代码拉取完成,页面将自动刷新
import json
import os
import re
from chat import get_mongo_query_from_ai, get_answer_from_ai, check_consistency, \
get_mysql_query_from_ai,vehicle_and_energy_types # 假设这是你的 AI 查询生成模块
from db_connection import get_db_connection, get_mysql_connection # 假设这是你的数据库连接模块
def parse_json_string(json_str: str):
"""将 MongoDB 查询语句中的单引号替换为双引号,并解析 JSON 字符串,同时将 Python 的 True/False 转换为 MongoDB 的 true/false"""
json_str = json_str.replace("'", '"')
# 替换 Python 的 True 和 False 为 MongoDB 的 true 和 false
json_str = json_str.replace('False', 'false').replace('True', 'true')
return json.loads(json_str)
# 从大模型中提取 vehicle_type 和 energy_type 的函数
def extract_vehicle_and_energy_types_from_model(user_input_text):
# 构建中文提示模板
prompt = (
f"根据以下输入:'{user_input_text}',"
"提取车辆类型(车辆类型只能是‘货车’、‘挂车’、‘网约车’、‘私家车’)和能源类型(能源类型只能是‘新能源’或‘燃油’),别的车辆类型和能源类型不要提取"
"请返回一个包含‘vehicle_type’和‘energy_type’两个键的JSON对象。如果用户输入没有车辆类型信息或者能源类型信息,可以只返回只含有一个健的JSON对象"
)
# 使用大模型提取信息
extraction_result = vehicle_and_energy_types(prompt)
# 初始化结果字典
result = {}
# 先尝试使用 json.loads 来解析
try:
result = json.loads(extraction_result)
except json.JSONDecodeError:
# 如果 json.loads 失败,则使用正则表达式提取
pattern = r'\"([a-zA-Z_]+)\":\s*\"([^\"]+)\"'
matches = re.findall(pattern, extraction_result)
result = {key: value for key, value in matches}
#print(result)
if "vehicle_type" in result and result["vehicle_type"] and "energy_type" in result and result["energy_type"]:
vehicle_type = result["vehicle_type"]
energy_type = result["energy_type"]
return vehicle_type, energy_type
elif "vehicle_type" in result and result["vehicle_type"]:
vehicle_type = result["vehicle_type"]
energy_type = None
return vehicle_type, energy_type
elif "energy_type" in result and result["energy_type"]:
vehicle_type = None
energy_type = result["energy_type"]
return vehicle_type, energy_type
else :
vehicle_type = None
energy_type = None
return vehicle_type, energy_type
# 用户输入处理逻辑
def handle_user_input(user_input_text):
# 从大模型提取 vehicle_type 和 energy_type
vehicle_type, energy_type = extract_vehicle_and_energy_types_from_model(user_input_text)
# 预定义合法选项
valid_vehicle_types = ['货车', '挂车', '私家车', '网约车']
valid_energy_types = ['新能源', '燃油']
# 检查并纠正提取结果
if vehicle_type not in valid_vehicle_types:
print(f"提取的 vehicle_type '{vehicle_type}' 不在合法选项中,可能出现错误。")
vehicle_type = None
if energy_type not in valid_energy_types:
print(f"提取的 energy_type '{energy_type}' 不在合法选项中,可能出现错误。")
energy_type = None
if vehicle_type and energy_type:
return None
elif vehicle_type:
return f"无法从输入中检测到汽车是燃油车还是新能源。请提供更多详细信息。能源类型包括:燃油、新能源。"
elif energy_type:
return f"无法从输入中检测到汽车使用类型。请提供更多详细信息。车型包括:挂车、货车、网约车、私家车;"
else:
return f"无法从输入中检测到汽车类型和能源类型。请提供更多详细信息。车型包括:挂车、货车、网约车、私家车;能源类型包括:燃油、新能源。"
def query_documents_from_ai(question: str) -> str:
#预先过滤
check_question = handle_user_input(question)
if check_question is not None:
return check_question
# 获取数据库连接和集合
client, collection = get_db_connection()
# 调用大模型获取查询语句,传递 first_record 作为文档结构
filter_str = get_mongo_query_from_ai(question)
# 打印原始查询语句以便调试
print(f"生成的 filter_str: {filter_str}")
# 将单引号替换为双引号以符合 JSON 标准
filter_dict = parse_json_string(filter_str)
# 执行 find 查询
query_results = collection.find(filter_dict)
# 将查询结果存入 query_results 列表
results = list(query_results)
# 如果查询结果不为空,则将结果和问题传递给 get_answer_from_ai 方法
if results:
answer = get_answer_from_ai(question, results)
# 打印 AI 返回的答案
print(f"AI 返回的答案: {answer}")
return answer
else:
print("未找到符合条件的文档")
# 关闭连接
client.close()
def query_mysql_from_ai(question: str) -> str:
"""
从 AI 生成 MySQL 查询语句,连接 MySQL 数据库,执行查询并返回 AI 的答案。
:param question: 用户输入的问题
:return: AI 返回的答案
"""
# 获取 MySQL 数据库连接和 cursor
connection, cursor = get_mysql_connection()
try:
# 调用大模型生成 MySQL 查询语句
query_str = get_mysql_query_from_ai(question)
# 打印生成的查询语句以便调试
print(f"生成的 MySQL 查询语句: {query_str}")
# 执行查询
cursor.execute(query_str)
# 获取查询结果
results = cursor.fetchall()
# 如果查询结果不为空,则将结果和问题传递给 get_answer_from_ai 方法
if results:
# 转换为适合 AI 的结构
answer = get_answer_from_ai(question, results)
# 打印 AI 返回的答案
print(f"AI 返回的答案: {answer}")
return answer
else:
print("未找到符合条件的记录")
except Exception as e:
print(f"执行 MySQL 查询时出错: {e}")
finally:
# 关闭游标和连接
cursor.close()
connection.close()
if __name__ == "__main__":
# 定义 JSON 文件的路径
json_file_path = os.path.join(os.getcwd(), 'insurance_qa_data.json')
result_json_file_path = os.path.join(os.getcwd(), 'ai_vs_expected_results.json')
# 读取 JSON 文件并加载内容
with open(json_file_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
total_questions = len(json_data)
correct_answers = 0
# 创建用于存储结果的列表
result_data = []
for entry in json_data:
question = entry['Q']
expected_answer = entry['A']
# 调用 query_mysql_from_ai 获取 AI 返回的答案
ai_answer = query_documents_from_ai(question)
print(ai_answer)
print(expected_answer)
# 检查 AI 返回的答案与期望的答案是否一致
c = check_consistency(expected_answer, ai_answer)
print(c)
if c == "是":
correct_answers += 1
# 将问题、AI答案、期望答案存储为键值对
result_data.append({
"question": question,
"ai_answer": ai_answer,
"expected_answer": expected_answer,
"consistency_check": c
})
# 计算并打印正确率
accuracy = (correct_answers / total_questions) * 100
print(f"正确率: {accuracy:.2f}%")
# 将结果写入新的 JSON 文件
with open(result_json_file_path, 'w', encoding='utf-8') as f:
json.dump(result_data, f, ensure_ascii=False, indent=4)
print(f"结果已写入 {result_json_file_path}")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。