1 Star 0 Fork 0

BROZHANG/Bearing_Fault_Diagnosis

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Restruct_sourceDataset.py 7.29 KB
一键复制 编辑 原始数据 按行查看 历史
BROZHANG 提交于 2024-07-14 11:35 . 240714 整理
# -*- encoding: utf-8 -*-
'''
@File : Restruct_sourceDataset.py
@Time : 2024/03/13 12:45:46
@Author : zhangqidong
@email : 835439833@qq.com
@info : 重构源域数据集;原则1:选点要靠近类中心;原则2:新类中心要尽可能分散
'''
import os
import sys
import torch
from torch import cuda, nn, no_grad
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import itertools
import torch.nn.functional as F
from sklearn.manifold import TSNE
sys.path.append(os.getcwd())
from config.config import Config
from utils.utils import weightInit, calMetrics, mmd, seedTorch, plotTSNE, createDataset, splitDataset, createProcedureCase, tripletLoss
from net.VGG7 import VGG7
import numpy as np
from collections import Counter
import random
config = Config()
def euclidean_distance(x1, x2):
return torch.sqrt(torch.sum((x1 - x2.to(x1.device))**2))
class KMeans:
def __init__(self, k=2, max_iters=100, device='cpu'):
self.k = k
self.max_iters = max_iters
self.device = device
def fit(self, X):
self.centroids = X[torch.randint(len(X), (self.k, ), dtype=torch.long)].to(self.device)
for _ in range(self.max_iters):
print(f"epoch={_}")
clusters = [torch.empty(0, *X.shape[1:], device=self.device) for _ in range(self.k)]
for point in X:
distances = [euclidean_distance(point, centroid) for centroid in self.centroids]
closest_cluster = torch.argmin(torch.tensor(distances))
clusters[closest_cluster] = torch.cat([clusters[closest_cluster], point.unsqueeze(0).to(self.device)])
new_centroids = [cluster.mean(dim=0) for cluster in clusters]
# 如果所有聚类中心点在本次迭代中均未发生改变(即新的聚类中心点与当前聚类中心点完全相同),则跳出循环,结束聚类算法的迭代过程。
if torch.all(torch.stack([torch.all(self.centroids[i] == new_centroids[i]) for i in range(self.k)])):
break
self.centroids = new_centroids
def predict(self, X):
clusters = [torch.empty(0, *X.shape[1:], device=self.device) for _ in range(self.k)]
for point in X:
distances = [euclidean_distance(point, centroid) for centroid in self.centroids]
closest_cluster = torch.argmin(torch.tensor(distances))
clusters[closest_cluster] = torch.cat([clusters[closest_cluster], point.unsqueeze(0).to(self.device)])
return clusters
class TL_Trainer:
"""
训练器对象
"""
def __init__(self, seed: int, basemodel: str = None):
"""
训练器初始化
Args:
seed (int): 随机数种子
basemodel (str, optional): basemodel的保存路径
"""
# 记录时间戳文件夹路径
self.timeroot = createProcedureCase()
# 记录随机数种子
self.seed = seed
# 记录当前迭代代数
self.epoch = None
# 设置全局MMD或局部MMD模式
self.mode = config.udamode
# 设置训练设备
self.device = "cuda" if cuda.is_available() else 'cpu'
print(f"training by {self.device} !")
# 加载网络到设备
self.student_model = VGG7().to(self.device)
# 网络权重初始化
weightInit(self.student_model, weight=config.student_model)
# 故障文件路径
source_root = config.domain0_root
target_root = config.domain3_root
# 组织源域和目标域数据集
dataset_s = createDataset(source_root, domain=0)
dataset_t = createDataset(target_root, domain=1)
self.dataset_sum, _ = splitDataset(dataset_s + dataset_t, 0.1)
# 设置批量
self.batch_size = 512
# 交叉熵损失
self.CE_loss = nn.CrossEntropyLoss()
self.MSE_loss = nn.MSELoss()
# 按8:2划分源域和目标域数据集
trainset_s, testset_s = splitDataset(dataset_s, 0.8)
trainset_t, testset_t = splitDataset(dataset_t, 0.8)
trainset_sum, testset_sum = splitDataset(self.dataset_sum, 0.8)
# 组织dataloader
self.traindataloader_s = DataLoader(trainset_s, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.testdataloader_s = DataLoader(testset_s, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.traindataloader_t = DataLoader(trainset_t, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.testdataloader_t = DataLoader(testset_t, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.trainloader_sum = DataLoader(trainset_sum, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.testloader_sum = DataLoader(testset_sum, batch_size=self.batch_size, shuffle=True, drop_last=True)
def Restruct(self):
# 将模型设置为训练模式
self.student_model.eval()
# 设置dataloader为可循环迭代 就是迭代到末尾了可以从头开始
dataloader = tqdm(self.trainloader_sum, ncols=100, desc="test")
all_X = torch.tensor(np.array([t[0] for t in self.dataset_sum])).to(self.device)
all_fake_y = torch.tensor(np.array([t[1] for t in self.dataset_sum])).to(self.device)
all_real_y = torch.tensor(np.array([t[2] for t in self.dataset_sum])).to(self.device)
all_domain = torch.tensor(np.array([t[3] for t in self.dataset_sum])).to(self.device)
# 前向传播 得到概率输出和隐藏层特征
logits_s_stu, feature_s_stu = self.student_model.forward(all_X) # student对源域推理
kmeans = KMeans(k=10)
kmeans.fit(feature_s_stu)
clusters = kmeans.predict(feature_s_stu)
fig1 = plotTSNE(feature_s_stu, all_real_y, title="class")
fig1.savefig(os.path.join(self.timeroot, 'png', f"class-{self.seed}.png"))
fig2 = plotTSNE(feature_s_stu, all_domain, title="domain")
fig2.savefig(os.path.join(self.timeroot, 'png', f"domain-{self.seed}.png"))
print(clusters)
def Visualization(self) -> dict:
"""
训练函数
Returns:
dict: 当前epoch的评价指标的批平均值
"""
# 将模型设置为训练模式
self.student_model.eval()
# 设置dataloader为可循环迭代 就是迭代到末尾了可以从头开始
dataloader = tqdm(self.trainloader_sum, ncols=100, desc="test")
for x, fake_y, real_y, domain in dataloader:
# 加载xyz到cuda或cpu
x = x.to(self.device)
fake_y = fake_y.to(self.device)
real_y = real_y.to(self.device)
domain = domain.to(self.device)
# 前向传播 得到概率输出和隐藏层特征
logits_s_stu, feature_s_stu = self.student_model.forward(x) # student对源域推理
# 对当前epoch的最后一批进行可视化 可在config中进行配置是否执行
fig1 = plotTSNE(feature_s_stu, real_y, title="class")
fig1.savefig(os.path.join(self.timeroot, 'png', f"class-{self.seed}.png"))
fig2 = plotTSNE(feature_s_stu, domain, title="domain")
fig2.savefig(os.path.join(self.timeroot, 'png', f"domain-{self.seed}.png"))
if __name__ == '__main__':
trainer = TL_Trainer(seed=100)
trainer.Visualization()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/brozhang666/bearing-fault-diagnosis.git
git@gitee.com:brozhang666/bearing-fault-diagnosis.git
brozhang666
bearing-fault-diagnosis
Bearing_Fault_Diagnosis
main

搜索帮助