1 Star 3 Fork 1

Charent/pytorch-IE

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
main.py 2.33 KB
一键复制 编辑 原始数据 按行查看 历史
Charent 提交于 2023-10-01 16:24 . bug fixed
import torch
import numpy as np
from config import Config
import fire
# 设置随机数种子
seed = 233
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# from model.p_so_model import Trainer, load_model_and_test
# from model.s_model import Trainer, load_model_and_test
class PytorchIE(object):
def __init__(self, config: Config=None):
super().__init__()
# 加载配置文件
if config is None:
config = Config()
self.config = config
# 指定训练设备
self.device = torch.device("cpu")
if torch.cuda.is_available():
self.device = torch.device("cuda:{}".format(config.cuda_device_number))
torch.backends.cudnn.benchmark = True
print('device: {}'.format(self.device))
# =====================SP_O_MODEL===================
def train_sp_o(self):
from model.sp_o_model import Trainer
print('train sp_o model')
trainer = Trainer()
trainer.train(self.config, self.device)
def train_sp_o_2023(self):
from model.sp_o_model_2023 import Trainer
print('train sp_o model_2023')
trainer = Trainer()
trainer.train(self.config, self.device)
def test_sp_o_2023(self):
from model.sp_o_model_2023 import load_model_and_test
print('teset sp_o model_2023')
torch.backends.cudnn.benchmark = False
load_model_and_test(self.config, self.device)
def test_sp_o(self):
from model.sp_o_model import load_model_and_test
print('test sp_o model')
torch.backends.cudnn.benchmark = False
load_model_and_test(self.config, self.device)
# =====================P_SO_MODEL===================
def train_p_so(self):
from model.p_so_model import Trainer
print('train p_so model')
trainer = Trainer()
trainer.train(self.config, self.device)
def test_p_so(self):
from model.p_so_model import load_model_and_test
print('test p_so model')
torch.backends.cudnn.benchmark = False
load_model_and_test(self.config, self.device)
if __name__ == "__main__":
# 设置默认为FloatTensor
torch.set_default_tensor_type(torch.FloatTensor)
# 解析命令行参数,执行指定函数
fire.Fire(component=PytorchIE())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/charent/pytorch-IE.git
git@gitee.com:charent/pytorch-IE.git
charent
pytorch-IE
pytorch-IE
main

搜索帮助

23e8dbc6 1850385 7e0993f3 1850385