1 Star 0 Fork 0

pchen.Lyndon/deepke

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
main.py 4.99 KB
一键复制 编辑 原始数据 按行查看 历史
海钓的猫 提交于 2019-12-06 10:44 . pretty logger
import os
import hydra
import torch
import logging
import torch.nn as nn
from torch import optim
from hydra import utils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# self
import models
from preprocess import preprocess
from dataset import CustomDataset, collate_fn
from trainer import train, validate
from utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)
@hydra.main(config_path='conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# device
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
if cfg.preprocess:
preprocess(cfg)
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
if cfg.model_name == 'lm':
vocab_size = None
else:
vocab = load_pkl(vocab_path)
vocab_size = vocab.count
cfg.vocab_size = vocab_size
train_dataset = CustomDataset(train_data_path)
valid_dataset = CustomDataset(valid_data_path)
test_dataset = CustomDataset(test_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
model = __Model__[cfg.model_name](cfg)
model.to(device)
logger.info(f'\n {model}')
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
criterion = nn.CrossEntropyLoss()
best_f1, best_epoch = -1, 0
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
train_losses, valid_losses = [], []
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
writer = SummaryWriter('tensorboard')
else:
writer = None
logger.info('=' * 10 + ' Start training ' + '=' * 10)
for epoch in range(1, cfg.epoch + 1):
manual_seed(cfg.seed + epoch)
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
scheduler.step(valid_loss)
model_path = model.save(epoch, cfg)
# logger.info(model_path)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
if best_f1 < valid_f1:
best_f1 = valid_f1
best_epoch = epoch
# 使用 valid loss 做 early stopping 的判断标准
if es_loss > valid_loss:
es_loss = valid_loss
es_f1 = valid_f1
es_epoch = epoch
es_patience = 0
es_path = model_path
else:
es_patience += 1
if es_patience >= cfg.early_stopping_patience:
best_es_epoch = es_epoch
best_es_f1 = es_f1
best_es_path = es_path
if cfg.show_plot:
if cfg.plot_utils == 'matplot':
plt.plot(train_losses, 'x-')
plt.plot(valid_losses, '+-')
plt.legend(['train', 'valid'])
plt.title('train/valid comparison loss')
plt.show()
if cfg.plot_utils == 'tensorboard':
for i in range(len(train_losses)):
writer.add_scalars('train/valid_comparison_loss', {
'train': train_losses[i],
'valid': valid_losses[i]
}, i)
writer.close()
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
f'this epoch macro f1: {best_es_f1:0.4f}')
logger.info(f'this model save path: {best_es_path}')
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
f'this epoch macro f1: {best_f1:.4f}')
validate(-1, model, test_dataloader, criterion, device, cfg)
if __name__ == '__main__':
main()
# python predict.py --help # 查看参数帮助
# python predict.py -c
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/pchen_Lyndon/deepke.git
git@gitee.com:pchen_Lyndon/deepke.git
pchen_Lyndon
deepke
deepke
master

搜索帮助