代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/22 13:45
@Description :AEE测试
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, AEE_Convd
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from sklearn.cluster import KMeans
torch.set_default_tensor_type(torch.DoubleTensor)
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == '__main__':
batch = 1
datasetpath = 'G:\哨兵2号数据'
trainSet = CarTiffDateSet(datasetpath, type='test')
train_loader = torch.utils.data.DataLoader(dataset=trainSet,
batch_size=batch,
shuffle=False)
model = AEE_Convd.Q_net().to(device)
model.load_state_dict(torch.load("SavedModels/QNet.pkl"))
model.eval()
temp = np.empty([72, 64])
cloudList = [0, 1, 2, 3, 4, 5, 9, 17, 24, 37, 39, 40]
greenList = [6, 7, 8]
for index, images in enumerate(tqdm(train_loader)):
images = images.to(device)
outputs2 = model(images).detach().numpy()
fea = model.x3.detach().numpy()
temp[index] = fea[0]
# if index in cloudList:
# plt.scatter(fea[0][0], fea[0][1], label=index, c='red')
# if index >= 66:
# plt.scatter(fea[0][0], fea[0][1], label=index, c='green')
# if index in greenList:
# plt.scatter(fea[0][0], fea[0][1], label=index, c='blue')
# plt.legend()
# plt.show()
# if index >= 66 or index == 24 or index==5 or index==6 or index==7:
if index in cloudList:
plt.plot(fea[0], label=index, c='red')
if index >= 66:
plt.plot(fea[0], label=index, c='green')
if index in greenList:
plt.plot(fea[0], label=index, c='blue')
plt.legend()
plt.show()
kmeans = KMeans(n_clusters=2, random_state=0, algorithm="auto")
kmeans.fit(temp)
print(kmeans.labels_)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。