代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
def run_discriminator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
images: torch.Tensor,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义真实样本与假样本的标签
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 使用真实样本训练鉴别器
outputs = d_net(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 使用生成样本训练鉴别器
z = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(z)
outputs = d_net(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake # 计算总损失
d_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return d_loss, real_score, fake_score
def run_generator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义生成样本的标签和噪声
real_labels = torch.ones(batch_size, 1).to(device)
z = torch.randn(batch_size, latent_size).to(device)
# 训练生成器
fake_images = g_net(z)
outputs = d_net(fake_images)
g_loss = criterion(outputs, real_labels) # 计算判别器结果和真实标签的损失
g_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return g_loss, fake_images
def generate_and_save_images(g_net: nn.Module,
batch_size: int,
latent_size: int,
device: str,
image_prefix: str,
index: int) -> bool:
def dnorm(x: torch.Tensor):
min_value = -1
max_value = 1
out = (x - min_value) / (max_value - min_value)
return out.clamp(0, 1) # plt expects values in [0,1]
sample_vectors = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(sample_vectors)
fake_images = fake_images.view(batch_size, 1, 28, 28)
if os.path.exists(image_prefix) is False:
os.makedirs(image_prefix)
save_image(dnorm(fake_images), os.path.join(image_prefix, f'fake_images-{index:03d}.png'), nrow=10)
return True
def run_epoch(d_net: nn.Module,
g_net: nn.Module,
train_loader: DataLoader,
criterion: nn.Module,
d_optim: optim.Optimizer,
g_optim: optim.Optimizer,
batch_size: int,
latent_size: int,
device: str,
d_loss_list: list,
g_loss_list: list,
real_score_list: list,
fake_score_list: list,
epoch: int, num_epochs: int):
d_net.train()
g_net.train()
for idx, (images, _) in enumerate(train_loader):
images = images.view(batch_size, -1).to(device)
# 训练鉴别器
d_loss, real_score, fake_score = run_discriminator_one_batch(d_net, g_net, batch_size, latent_size, images,
criterion, d_optim, device)
# 训练生成器
g_loss, _ = run_generator_one_batch(d_net, g_net, batch_size, latent_size, criterion, g_optim, device)
if (idx + 1) % 300 == 0:
num = f"Epoch: [{epoch + 1}/{num_epochs}], Batch: [{idx + 1}/{len(train_loader)}]"
loss_info = f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}"
real_sample_score = f"Real sample score for Discriminator D(x): {real_score.mean().item():.4f}"
fake_sample_score = f"Fake sample score for Discriminator D(G(x)): {fake_score.mean().item():.4f}"
print(num + loss_info)
print(num + real_sample_score)
print(num + fake_sample_score)
d_loss_list.append(d_loss.item())
g_loss_list.append(g_loss.item())
real_score_list.append(real_score.mean().item())
fake_score_list.append(fake_score.mean().item())
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。