1 Star 1 Fork 0

Haixu He/自编码器提取特征

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
UNet_train.py 1.58 KB
一键复制 编辑 原始数据 按行查看 历史
Haixu He 提交于 2022-05-22 15:28 . add UNet
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author :hhx
@Date :2022/5/22 14:55
@Description :UNet训练
"""
import numpy as np
import os
from utils import *
import torch
from torch import nn, optim
from torch.utils import data
from models import AE, AE_withLinear, UNet
from tqdm import tqdm
from PIL import Image
# import os
# os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = 'cpu'
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type(torch.DoubleTensor)
if __name__ == '__main__':
batch = 8
datasetpath = 'G:\哨兵2号数据'
trainSet = CarTiffDateSet(datasetpath)
train_loader = torch.utils.data.DataLoader(dataset=trainSet,
batch_size=batch,
shuffle=True)
# model = AE_withLinear.Autoencoder().to(device)
model = UNet.UNet().to(device)
cost2 = nn.MSELoss().to(device)
optimizer2 = optim.Adam(model.parameters(), weight_decay=1e-6)
for epoch in range(1, 8):
model.train()
index = 0
loss = 0
for images, labels in tqdm(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs2 = model(images)
loss2 = cost2(outputs2, labels)
optimizer2.zero_grad()
loss2.backward()
loss += loss2
optimizer2.step()
index += 1
print(loss)
torch.save(model.state_dict(), 'SavedModels/UNet.pkl')
# torch.save(model.state_dict(), 'SavedModels/AE_withLinear.pkl')
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HaixuHe/Feature-from-encoder.git
git@gitee.com:HaixuHe/Feature-from-encoder.git
HaixuHe
Feature-from-encoder
自编码器提取特征
master

搜索帮助