1 Star 0 Fork 71

TopKernel/ChatLM-mini-Chinese

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
pre_train.py 4.70 KB
一键复制 编辑 原始数据 按行查看 历史
Charent 提交于 2024-02-05 16:02 . fixed disk use
# coding=utf-8
import time
import os
import pandas as pd
from dataclasses import dataclass
import torch
from typing import Dict
from tqdm import tqdm
import numpy as np
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig
from datasets import Dataset, load_dataset
from model.chat_model import TextToTextModel
from model.dataset import MyDataset
from config import TrainConfig, T5ModelConfig
from utils.functions import json_to_dataclass, get_T5_config, MyTrainerCallback
tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
"""
加载数据集
"""
dataset = load_dataset('parquet', data_files=file, split=split, cache_dir=cache_dir)
def tokens_to_ids(samples: dict) -> Dict[str, str]:
eos_token_id = tokenizer.eos_token_id
batch_prompt = samples['prompt']
batch_response = samples['response']
encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False,)
encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False,)
# vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
return {
'input_ids': input_ids,
'labels': labels,
}
dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
return dataset
def pre_train(config: TrainConfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载模型配置文件
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
# step 3. 初始化模型
model = TextToTextModel(t5_config)
# Step 4: Load my dataset
dataset = get_dataset(file=config.train_file, split='train', tokenizer=tokenizer)
# Step 5: Define the training arguments
# T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
# huggingface官网的sft工具适用于language model/LM模型
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id
generation_config.max_new_tokens = 320
generation_config.num_beams = 1 # greedy search
generation_config.do_sample = False # greedy search
training_args = Seq2SeqTrainingArguments(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size_per_gpu,
auto_find_batch_size=True, # 防止OOM
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learn_rate,
logging_steps=config.logging_steps,
num_train_epochs=config.epochs,
optim="adafactor",
report_to='tensorboard',
log_level='info',
save_steps=config.save_steps,
save_total_limit=3,
fp16=True if config.mixed_precision == 'fp16' else False,
bf16=True if config.mixed_precision == 'bf16' else False,
logging_first_step=True,
warmup_steps=config.warmup_steps,
seed=config.seed,
generation_config=generation_config,
)
# step 6: init my collator,
collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
empty_cuda_cahce = MyTrainerCallback()
# Step 7: Define the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
callbacks=[empty_cuda_cahce],
)
# step 8: train
trainer.train(
# resume_from_checkpoint=True
)
#step 9: save log
loss_log = pd.DataFrame(trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/pre_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# Step 10: Save the model
trainer.save_model(config.output_dir)
if __name__ == '__main__':
config = TrainConfig()
pre_train(config)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/topkernel/ChatLM-mini-Chinese.git
git@gitee.com:topkernel/ChatLM-mini-Chinese.git
topkernel
ChatLM-mini-Chinese
ChatLM-mini-Chinese
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385