代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/21 21:39
@Description :AEE 训练
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, AEE_Convd
from tqdm import tqdm
import torch.nn.functional as F
from PIL import Image
# import os
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)
EPS = 1e-15
if __name__ == '__main__':
batch = 8
datasetpath = 'G:\哨兵2号数据'
trainSet = CarTiffDateSet(datasetpath)
train_loader = torch.utils.data.DataLoader(dataset=trainSet,
batch_size=batch,
shuffle=True)
# 加载模型
Q = AEE_Convd.Q_net().to(device)
P = AEE_Convd.P_net().to(device)
D_gauss = AEE_Convd.D_net_gauss().to(device)
# 优化器
# encode/decode optimizers
optim_P = torch.optim.Adam(P.parameters())
optim_Q_enc = torch.optim.Adam(Q.parameters())
# regularizing optimizers
optim_Q_gen = torch.optim.Adam(Q.parameters())
optim_D = torch.optim.Adam(D_gauss.parameters())
cost = nn.MSELoss().to(device)
# 开始训练
for epoch in range(20):
# 启用编码器、解码器、判别器中的随机失活(与批标准化)
Q.train()
P.train()
D_gauss.train()
D_loss_sum = torch.zeros(1).to(device) # 累计损失
G_loss_sum = torch.zeros(1).to(device) # 累计损失
recon_loss_sum = torch.zeros(1).to(device) # 累计损失
for step, data in enumerate(train_loader):
images, labels = data
# 自编码
z_sample = Q(images) # encode to z
X_sample = P(z_sample) # decode to X reconstruction
# 二分类交叉熵,计算损失
# print(images.shape,X_sample.shape)
recon_loss = cost(X_sample + EPS, images + EPS)
# recon_loss = F.binary_cross_entropy(X_sample + EPS, images + EPS)
# 重构损失反向传播
recon_loss.backward()
# 优化更新编码器与解码器
optim_P.step()
optim_Q_enc.step()
# 梯度清零,不希望之前的梯度影响之后的梯度
P.zero_grad()
Q.zero_grad()
D_gauss.zero_grad()
#######################
# Regularization phase
#######################
# Discriminator
# 判别器
Q.eval()
# 从一个高斯随机分布中采样隐变量
z_real_gauss = torch.randn(images.size()[0], 64) * 5.
# print(z_real_gauss.shape)
z_real_gauss = z_real_gauss.to(device)
# 由输入数据到编码器中产生的隐变量,目前不一定是高斯分布
z_fake_gauss = Q(images)
# 将两种隐变量分别输入到判别器中,每张图片的每种隐变量对应一个数
D_fake_gauss = D_gauss(z_fake_gauss)
D_real_gauss = D_gauss(z_real_gauss)
# 判别器的损失函数
# 判别器认为真实高斯隐变量为1,从训练数据得到的隐变量为0
D_loss = -torch.mean(torch.log(D_real_gauss + EPS) + torch.log(1 - D_fake_gauss + EPS))
# 判别器损失函数反向传播
D_loss.backward()
# 优化更新判别器
optim_D.step()
# 梯度再次清零
P.zero_grad()
Q.zero_grad()
D_gauss.zero_grad()
# Generator
# 对于生成器/解码器
# 启用编码器中的随机失活
Q.train()
# 由输入数据到编码器中产生的隐变量,目前不一定是高斯分布
z_fake_gauss = Q(images)
# 将隐变量输入到判别器中,每张图片的结果对应一个数
D_fake_gauss = D_gauss(z_fake_gauss)
# print(D_fake_gauss)
# 生成器/解码器的损失函数
# 生成器/解码器努力使由输入数据得到的隐变量标签为1,与判别器对抗训练,使得隐变量分布接近于高斯分布
# 生成器/解码器的目标在于生成以假乱真的数据,与GAN不同在于以假乱真是指是否是高斯分布,而非是否是真实数据
G_loss = -torch.mean(torch.log(D_fake_gauss + EPS))
# 生成器/解码器损失函数反向传播
G_loss.backward()
# 优化更新生成器/解码器
optim_Q_gen.step()
# 清除梯度
P.zero_grad()
Q.zero_grad()
D_gauss.zero_grad()
D_loss_sum += D_loss.detach()
G_loss_sum += G_loss.detach()
recon_loss_sum += recon_loss.detach()
print(f"epoch:{epoch};"
f" D_loss_gauss:{D_loss_sum.item() / (step + 1):.3f};"
f" G_loss:{G_loss_sum.item() / (step + 1):.3f};"
f" recon_loss: {recon_loss_sum.item() / (step + 1):.3f} ")
torch.save(Q.state_dict(), 'SavedModels/QNet.pkl')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。