1 Star 2 Fork 0

zh-jp/pytorch-gan

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
training.py 4.66 KB
一键复制 编辑 原始数据 按行查看 历史
zh-jp 提交于 2024-03-04 14:55 . init
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
def run_discriminator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
images: torch.Tensor,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义真实样本与假样本的标签
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# 使用真实样本训练鉴别器
outputs = d_net(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
# 使用生成样本训练鉴别器
z = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(z)
outputs = d_net(fake_images.detach())
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake # 计算总损失
d_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return d_loss, real_score, fake_score
def run_generator_one_batch(d_net: nn.Module,
g_net: nn.Module,
batch_size: int,
latent_size: int,
criterion: nn.Module,
optimizer: optim.Optimizer,
device: str):
# 定义生成样本的标签和噪声
real_labels = torch.ones(batch_size, 1).to(device)
z = torch.randn(batch_size, latent_size).to(device)
# 训练生成器
fake_images = g_net(z)
outputs = d_net(fake_images)
g_loss = criterion(outputs, real_labels) # 计算判别器结果和真实标签的损失
g_loss.backward() # 反向传播
optimizer.step() # 更新参数
optimizer.zero_grad() # 清空梯度
return g_loss, fake_images
def generate_and_save_images(g_net: nn.Module,
batch_size: int,
latent_size: int,
device: str,
image_prefix: str,
index: int) -> bool:
def dnorm(x: torch.Tensor):
min_value = -1
max_value = 1
out = (x - min_value) / (max_value - min_value)
return out.clamp(0, 1) # plt expects values in [0,1]
sample_vectors = torch.randn(batch_size, latent_size).to(device)
fake_images = g_net(sample_vectors)
fake_images = fake_images.view(batch_size, 1, 28, 28)
if os.path.exists(image_prefix) is False:
os.makedirs(image_prefix)
save_image(dnorm(fake_images), os.path.join(image_prefix, f'fake_images-{index:03d}.png'), nrow=10)
return True
def run_epoch(d_net: nn.Module,
g_net: nn.Module,
train_loader: DataLoader,
criterion: nn.Module,
d_optim: optim.Optimizer,
g_optim: optim.Optimizer,
batch_size: int,
latent_size: int,
device: str,
d_loss_list: list,
g_loss_list: list,
real_score_list: list,
fake_score_list: list,
epoch: int, num_epochs: int):
d_net.train()
g_net.train()
for idx, (images, _) in enumerate(train_loader):
images = images.view(batch_size, -1).to(device)
# 训练鉴别器
d_loss, real_score, fake_score = run_discriminator_one_batch(d_net, g_net, batch_size, latent_size, images,
criterion, d_optim, device)
# 训练生成器
g_loss, _ = run_generator_one_batch(d_net, g_net, batch_size, latent_size, criterion, g_optim, device)
if (idx + 1) % 300 == 0:
num = f"Epoch: [{epoch + 1}/{num_epochs}], Batch: [{idx + 1}/{len(train_loader)}]"
loss_info = f"Discriminator Loss: {d_loss.item():.4f}, Generator Loss: {g_loss.item():.4f}"
real_sample_score = f"Real sample score for Discriminator D(x): {real_score.mean().item():.4f}"
fake_sample_score = f"Fake sample score for Discriminator D(G(x)): {fake_score.mean().item():.4f}"
print(num + loss_info)
print(num + real_sample_score)
print(num + fake_sample_score)
d_loss_list.append(d_loss.item())
g_loss_list.append(g_loss.item())
real_score_list.append(real_score.mean().item())
fake_score_list.append(fake_score.mean().item())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/zh-jp/pytorch-gan.git
git@gitee.com:zh-jp/pytorch-gan.git
zh-jp
pytorch-gan
pytorch-gan
master

搜索帮助