代码拉取完成,页面将自动刷新
"""Experiment-running framework."""
import argparse
import importlib
import numpy as np
import torch
import pytorch_lightning as pl
import lit_models
from lit_models import TransformerLitModel
from transformers import AutoConfig
import os
import json
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# In order to ensure reproducible experiments, we must set random seeds.
def _import_class(module_and_class_name: str) -> type:
"""Import class from a module, e.g. 'text_recognizer.models.MLP'"""
module_name, class_name = module_and_class_name.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
return class_
def _setup_parser():
"""Set up Python's ArgumentParser with data, model, trainer, and other arguments."""
parser = argparse.ArgumentParser(add_help=False)
# Add Trainer specific arguments, such as --max_epochs, --gpus, --precision
trainer_parser = pl.Trainer.add_argparse_args(parser)
trainer_parser._action_groups[1].title = "Trainer Args" # pylint: disable=protected-access
parser = argparse.ArgumentParser(add_help=False, parents=[trainer_parser])
# Basic arguments
parser.add_argument("--wandb", action="store_true", default=False)
parser.add_argument("--seed", type=int, default=666)
parser.add_argument("--data_class", type=str, default="TEXTCLS")
parser.add_argument("--model_class", type=str, default="bert.BertForSequenceClassification")
parser.add_argument("--load_checkpoint", type=str, default=None)
# Get the data and model classes, so that we can add their specific arguments
temp_args, _ = parser.parse_known_args()
data_class = _import_class(f"data.{temp_args.data_class}")
model_class = _import_class(f"models.{temp_args.model_class}")
# Get data, model, and LitModel specific arguments
data_group = parser.add_argument_group("Data Args")
data_class.add_to_argparse(data_group)
model_group = parser.add_argument_group("Model Args")
model_class.add_to_argparse(model_group)
lit_model_group = parser.add_argument_group("LitModel Args")
lit_models.BaseLitModel.add_to_argparse(lit_model_group)
parser.add_argument("--help", "-h", action="help")
return parser
def main():
"""
Run an experiment.
Sample command:
```
python training/run_experiment.py --max_epochs=3 --gpus='0,' --num_workers=20 --model_class=MLP --data_class=MNIST
```
"""
parser = _setup_parser()
args = parser.parse_args()
with open("config.json", "w") as file:
t = vars(args).copy()
t.pop("tpu_cores")
json.dump(t, file)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
data_class = _import_class(f"data.{args.data_class}")
model_class = _import_class(f"models.{args.model_class}")
data = data_class(args)
data_config = data.get_data_config()
config = AutoConfig.from_pretrained(args.model_name_or_path)
config.num_labels = data_config["num_labels"]
model = model_class.from_pretrained(args.model_name_or_path, config=config)
lit_model = TransformerLitModel(args=args, model=model, data_config=data_config)
logger = pl.loggers.TensorBoardLogger("training/logs")
# Hide lines below until Lab 5
if args.wandb:
logger = pl.loggers.WandbLogger(project="text_cls")
logger.log_hyperparams(vars(args))
# Hide lines above until Lab 5
early_callback = pl.callbacks.EarlyStopping(monitor="Eval/acc", mode="max", patience=5)
model_checkpoint = pl.callbacks.ModelCheckpoint(monitor="Eval/acc", mode="max",
filename='{epoch}-{Eval/acc:.2f}',
dirpath="output"
)
callbacks = [early_callback, model_checkpoint]
# args.weights_summary = "full" # Print full summary of the model
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, logger=logger, default_root_dir="training/logs")
trainer.tune(lit_model, datamodule=data) # If passing --auto_lr_find, this will set learning rate
trainer.fit(lit_model, datamodule=data)
trainer.test(lit_model, datamodule=data)
path, filename = os.path.splitext(model_checkpoint.best_model_path)
trainer.model.save_pretrained(path)
if __name__ == "__main__":
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。