代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/22 14:55
@Description :UNet训练
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, UNet
from tqdm import tqdm
from PIL import Image
# import os
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)
if __name__ == '__main__':
batch = 8
datasetpath = 'G:\哨兵2号数据'
trainSet = CarTiffDateSet(datasetpath)
train_loader = torch.utils.data.DataLoader(dataset=trainSet,
batch_size=batch,
shuffle=True)
# model = AE_withLinear.Autoencoder().to(device)
model = UNet.UNet().to(device)
cost2 = nn.MSELoss().to(device)
optimizer2 = optim.Adam(model.parameters(), weight_decay=1e-6)
for epoch in range(1, 8):
model.train()
index = 0
loss = 0
for images, labels in tqdm(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs2 = model(images)
loss2 = cost2(outputs2, labels)
optimizer2.zero_grad()
loss2.backward()
loss += loss2
optimizer2.step()
index += 1
print(loss)
torch.save(model.state_dict(), 'SavedModels/UNet.pkl')
# torch.save(model.state_dict(), 'SavedModels/AE_withLinear.pkl')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。