1 Star 0 Fork 2

GITHUBear/ai-workshop-2024

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
query_for_graph_rag.py 4.00 KB
一键复制 编辑 原始数据 按行查看 历史
shanhaikang.shk 提交于 2024-11-28 17:13 . ob graphrag demo finish
import os
import dspy
import requests
from typing import List
from llm_extractor import ExtractKeywords
import dspy
from neo4j import GraphDatabase
from pyobvector import ObVecClient, VECTOR
from sqlalchemy import func
tongyi_lm = dspy.LM(
model="openai/qwen-plus",
api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
api_key=os.environ.get("DASHSCOPE_API_KEY")
)
dspy.settings.configure(lm=tongyi_lm)
neo_uri = "neo4j://localhost:7687"
user = "neo4j"
password = os.environ.get("NEO4J_PASSWORD", "")
graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password))
ob = ObVecClient()
CONTENT_EMBED_TABLE = "content_embed_table"
VECTOR_RECALL_TOPK = 3
def extract_keywords(query: str):
try:
keywords_extractor = dspy.Predict(ExtractKeywords)
pred = keywords_extractor(text=query)
return pred.keywords
except Exception as e:
print("XXXXXXXXXX failed to extract keywords")
raise e
# print(extract_keywords("OceanBase是什么"))
def query_graphdb(query: str):
keywords = extract_keywords(query)
print(keywords)
with graph_db.session() as session:
# res = session.run(
# f"MATCH (e: Entity) WHERE e.name IN {keywords} RETURN e"
# )
# ents = [r for r in res]
res = session.run(
f"MATCH (c: Chunk)-[r]->(e: Entity) WHERE e.name IN {keywords} RETURN c"
)
leaf_chunks = [r for r in res]
return leaf_chunks
def query_graphdb_entities_and_rels_with_chunk_ids(
chunk_ids: List[str]
):
with graph_db.session() as session:
chunks_res = session.run(
f"MATCH (c: Chunk)-[]->(e1:Entity) WHERE c.id IN {chunk_ids} " \
f"MATCH (c: Chunk)-[]->(e2:Entity) WHERE c.id IN {chunk_ids} " \
f"MATCH p=(e1)-[r]->(e2) RETURN p"
)
relations = []
for r in chunks_res:
start_ent = r['p'].start_node['name']
end_ent = r['p'].end_node['name']
for rel in r['p'].relationships:
relations.append(start_ent + "#" + rel['description'] + "#" + end_ent)
return relations
def embedding(queries: List[str]):
res = requests.post(
os.environ.get("REMOTE_BGE_URL", ""),
json={"model": "bge-m3", "input": queries},
headers={
"X-Token": os.environ.get("REMOTE_BGE_TOKEN", "")
},
)
try:
data = res.json()
except Exception as e:
print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX")
raise e
return data["embeddings"]
def query_vecdb(query: str):
vec = embedding([query])[0]
res = ob.ann_search(
table_name=CONTENT_EMBED_TABLE,
vec_data=vec,
vec_column_name="content_embedding",
distance_func=func.l2_distance,
with_dist=False,
topk=VECTOR_RECALL_TOPK,
output_column_names=["chunk_id", "content"],
)
return [
{
"chunk_id": r[0],
"content": r[1]
}
for r in res
]
PROMPT = """
你是一个知识库问答助手,非常擅长利用文档上下文以及文档中实体的关系为用户提供详实、正确的问答服务
以下是相关的文档上下文:
{context}
以下是文档上下文中的实体关系(每一行表示一组实体关系,格式为'起始实体#关系#目标实体'):
{relations}
以下是用户的问题:
{query}
现在请回答用户的问题:
"""
def response_query(
query: str,
):
vres = query_vecdb(query)
chunk_ids = [v["chunk_id"] for v in vres]
rels = query_graphdb_entities_and_rels_with_chunk_ids(chunk_ids)
# print("\n".join(list(set(rels))))
prompt = PROMPT.format(
context="\n".join([v["content"] for v in vres]),
relations="\n".join(list(set(rels))),
query=query,
)
print(prompt)
return tongyi_lm(prompt)
while True:
query = input("> ")
res = response_query(query)
print(f"====================================\n{res}\n==============================")
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/GITHUBear/ai-workshop-2024.git
git@gitee.com:GITHUBear/ai-workshop-2024.git
GITHUBear
ai-workshop-2024
ai-workshop-2024
tongyi

搜索帮助