1 Star 1 Fork 1

cheasim/标题段落分类

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
main.py 4.19 KB
一键复制 编辑 原始数据 按行查看 历史
cheasim 提交于 2021-04-26 17:21 . modify
"""Experiment-running framework."""
import argparse
import importlib
import numpy as np
import torch
import pytorch_lightning as pl
import lit_models
from lit_models import TransformerLitModel
from transformers import AutoConfig
import os
import json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# In order to ensure reproducible experiments, we must set random seeds.
def _import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
# Basic arguments
parser.add_argument("--wandb", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=666)
parser.add_argument("--data_class", type=str, default="TEXTCLS")
parser.add_argument("--model_class", type=str, default="bert.BertForSequenceClassification")
parser.add_argument("--load_checkpoint", type=str, default=None)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = _import_class(f"data.{temp_args.data_class}")
model_class = _import_class(f"models.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
"""
parser = _setup_parser()
args = parser.parse_args()
with open("config.json", "w") as file:
t = vars(args).copy()
t.pop("tpu_cores")
json.dump(t, file)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_class = _import_class(f"data.{args.data_class}")
model_class = _import_class(f"models.{args.model_class}")
data = data_class(args)
data_config = data.get_data_config()
config = AutoConfig.from_pretrained(args.model_name_or_path)
config.num_labels = data_config["num_labels"]
model = model_class.from_pretrained(args.model_name_or_path, config=config)
lit_model = TransformerLitModel(args=args, model=model, data_config=data_config)
logger = pl.loggers.TensorBoardLogger("training/logs")
# Hide lines below until Lab 5
if args.wandb:
logger = pl.loggers.WandbLogger(project="text_cls")
logger.log_hyperparams(vars(args))
# Hide lines above until Lab 5
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/acc", mode="max", patience=5)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/acc", mode="max",
filename='{epoch}-{Eval/acc:.2f}',
dirpath="output"
)
callbacks = [early_callback, model_checkpoint]
# args.weights_summary = "full" # Print full summary of the model
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.test(lit_model, datamodule=data)
path, filename = os.path.splitext(model_checkpoint.best_model_path)
trainer.model.save_pretrained(path)
if __name__ == "__main__":
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/cheasim/title-context-classification.git
git@gitee.com:cheasim/title-context-classification.git
cheasim
title-context-classification
标题段落分类
master

搜索帮助