代码拉取完成,页面将自动刷新
import argparse
from torch.nn.modules.activation import ReLU
from models.unetcae import UnetCAE
import os
import warnings
from config import TrainingConfig
from utils.logger import get_logger
from datasets import MvTec
from trainer import Trainer
from models import ResnetCAE
if __name__ == '__main__':
warnings.filterwarnings('ignore')
parser = argparse.ArgumentParser()
parser.add_argument('-nc')
parser.add_argument('-device', default='0')
cfg = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.device
train_cfg = TrainingConfig()
train_cfg.ckpt_path = os.path.join(train_cfg.ckpt_path, cfg.nc)
if not os.path.exists(train_cfg.ckpt_path):
os.makedirs(train_cfg.ckpt_path)
logger = get_logger(train_cfg.ckpt_path, specified_file='train.txt', console_open=False)
train_ds = MvTec('train', cfg.nc, zoom_size=train_cfg.img_size)
val_ds = MvTec('test', cfg.nc, zoom_size=train_cfg.img_size, test_mode='broken')
model = ResnetCAE().cuda()
resume = None
# resume='checkpoints/resnet18/mvtec/tile/epoch_200.pth'
trainer = Trainer(model, logger, train_ds, val_ds, train_cfg)
trainer.training_process()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。