1 Star 0 Fork 71

TopKernel/ChatLM-mini-Chinese

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
dpo_train.py 6.69 KB
一键复制 编辑 原始数据 按行查看 历史
Charent 提交于 2024-01-30 22:17 . bug fixed and update readme
# coding=utf-8
from typing import Dict, Optional
import time
import os
import pandas as pd
import torch
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizerFast, TrainingArguments
from trl import DPOTrainer
from tokenizers import Tokenizer
from peft import LoraConfig, TaskType, PeftModel
from config import DpoConfig, T5ModelConfig
from model.chat_model import TextToTextModel
from utils.functions import get_T5_config
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(split: str, file: str, cache_dir: str = '.cache') -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
"""
dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
def split_prompt_and_responses(sample: dict) -> Dict[str, str]:
return {
# add an eos token for signal that end of sentence, using in generate.
"prompt": f"{sample['prompt']}[EOS]",
"chosen": f"{sample['chosen']}[EOS]",
"rejected": f"{sample['rejected']}[EOS]",
}
return dataset.map(split_prompt_and_responses).shuffle(2333)
def train_dpo(config: DpoConfig, peft_config: LoraConfig=None) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
model_train, model_ref = None, None
if os.path.isdir(config.sft_model_file):
# 传入文件夹则 from_pretrained
model_train = TextToTextModel.from_pretrained(config.sft_model_file)
model_ref = TextToTextModel.from_pretrained(config.sft_model_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
model_train = TextToTextModel(t5_config)
model_train.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
model_ref = TextToTextModel(t5_config)
model_ref.load_state_dict(torch.load(config.sft_model_file, map_location='cpu'))
# 4. 加载训练数据集
train_dataset = get_dataset("train", file=config.dpo_train_file)
# 5. 加载评估数据集
# eval_dataset = get_dataset("train", file=config.dpo_eval_file)
eval_dataset = None
# 6. 初始化训练参数
training_args = TrainingArguments(
per_device_train_batch_size=config.per_device_train_batch_size,
num_train_epochs=config.num_train_epochs,
auto_find_batch_size=True,
remove_unused_columns=False,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
logging_first_step=True,
logging_steps=config.logging_steps,
save_steps=config.save_steps,
output_dir=config.output_dir,
optim="adafactor",
report_to="tensorboard",
log_level='info',
warmup_steps=config.warmup_steps,
bf16=False,
fp16=config.fp16,
seed=config.seed,
logging_dir=config.log_dir,
)
# 7. 初始化 DPO trainer
dpo_trainer = DPOTrainer(
model_train,
model_ref,
peft_config=peft_config,
args=training_args,
beta=config.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=config.max_seq_len,
max_target_length=config.max_seq_len,
max_prompt_length=config.max_seq_len,
generate_during_eval=True,
is_encoder_decoder=True,
)
# 8. 训练
dpo_trainer.train(
# resume_from_checkpoint=True
)
# 9. save log
loss_log = pd.DataFrame(dpo_trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/dpo_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# 10. 保存模型/lora
suffixe = '/lora/' if peft_config is not None else '/dpo'
model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
dpo_trainer.save_model(model_save_dir)
print('save model or lora adapter to: {}'.format(model_save_dir))
def merge_lora_weight_into_model(config: DpoConfig, peft_config: LoraConfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
sft_model = None
if os.path.isdir(config.sft_model_file):
# 传入文件夹则 from_pretrained
sft_model = TextToTextModel.from_pretrained(config.sft_model_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
sft_model = TextToTextModel(t5_config)
sft_model.load_state_dict(torch.load(config.sft_model_file, map_location='cpu')) # set cpu for no exception
# 注意这个路径要和上面的model_save_dir一致
# train_dpo函数代码
# 9. 保存模型/lora
# suffixe = '/lora/' if peft_config is not None else '/dpo'
# model_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + suffixe
adapter_save_dir = '/'.join(config.sft_model_file.split('/')[0: -1]) + '/lora'
peft_model = PeftModel.from_pretrained(
model=sft_model,
model_id=adapter_save_dir,
config=peft_config,
adapter_name='adapter',
)
# peft_model = PeftModel(
# model=sft_model,
# peft_config=peft_config,
# adapter_name='adapter',
# )
# 3. load adapter
print('load adapter from dir: {}'.format(adapter_save_dir))
peft_model.load_adapter(model_id=adapter_save_dir, adapter_name='adapter',)
# 4. merge
peft_model = peft_model.merge_and_unload()
# 5. save
save_merge_file = config.sft_model_file + '.dpo_lora_merged'
sft_model.save_pretrained(save_merge_file)
print('save merge model file to: {}'.format(save_merge_file))
if __name__ == "__main__":
peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM, # text 2 text lora model
inference_mode=False,
r=16,
lora_alpha=16,
lora_dropout=0.1,
bias="all",
)
dpo_config = DpoConfig()
# 1. train
train_dpo(dpo_config, peft_config=None)
# 2. merge lora adapter into model
# merge_lora_weight_into_model(dpo_config, peft_config)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/topkernel/ChatLM-mini-Chinese.git
git@gitee.com:topkernel/ChatLM-mini-Chinese.git
topkernel
ChatLM-mini-Chinese
ChatLM-mini-Chinese
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385