1 Star 0 Fork 0

Mortal Fu、/minimind

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
eval_ceval.py 6.87 KB
一键复制 编辑 原始数据 按行查看 历史
gongjingyao 提交于 2024-08-28 16:41 . MiniMind first open source
import random
import time
import os
import pandas as pd
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig
import torch.nn.functional as F
warnings.filterwarnings('ignore')
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer',
trust_remote_code=True, use_fast=False)
model_from = 1 # 1从权重,2用transformers
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/single_chat/full_sft_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
# 处理不需要的前缀
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
# 加载到模型中
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
model = model.to(device)
return model, tokenizer
if __name__ == "__main__":
# -----------------------------------------------------------------------------
seed = random.randint(1, 2000)
# device = 'cuda:0'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
lm_config = LMConfig()
# -----------------------------------------------------------------------------
model, tokenizer = init_model(lm_config)
model = model.eval()
# 消息模板,具体实现根据你的tokenizer进行调整
messages_origin = [{"role": "system", "content": "开始回答问题"}]
# 定义文件目录
File_Dir = "ceval/ceval-exam/val"
results_dir = "ceval/ceval_result"
# 确保结果目录存在
if not os.path.exists(results_dir):
os.makedirs(results_dir)
# 用于记录所有文件的总正确数和总题数
total_correct = 0
total_questions = 0
# 遍历目录下的所有CSV文件
for filename in os.listdir(File_Dir):
if filename.endswith('.csv'):
file_path = os.path.join(File_Dir, filename)
test_df = pd.read_csv(file_path)
# 存储结果的DataFrame
results_df = pd.DataFrame(columns=['question', 'A', 'B', 'C', 'D', 'answer', 'llm_answer', 'is_right'])
total_correct_in_file = 0 # 用于记录当前文件的正确数
for row in test_df.itertuples(index=True, name='Pandas'):
id = getattr(row, 'id')
question = getattr(row, 'question')
A = getattr(row, 'A')
B = getattr(row, 'B')
C = getattr(row, 'C')
D = getattr(row, 'D')
right_answer = getattr(row, 'answer')
prompt = f'{question}。选择 A: {A}, B: {B}, C: {C}, D: {D}'
messages = messages_origin.copy()
messages.append({"role": "user", "content": prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
x = tokenizer(new_prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
res_ids = model.eval_answer(x)
# 假设 res_ids 是模型的 logits 输出,我们使用 softmax 转换为概率分布
probabilities = F.softmax(res_ids, dim=-1)
# 定义每个选项的 token id
A_id = tokenizer('A').data['input_ids']
B_id = tokenizer('B').data['input_ids']
C_id = tokenizer('C').data['input_ids']
D_id = tokenizer('D').data['input_ids']
# 获取每个选项的概率
A_prob = probabilities[0, A_id].item()
B_prob = probabilities[0, B_id].item()
C_prob = probabilities[0, C_id].item()
D_prob = probabilities[0, D_id].item()
# 将每个选项的概率放入字典中便于处理
options_prob = {
'A': A_prob,
'B': B_prob,
'C': C_prob,
'D': D_prob
}
# 找到具有最大概率的选项
max_option_answer = max(options_prob, key=options_prob.get)
# 比较答案并记录
is_right = 1 if max_option_answer == right_answer else 0
results_df = results_df.append({
'question': question,
'A': A,
'B': B,
'C': C,
'D': D,
'answer': right_answer,
'llm_answer': max_option_answer,
'is_right': is_right
}, ignore_index=True)
# print(f'id: {id} 问题: {question[:10]}... 是否正确: {is_right}')
if is_right:
total_correct_in_file += 1
total_correct += total_correct_in_file
total_questions += len(test_df)
# 计算当前文件的正确率并添加到结果DataFrame的最后一行
accuracy = total_correct_in_file / len(test_df)
results_df = results_df.append({
'question': '-',
'A': '-',
'B': '-',
'C': '-',
'D': '-',
'answer': f'文件 {filename} 的正确率: {accuracy:.2%}',
'llm_answer': '-',
'is_right': '-'
}, ignore_index=True)
print(f'{filename.split(".")[0]}{total_correct_in_file}/{len(test_df)},正确率: {accuracy:.2%}')
# 保存结果到CSV
results_path = os.path.join(results_dir, f"{filename.split('.')[0]}_result.csv")
results_df.to_csv(results_path, index=False)
# 计算总正确率
total_accuracy = total_correct / total_questions if total_questions > 0 else 0
# 将各个文件的正确率以及总正确率写入到 "ceval/ceval_result/test.log"
log_path = os.path.join(results_dir, "test.log")
with open(log_path, 'w') as log_file:
result = f"总题数: {total_questions}\n总正确数: {total_correct}\n总正确率: {total_accuracy:.2%}"
log_file.write(result)
print(result)
for filename in os.listdir(File_Dir):
if filename.endswith('.csv'):
accuracy_file = pd.read_csv(os.path.join(results_dir, f"{filename.split('.')[0]}_result.csv"))
last_row = accuracy_file.iloc[-1]['answer']
log_file.write(f"{filename}: {last_row}\n")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/lhwz666/minimind.git
git@gitee.com:lhwz666/minimind.git
lhwz666
minimind
minimind
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385