1 Star 0 Fork 0

yyiOe/autoencoder

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 1.17 KB
一键复制 编辑 原始数据 按行查看 历史
ps 提交于 2021-07-06 11:33 . update
import argparse
from torch.nn.modules.activation import ReLU
from models.unetcae import UnetCAE
import os
import warnings
from config import TrainingConfig
from utils.logger import get_logger
from datasets import MvTec
from trainer import Trainer
from models import ResnetCAE
if __name__ == '__main__':
warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser()
parser.add_argument('-nc')
parser.add_argument('-device', default='0')
cfg = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.device
train_cfg = TrainingConfig()
train_cfg.ckpt_path = os.path.join(train_cfg.ckpt_path, cfg.nc)
if not os.path.exists(train_cfg.ckpt_path):
os.makedirs(train_cfg.ckpt_path)
logger = get_logger(train_cfg.ckpt_path, specified_file='train.txt', console_open=False)
train_ds = MvTec('train', cfg.nc, zoom_size=train_cfg.img_size)
val_ds = MvTec('test', cfg.nc, zoom_size=train_cfg.img_size, test_mode='broken')
model = ResnetCAE().cuda()
resume = None
# resume='checkpoints/resnet18/mvtec/tile/epoch_200.pth'
trainer = Trainer(model, logger, train_ds, val_ds, train_cfg)
trainer.training_process()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yyioe/autoencoder.git
git@gitee.com:yyioe/autoencoder.git
yyioe
autoencoder
autoencoder
master

搜索帮助