代码拉取完成,页面将自动刷新
同步操作将从 oceanbase-devhub/ai-workshop-2024 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import dspy
import requests
from typing import List
from llm_extractor import ExtractKeywords
import dspy
from neo4j import GraphDatabase
from pyobvector import ObVecClient, VECTOR
from sqlalchemy import func
tongyi_lm = dspy.LM(
model="openai/qwen-plus",
api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
api_key=os.environ.get("DASHSCOPE_API_KEY")
)
dspy.settings.configure(lm=tongyi_lm)
neo_uri = "neo4j://localhost:7687"
user = "neo4j"
password = os.environ.get("NEO4J_PASSWORD", "")
graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password))
ob = ObVecClient()
CONTENT_EMBED_TABLE = "content_embed_table"
VECTOR_RECALL_TOPK = 3
def extract_keywords(query: str):
try:
keywords_extractor = dspy.Predict(ExtractKeywords)
pred = keywords_extractor(text=query)
return pred.keywords
except Exception as e:
print("XXXXXXXXXX failed to extract keywords")
raise e
# print(extract_keywords("OceanBase是什么"))
def query_graphdb(query: str):
keywords = extract_keywords(query)
print(keywords)
with graph_db.session() as session:
# res = session.run(
# f"MATCH (e: Entity) WHERE e.name IN {keywords} RETURN e"
# )
# ents = [r for r in res]
res = session.run(
f"MATCH (c: Chunk)-[r]->(e: Entity) WHERE e.name IN {keywords} RETURN c"
)
leaf_chunks = [r for r in res]
return leaf_chunks
def query_graphdb_entities_and_rels_with_chunk_ids(
chunk_ids: List[str]
):
with graph_db.session() as session:
chunks_res = session.run(
f"MATCH (c: Chunk)-[]->(e1:Entity) WHERE c.id IN {chunk_ids} " \
f"MATCH (c: Chunk)-[]->(e2:Entity) WHERE c.id IN {chunk_ids} " \
f"MATCH p=(e1)-[r]->(e2) RETURN p"
)
relations = []
for r in chunks_res:
start_ent = r['p'].start_node['name']
end_ent = r['p'].end_node['name']
for rel in r['p'].relationships:
relations.append(start_ent + "#" + rel['description'] + "#" + end_ent)
return relations
def embedding(queries: List[str]):
res = requests.post(
os.environ.get("REMOTE_BGE_URL", ""),
json={"model": "bge-m3", "input": queries},
headers={
"X-Token": os.environ.get("REMOTE_BGE_TOKEN", "")
},
)
try:
data = res.json()
except Exception as e:
print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX")
raise e
return data["embeddings"]
def query_vecdb(query: str):
vec = embedding([query])[0]
res = ob.ann_search(
table_name=CONTENT_EMBED_TABLE,
vec_data=vec,
vec_column_name="content_embedding",
distance_func=func.l2_distance,
with_dist=False,
topk=VECTOR_RECALL_TOPK,
output_column_names=["chunk_id", "content"],
)
return [
{
"chunk_id": r[0],
"content": r[1]
}
for r in res
]
PROMPT = """
你是一个知识库问答助手,非常擅长利用文档上下文以及文档中实体的关系为用户提供详实、正确的问答服务
以下是相关的文档上下文:
{context}
以下是文档上下文中的实体关系(每一行表示一组实体关系,格式为'起始实体#关系#目标实体'):
{relations}
以下是用户的问题:
{query}
现在请回答用户的问题:
"""
def response_query(
query: str,
):
vres = query_vecdb(query)
chunk_ids = [v["chunk_id"] for v in vres]
rels = query_graphdb_entities_and_rels_with_chunk_ids(chunk_ids)
# print("\n".join(list(set(rels))))
prompt = PROMPT.format(
context="\n".join([v["content"] for v in vres]),
relations="\n".join(list(set(rels))),
query=query,
)
print(prompt)
return tongyi_lm(prompt)
while True:
query = input("> ")
res = response_query(query)
print(f"====================================\n{res}\n==============================")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。