代码拉取完成,页面将自动刷新
import argparse
import math
import time
import dill as pickle
# import pickle
# import pickle5 as pickle
from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim
# from torchtext.data import Field, Dataset, BucketIterator
# from torchtext.datasets import TranslationDataset
import random
import numpy as np
# fix random seed
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import time
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import importlib
class Log:
state_dict = {}
def __init__(self, root_path, method_name):
self.root_path = root_path
self.method_name = method_name
self._time = str(time.strftime("%Y-%m-%d/%H-%M-%S", time.localtime()))
self._log_path = os.path.join(self.root_path, self.method_name, self._time)
if not os.path.exists(self._log_path):
os.makedirs(self._log_path)
print('log_path:', self._log_path)
def record_opt(self, opt):
self._opt_path = self._log_path + '/opt.txt'
f = open(self._opt_path, 'w')
f.write(str(opt))
f.close()
def state_dict_update(self, key_value_list):
for key, value in key_value_list:
if key not in self.state_dict:
self.state_dict[key] = []
self.state_dict[key].append(value)
np.save(self._log_path + '/state_dict.npy', self.state_dict)
def save_model(self, model_name, checkpoint):
self._model_path = os.path.join(self._log_path, model_name)
torch.save(checkpoint, self._model_path)
def record_report(self, report_str):
self._report_path = self._log_path + '/report.txt'
f = open(self._report_path, 'a')
f.writelines(report_str + '\n')
f.close()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。