6 Star 1 Fork 0

PolarDB/AIRobot-ECNU

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
test_zzzw.py 6.61 KB
一键复制 编辑 原始数据 按行查看 历史
Ashura5 提交于 2023-12-12 11:54 . 初始化代码
import requests
import json
import sys
from pathlib import Path
import time
sys.path.append("/home/lixiang/Langchain-Chatchat")
from server.utils import api_address
from pprint import pprint
import jsonlines
api_base_url = api_address()
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from nltk.translate.bleu_score import sentence_bleu
import torch
import jieba
import csv
# 通过 sys.argv 读取命令行参数
# sys.argv[0] 代表脚本文件名本身
# 从 sys.argv[1] 开始是传入的参数
if len(sys.argv) > 1:
# 输出所有传入的参数:python test_zzzw.py ChatGLM2 ChatGLM2.json test_data2.csv
model_name = sys.argv[1]
file_list = sys.argv[2]
output_file = sys.argv[-1]
else:
print("未传入参数")
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "="*30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "="*30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
def test_chat_chat_case(query,model_name,api="/chat/fastchat"):
url = f"{api_base_url}{api}"
data = {
"messages": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 "+model_name
},
{
"role": "user",
"content": query
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
assert response.status_code == 200
result=""
for line in response.iter_content(None, decode_unicode=True):
result=result+line
print(result)
return result
def test_knowledge_chat_case(query,model_name,api="/chat/knowledge_base_chat"):
url = f"{api_base_url}{api}"
data = {
"query": query,
"knowledge_base_name": "samples",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 "+model_name
}
],
"stream": True
}
dump_input(data, api)
response = requests.post(url, headers=headers, json=data, stream=True)
assert response.status_code == 200
result=""
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if "answer" in data:
result=result+data["answer"]
print(result)
return result
def embedding(sentences,tokenizer,model):
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt',max_length=512)
encoded_input=encoded_input.to('cuda')
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output[0][:, 0]
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings
def cos_sim(a, b):
return torch.nn.functional.cosine_similarity(a, b,dim=1)
# f=open('yuliao.txt', encoding='utf-8')
# txt=[]
# for line in f:
# txt.append(line.strip())
# data=[]
# for i in range(int(len(txt)/2)):
# data.append({"query":txt[2*i].split(": ")[1],"answer":txt[2*i+1].split(": ")[1]})
df = pd.read_excel('/home/lixiang/Langchain-Chatchat/PolarDB_knowledge.xlsx')
# 获取 question 和 answer 列的数据
similar_questions = df['similarQuestion']
questions = df['question']
answers = df['answer']
# 创建新的数据列表
data = []
for q, a in zip(questions, answers):
data.append({"query": q, "answer": a})
# 分割类似问题并添加到 data 列表中
# for idx, sim_questions in enumerate(similar_questions):
# q_list = sim_questions.split("#;#")
# for q in q_list:
# data.append({"query": q, "answer": answers[idx]})
start_knowledge_chat=time.time()
for item in data:
answer=test_knowledge_chat_case(item["query"],model_name)
item[model_name]=answer
end_knowledge_chat=time.time()
print(end_knowledge_chat-start_knowledge_chat)
with jsonlines.open(f'{model_name}.json', mode='w') as writer:
for item in data:
writer.write(item)
start_chat=time.time()
for item in data:
answer=test_chat_chat_case(item["query"],model_name)
item[model_name]=answer
end_chat=time.time()
print(end_chat-start_chat)
with jsonlines.open(f'{model_name}_no.json', mode='w') as writer:
for item in data:
writer.write(item)
# cos.py
# file_list=["ChatGPT.json"]
data = {
"query": [],
"answer": []
}
file = open(file_list, 'r', encoding='utf-8')
# 上面路径是我的 json 文件所在地,后面包含中文编码
for line in file.readlines():
dic = json.loads(line)
data["query"].append(dic["query"])
data["answer"].append(dic["answer"])
for key in dic:
if key!="query" and key!="answer":
if file_list not in data:
data[file_list] = []
data[file_list].append(dic[key])
tokenizer = AutoTokenizer.from_pretrained('/home/lixiang/LLM/bge-large-zh-v1.5',model_max_length=512)
model = AutoModel.from_pretrained('/home/lixiang/LLM/bge-large-zh-v1.5',device_map="cuda")
embeddings={}
for key in data:
if key!="query":
embeddings[key]=embedding(data[key],tokenizer,model)
all_cos={}
for key in data:
if key!="query":
all_cos[key]=cos_sim(embeddings[key],embeddings["answer"])
for key in all_cos:
print(model_name,":",key,":",torch.mean(all_cos[key]))
with open('cos_output.txt', 'w') as file:
for key in all_cos:
file.write(f"{model_name} : {key} : {torch.mean(all_cos[key])}\n")
ref={}
for key in data:
if key !="query":
ref[key]=[list(jieba.cut(x)) for x in data[key]]
if key =='answer':
ref['reference']=[[list(jieba.cut(x))] for x in data[key]]
for key in ref:
if key!="reference":
print(key, ":", sum([sentence_bleu(reference,candidate) for reference, candidate in zip(ref["reference"],ref[key])])/93)
with open(output_file, "w",newline='', encoding='utf-8-sig') as f:
f.write("问题,标准答案," + model_name + "\n")
f=csv.writer(f)
for i in range(len(data["query"])):
str1=[]
for key in data:
str1.append(data[key][i])
f.writerow(str1)
print("使用知识库的执行时间:" + str(end_knowledge_chat-start_knowledge_chat))
print("不使用知识库的执行时间:" + str(end_chat-start_chat))
for key in all_cos:
print(key,":",torch.mean(all_cos[key]))
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Java
1
https://gitee.com/polardb/AIRobot-ECNU.git
git@gitee.com:polardb/AIRobot-ECNU.git
polardb
AIRobot-ECNU
AIRobot-ECNU
master

搜索帮助