5 Star 2 Fork 1

善若水/Oh-My-Bloom

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
train.py 5.70 KB
一键复制 编辑 原始数据 按行查看 历史
善若水 提交于 2023-04-28 16:26 . 添加wandb日志器
import lightning as L
from typing import Any, List, Optional
import torch
from pathlib import Path
import typer
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from lightning.pytorch.utilities import rank_zero_only
import time
from oh_my_bloom.model import DeepSpeedChatBloom
from oh_my_bloom.tokenizer import get_chatbloom_tokeizer
from oh_my_bloom.datamodule import ChatDataModule
import os
os.environ['TOKENIZERS_PARALLELISM']= "False"
def run(plm_path: str,
output_dir: str,
ckpt_convert_dir: str,
cache_dir: str,
resume_ckpt_path: Optional[str] = None,
version: str = 'v1',
model_max_length: int = 512,
lr: float = 1e-5,
num_warmup_steps: int = 100,
weight_decay: float = 0.0,
seed: int = 42,
num_workers: int= 10,
batch_size: int = 10,
max_epochs: int = 3,
limit_train_batches: Optional[float] = None,
limit_val_batches: Optional[float] = None,
devices: int = 6,
accelerator: str = 'cuda',
strategy: Optional[str] = 'deepspeed_stage_2',
precision: str = '16-mixed',
fast_dev_run: bool = False):
"""
plm_path (str): 预训练模型路径.
output_dir (str): 日志以及模型参数保存目录.
ckpt_convert_dir (str): deepspeed类型参数转换后的保存目录.
version (str, optional): 版本. Defaults to 'v1'.
model_max_length (int, optional): 模型最大输入长度. Defaults to 512.
lr (float, optional): 学习率. Defaults to 2e-5.
num_warmup_steps (int, optional): 预热的步数. Defaults to 100.
weight_decay (float, optional): 权重衰减. Defaults to 0.0.
seed (int, optional): 随机数种子. Defaults to 42.
num_workers (int, optional): dataloader中的进程数. Defaults to 10.
batch_size (int, optional): 批次大小. Defaults to 10.
max_epochs (int, optional): 最大训练迭代. Defaults to 3.
limit_train_batches (Optional[float], optional): 限制训练批次. Defaults to None.
limit_val_batches (Optional[float], optional): 限制验证批次. Defaults to None.
devices (int, optional): 设备. Defaults to 6.
accelerator (str, optional): 加速器. Defaults to 'cuda'.
strategy (Optional[str], optional): 加速策略. Defaults to 'deepspeed_stage_2'.
precision (str, optional):训练精度. Defaults to '16-mixed'.
fast_dev_run (bool, optional): 快速开发模式. Defaults to False.
"""
L.seed_everything(seed=seed, workers=True)
torch.set_float32_matmul_precision('high')
tokenizer = get_chatbloom_tokeizer(plm_path=plm_path, model_max_length=model_max_length)
dm = ChatDataModule(tokenizer=tokenizer,
data_dir='/root/autodl-tmp/OMInstructions',
cache_dir=cache_dir,
batch_size=batch_size,
num_workers=num_workers)
if strategy.startswith('deepspeed'):
offload = True if strategy.endswith('offload') else False
model = DeepSpeedChatBloom(plm_path=plm_path,
chat_tokenizer=tokenizer,
lr=lr,
weight_decay=weight_decay,
num_warmup_steps=num_warmup_steps,
offload=offload)
elif strategy.startswith('fsdp'):
# 还未完成fsdp训练方式
pass
# 按照日期保存日志和ckpt
output_timed_dir = Path(output_dir, time.strftime('%Y-%m-%d'))
if not output_timed_dir.exists():
output_timed_dir.mkdir(parents=True)
csv_logger = CSVLogger(output_timed_dir, name='logs', version=version)
wandb_dir = Path(output_timed_dir, 'wandb_logs')
if not wandb_dir.exists():
wandb_dir.mkdir()
wandb_logger = WandbLogger(project=f'oh-my-bloom', save_dir=wandb_dir, version=version)
ckpt_dir = Path(output_timed_dir, 'checkpoints')
model_checkpoint = ModelCheckpoint(dirpath=ckpt_dir,
filename='epoch={epoch}-step={step}-loss={train/loss:.2f}',
auto_insert_metric_name=False,
save_last=True)
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = L.Trainer(accelerator=accelerator,
devices=devices,
strategy=strategy,
max_epochs=max_epochs,
enable_checkpointing=True,
precision=precision,
fast_dev_run=fast_dev_run,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
callbacks=[model_checkpoint, lr_monitor],
logger=[wandb_logger, csv_logger],
profiler='simple')
trainer.fit(model=model, datamodule=dm, ckpt_path=resume_ckpt_path)
# 将deepspeed的参数保存格式转换为lightning可以直接加载的格式
if strategy.startswith('deepspeed') and not fast_dev_run:
best_ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
convert_save_path = Path(ckpt_convert_dir, version, best_ckpt_path.name)
if not convert_save_path.parent.exists():
convert_save_path.parent.mkdir()
rank_zero_only(convert_zero_checkpoint_to_fp32_state_dict)(checkpoint_dir=best_ckpt_path, output_file=convert_save_path)
if __name__ == "__main__":
typer.run(run)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/good-as-water/oh-my-bloom.git
git@gitee.com:good-as-water/oh-my-bloom.git
good-as-water
oh-my-bloom
Oh-My-Bloom
master

搜索帮助