代码拉取完成,页面将自动刷新
同步操作将从 Charent/ChatLM-mini-Chinese 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
# coding=utf-8
import time
import os
import pandas as pd
from dataclasses import dataclass
import torch
from typing import Dict
from tqdm import tqdm
import numpy as np
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig
from datasets import Dataset, load_dataset
from model.chat_model import TextToTextModel
from model.dataset import MyDataset
from config import TrainConfig, T5ModelConfig
from utils.functions import json_to_dataclass, get_T5_config, MyTrainerCallback
tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
"""
加载数据集
"""
dataset = load_dataset('parquet', data_files=file, split=split, cache_dir=cache_dir)
def tokens_to_ids(samples: dict) -> Dict[str, str]:
eos_token_id = tokenizer.eos_token_id
batch_prompt = samples['prompt']
batch_response = samples['response']
encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False,)
encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False,)
# vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
return {
'input_ids': input_ids,
'labels': labels,
}
dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
return dataset
def pre_train(config: TrainConfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载模型配置文件
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
# step 3. 初始化模型
model = TextToTextModel(t5_config)
# Step 4: Load my dataset
dataset = get_dataset(file=config.train_file, split='train', tokenizer=tokenizer)
# Step 5: Define the training arguments
# T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
# huggingface官网的sft工具适用于language model/LM模型
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id
generation_config.max_new_tokens = 320
generation_config.num_beams = 1 # greedy search
generation_config.do_sample = False # greedy search
training_args = Seq2SeqTrainingArguments(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size_per_gpu,
auto_find_batch_size=True, # 防止OOM
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learn_rate,
logging_steps=config.logging_steps,
num_train_epochs=config.epochs,
optim="adafactor",
report_to='tensorboard',
log_level='info',
save_steps=config.save_steps,
save_total_limit=3,
fp16=True if config.mixed_precision == 'fp16' else False,
bf16=True if config.mixed_precision == 'bf16' else False,
logging_first_step=True,
warmup_steps=config.warmup_steps,
seed=config.seed,
generation_config=generation_config,
)
# step 6: init my collator,
collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
empty_cuda_cahce = MyTrainerCallback()
# Step 7: Define the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
callbacks=[empty_cuda_cahce],
)
# step 8: train
trainer.train(
# resume_from_checkpoint=True
)
#step 9: save log
loss_log = pd.DataFrame(trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/pre_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# Step 10: Save the model
trainer.save_model(config.output_dir)
if __name__ == '__main__':
config = TrainConfig()
pre_train(config)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。