diff --git a/llm_extractor.py b/llm_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2be9b727f7c647b5ef34ea9223fdf8c3b5fc95cf --- /dev/null +++ b/llm_extractor.py @@ -0,0 +1,73 @@ +import os +import dspy +from typing import List +from pydantic import BaseModel, Field + +class Entity(BaseModel): + name: str = Field( + description="从文本中抽取出的实体名字,以便构建知识图谱" + ) + +class Relationship(BaseModel): + source_entity: str = Field( + description="关系的源实体名字,要求必须是出现在实体列表当中的实体名" + ) + target_entity: str = Field( + description="关系的目标实体名字,要求必须是出现在实体列表当中的实体名" + ) + relation_name: str = Field( + description="关系名,一般是一个谓词" + ) + +class KnowledgeGraph(BaseModel): + entities: List[Entity] = Field( + description="知识图谱当中的一组实体,要求实体名不重复" + ) + relationships: List[Relationship] = Field( + description="知识图谱当中的一组关系,要求关系名不重复" + ) + +class ExtractKG(dspy.Signature): + text: str = dspy.InputField( + desc="基于这段文本抽取实体和关系来形成一个知识图谱" + ) + entities: List[str] = dspy.OutputField( + desc="知识图谱当中的一组实体,每个实体的格式为'实体名',要求实体名不重复,要求数量不超过10个" + ) + relationships: List[str] = dspy.OutputField( + desc="知识图谱当中的一组关系,每个关系的格式为'来源实体名#关系名#目标实体名',来源实体名、关系名、目标实体名中存在'#'时使用'_'替换,要求来源实体名和目标实体名必须在实体列表中包含" + ) + # knowledge_graph: KnowledgeGraph = dspy.OutputField( + # desc="基于文本抽取得到的知识图谱" + # ) + +class ExtractKeywords(dspy.Signature): + text: str = dspy.InputField( + desc="基于这段文本抽取关键词列表以便于在知识图谱中搜索" + ) + keywords: List[str] = dspy.OutputField( + desc="一组关键词,要求每个关键词清晰明确,尽可能多的包含同义词" + ) + +def parse_extract_output_to_kg(pred) -> KnowledgeGraph: + entities = [] + for ent in pred.entities: + entities.append(Entity( + name=ent + )) + + rels = [] + for rel in pred.relationships: + rel_eles = rel.split('#') + if len(rel_eles) != 3: + continue + rels.append(Relationship( + source_entity=rel_eles[0], + target_entity=rel_eles[2], + relation_name=rel_eles[1], + )) + + return KnowledgeGraph( + entities=entities, + relationships=rels + ) diff --git a/load_docs_for_graph_rag.py b/load_docs_for_graph_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..83b7249dfab34b20bf7af4db8e0f27de52d67c80 --- /dev/null +++ b/load_docs_for_graph_rag.py @@ -0,0 +1,393 @@ +import os +import requests +import uuid +from typing import List, Optional + +from pydantic import BaseModel +from pyobvector import ObVecClient, VECTOR +from sqlalchemy import Column, Integer, String +from sqlalchemy.dialects.mysql import TEXT +from langchain.text_splitter import MarkdownHeaderTextSplitter +from neo4j import GraphDatabase + +from concurrent.futures import ThreadPoolExecutor +import asyncio + +from llm_extractor import ExtractKG, KnowledgeGraph, Entity, Relationship, parse_extract_output_to_kg +import dspy + +headers_to_split_on = [ + ("#", "Header1"), + ("##", "Header2"), + ("###", "Header3"), + ("####", "Header4"), + ("#####", "Header5"), + ("######", "Header6"), +] + +splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, +) + +neo_uri = "neo4j://localhost:7687" +user = "neo4j" +password = os.environ.get("NEO4J_PASSWORD", "") +graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) + +ob = ObVecClient() +cols = [ + Column("chunk_id", String(128), primary_key=True, autoincrement=False), + Column("content_embedding", VECTOR(1024)), + Column("content", TEXT), +] +CONTENT_EMBED_TABLE = "content_embed_table" +if not ob.check_table_exists(CONTENT_EMBED_TABLE): + print(f"################### create table {CONTENT_EMBED_TABLE}") + ob.create_table( + CONTENT_EMBED_TABLE, columns=cols + ) + # create vector index + ob.create_index( + CONTENT_EMBED_TABLE, + is_vec_index=True, + index_name="vidx", + column_names=["content_embedding"], + vidx_params="distance=l2, type=hnsw, lib=vsag", + ) +OB_DEFAULT_BATCH_SIZE = 10 + +INCLUDE = "include" +CHUNK_INCLUDE_CHUNK = "chunk_include_chunk" +CHUNK_NEXT_CHUNK = "chunk_next_chunk" +DOC_INCLUDE_CHUNK = "doc_include_chunk" +DOC_INCLUDE_DOC = "doc_include_doc" +RELATIONSHIP = "relationship" +CHUNK_INCLUDE_ENTITY = "chunk_include_entity" + +tongyi_lm = dspy.LM( + model="openai/qwen-plus", + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY") +) +dspy.settings.configure(lm=tongyi_lm) + +extract_executor = ThreadPoolExecutor(max_workers=16) + +def reset_ob(): + ob.perform_raw_text_sql(f"DROP TABLE {CONTENT_EMBED_TABLE}") + +def reset_graphdb(): + graph_db.execute_query("MATCH (n) DETACH DELETE n") + + +class ChunkWithRelation(BaseModel): + chunk_id: str + content: str + chunk_name: str + lv: int + parent_chunk: Optional["ChunkWithRelation"] + next_chunk: Optional["ChunkWithRelation"] + +class Doc(BaseModel): + doc_id: str + doc_name: str + keywords: List[str] + + + + +def parse_headern_to_lv(headern: str): + return int(headern[len("Header"):]) + +def embedding(queries: List[str]): + res = requests.post( + os.environ.get("REMOTE_BGE_URL", ""), + json={"model": "bge-m3", "input": queries}, + headers={ + "X-Token": os.environ.get("REMOTE_BGE_TOKEN", "") + }, + ) + try: + data = res.json() + except Exception as e: + print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX") + raise e + return data["embeddings"] + +def embed_chunks( + chunks: List[ChunkWithRelation], +): + chunk_contents = [chunk.content for chunk in chunks] + vecs = embedding(chunk_contents) + datas = [ + { + "chunk_id": chunk.chunk_id, + "content_embedding": vec, + "content": chunk.content, + } + for (chunk, vec) in zip(chunks, vecs) + ] + ob.insert(CONTENT_EMBED_TABLE, datas) + + +async def embed_chunks_with_batch_size( + executor: ThreadPoolExecutor, + chunks: List[ChunkWithRelation], + batch_size: int = OB_DEFAULT_BATCH_SIZE, +): + tasks = [] + for start_idx in range(0, len(chunks), batch_size): + tasks.append( + loop.run_in_executor(executor, embed_chunks, chunks[start_idx : start_idx+batch_size]) + ) + await asyncio.gather(*tasks) + +def extract_chunk_kg( + chunk: ChunkWithRelation +): + try: + kg_extractor = dspy.Predict(ExtractKG) + pred = kg_extractor(text=chunk.chunk_name + ": " + chunk.content) + except Exception as e: + return + kg: KnowledgeGraph = parse_extract_output_to_kg(pred) + print("======================= extract_chunk_kg: ", kg.model_dump()) + graphdb_upsert_graph(kg, chunk) + +async def aget_doc_tree( + file_path: str, +) -> List[ChunkWithRelation]: + chunks_with_rel: List[ChunkWithRelation] = [] + level_stk: List[ChunkWithRelation] = [] + with open(file_path, "r", encoding="utf-8") as f: + file_name = os.path.basename(file_path) + file_content = f.read() + chunks = splitter.split_text(file_content) + for chunk in chunks: + metadata_keys = list(chunk.metadata.keys()) + chunk_name = '-'.join(list(chunk.metadata.values())) + chunk_name = '-'.join([file_name, chunk_name]) + # TODO: if chunk_name does not exists + chunk_lv = ( + 0 if len(metadata_keys) == 0 + else parse_headern_to_lv(metadata_keys[-1]) + ) + + # pop stk if necessary + while len(level_stk) > 0 and chunk_lv <= level_stk[-1].lv: + level_stk.pop() + + chunk_with_rel = ChunkWithRelation( + chunk_id=str(uuid.uuid4()), + content=chunk.page_content, + chunk_name=chunk_name, + lv=chunk_lv, + parent_chunk=( + None if len(level_stk) == 0 + else level_stk[-1] + ), + next_chunk=None, + ) + if len(chunks_with_rel) > 0: + chunks_with_rel[-1].next_chunk = chunk_with_rel + chunks_with_rel.append(chunk_with_rel) + level_stk.append(chunk_with_rel) + return chunks_with_rel + +def graphdb_upsert_chunks( + chunks: List[ChunkWithRelation], +): + def _create_chunk(tx, chunk_batch): + for chunk in chunk_batch: + query = ( + "CREATE (chunk: Chunk {id: $id, name: $name, content: $content})" + ) + tx.run(query, id=chunk.chunk_id, name=chunk.chunk_name, content=chunk.content) + + with graph_db.session() as session: + session.execute_write(_create_chunk, chunks) + +def graphdb_upsert_chunk_rels( + chunks: List[ChunkWithRelation], +): + def _create_chunk_rels(tx, chunk_batch): + for chunk in chunk_batch: + parent_chunk = chunk.parent_chunk + next_chunk = chunk.next_chunk + if parent_chunk: + query = ( + f"MATCH (s: Chunk {{id: $parent_id}}), (t: Chunk {{id: $id}}) " \ + f"CREATE (s)-[r: {INCLUDE} {{description: '{CHUNK_INCLUDE_CHUNK}'}}]->(t)" + ) + tx.run(query, parent_id=parent_chunk.chunk_id, id=chunk.chunk_id) + if next_chunk: + query = ( + f"MATCH (s: Chunk {{id: $id}}), (t: Chunk {{id: $next_id}}) " \ + f"CREATE (s)-[r: {INCLUDE} {{description: '{CHUNK_NEXT_CHUNK}'}}]->(t)" + ) + tx.run(query, id=chunk.chunk_id, next_id=next_chunk.chunk_id) + with graph_db.session() as session: + session.execute_write(_create_chunk_rels, chunks) + +def graphdb_upsert_doc( + doc: Doc +): + def _create_doc(tx, doc): + query = ( + "CREATE (d: Doc {id: $id, name: $name, keywords: $keywords})" + ) + tx.run(query, id=doc.doc_id, name=doc.doc_name, keywords=doc.keywords) + + with graph_db.session() as session: + session.execute_write(_create_doc, doc) + +def graphdb_upsert_doc_include_chunk( + doc: Doc, + hd_chunks: List[ChunkWithRelation] +): + def _create_doc_include_chunks(tx, doc, chunks): + for chunk in chunks: + query = ( + f"MATCH (s: Doc {{id: $doc_id}}), (t: Chunk {{id: $chunk_id}}) " \ + f"CREATE (s)-[r: {INCLUDE} {{description: '{DOC_INCLUDE_CHUNK}'}}]->(t)" + ) + tx.run(query, doc_id=doc.doc_id, chunk_id=chunk.chunk_id) + with graph_db.session() as session: + session.execute_write(_create_doc_include_chunks, doc, hd_chunks) + +def graphdb_upsert_doc_include_doc( + parent_doc: Doc, + doc: Doc +): + def _create_doc_include_doc(tx, par_doc, doc): + query = ( + f"MATCH (s: Doc {{id: $par_id}}), (t: Doc {{id: $doc_id}}) " \ + f"CREATE (s)-[r: {INCLUDE} {{description: '{DOC_INCLUDE_DOC}'}}]->(t)" + ) + tx.run(query, par_id=par_doc.doc_id, doc_id=doc.doc_id) + with graph_db.session() as session: + session.execute_write(_create_doc_include_doc, parent_doc, doc) + +def graphdb_upsert_entities( + entities: List[Entity] +): + def _create_entities(tx, ents: List[Entity]): + for ent in ents: + query = ( + f"MERGE (e: Entity {{name: $name}})" + ) + tx.run(query, name=ent.name) + with graph_db.session() as session: + session.execute_write(_create_entities, entities) + +def graphdb_upsert_relations( + relations: List[Relationship] +): + def _create_rels(tx, rels: List[Relationship]): + for rel in rels: + query = ( + f"MATCH (s: Entity {{name: $sname}}), (t: Entity {{name: $tname}}) " \ + f"MERGE (s)-[r: {RELATIONSHIP} {{description: $rdesc}}]->(t)" + ) + tx.run(query, sname=rel.source_entity, tname=rel.target_entity, rdesc=rel.relation_name) + with graph_db.session() as session: + session.execute_write(_create_rels, relations) + +def graphdb_upsert_chunk_include_entities( + chunk_id: str, + entities: List[Entity] +): + def _create_chunk_include_entities(tx, cid, ents): + for ent in ents: + query = ( + f"MATCH (c: Chunk {{id: $id}}), (e: Entity {{name: $name}}) " \ + f"CREATE (c)-[r: {INCLUDE} {{description: '{CHUNK_INCLUDE_ENTITY}'}}]->(e)" + ) + tx.run(query, id=cid, name=ent.name) + with graph_db.session() as session: + session.execute_write(_create_chunk_include_entities, chunk_id, entities) + +def graphdb_upsert_graph( + graph: KnowledgeGraph, + chunk: ChunkWithRelation, +): + graphdb_upsert_entities(graph.entities) + graphdb_upsert_relations(graph.relationships) + graphdb_upsert_chunk_include_entities(chunk.chunk_id, graph.entities) + + +async def aload_doc_graph( + doc_root_path: str, + relative_path: str, +): + print("=========== start_load_doc_graph: ", doc_root_path, relative_path) + file_path = os.path.join(doc_root_path, relative_path) + + # Get Document tree + chunks_with_rel = await aget_doc_tree(file_path=file_path) + + graphdb_upsert_chunks(chunks_with_rel) + graphdb_upsert_chunk_rels(chunks_with_rel) + + extract_tasks = [] + for chunk in chunks_with_rel: + extract_tasks.append( + loop.run_in_executor(extract_executor, extract_chunk_kg, chunk) + ) + await asyncio.gather(*extract_tasks) + + await embed_chunks_with_batch_size(extract_executor, chunks_with_rel) + + return chunks_with_rel[0] if len(chunks_with_rel) > 0 else None + +async def aload_doc( + doc_root_path: str, + doc_repo_name: str, + parent_doc: Optional[Doc] = None, +): + files = [] + dirs = [] + with os.scandir(doc_root_path) as entries: + for entry in entries: + if entry.is_file() and entry.name.endswith(".md"): + files.append(entry.name) + elif entry.is_dir(): + new_doc_root_path = os.path.join(doc_root_path, entry.name) + dirs.append(entry.name) + + load_doc_graph_tasks = [] + for file in files: + load_doc_graph_tasks.append( + aload_doc_graph(doc_root_path, file) + ) + if len(load_doc_graph_tasks) > 0: + hd_chunks = await asyncio.gather(*load_doc_graph_tasks) + + doc = Doc( + doc_id=str(uuid.uuid4()), + doc_name=doc_repo_name[doc_repo_name.find('.') + 1:], + keywords=[], + ) + graphdb_upsert_doc(doc) + if len(hd_chunks) > 0: + graphdb_upsert_doc_include_chunk(doc, hd_chunks) + if parent_doc: + graphdb_upsert_doc_include_doc(parent_doc, doc) + + print(f"Doc: {{doc_name: {doc.doc_name}, doc_path: {doc_root_path}}}") + + load_doc_tasks = [] + for dir in dirs: + new_doc_root_path = os.path.join(doc_root_path, dir) + load_doc_tasks.append( + aload_doc(new_doc_root_path, dir, doc) + ) + if len(load_doc_tasks) > 0: + await asyncio.gather(*load_doc_tasks) + + +reset_graphdb() +loop = asyncio.get_event_loop() +loop.run_until_complete( + aload_doc(doc_root_path="./doc_test", doc_repo_name="OceanBase") +) diff --git a/query_for_graph_rag.py b/query_for_graph_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4a03028cd5c84603e825b2f2254873a044f2fe --- /dev/null +++ b/query_for_graph_rag.py @@ -0,0 +1,139 @@ +import os +import dspy +import requests +from typing import List + +from llm_extractor import ExtractKeywords +import dspy +from neo4j import GraphDatabase +from pyobvector import ObVecClient, VECTOR +from sqlalchemy import func + +tongyi_lm = dspy.LM( + model="openai/qwen-plus", + api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", + api_key=os.environ.get("DASHSCOPE_API_KEY") +) +dspy.settings.configure(lm=tongyi_lm) + +neo_uri = "neo4j://localhost:7687" +user = "neo4j" +password = os.environ.get("NEO4J_PASSWORD", "") +graph_db = GraphDatabase.driver(uri=neo_uri, auth=(user, password)) + +ob = ObVecClient() +CONTENT_EMBED_TABLE = "content_embed_table" + +VECTOR_RECALL_TOPK = 3 + +def extract_keywords(query: str): + try: + keywords_extractor = dspy.Predict(ExtractKeywords) + pred = keywords_extractor(text=query) + return pred.keywords + except Exception as e: + print("XXXXXXXXXX failed to extract keywords") + raise e + +# print(extract_keywords("OceanBase是什么")) + +def query_graphdb(query: str): + keywords = extract_keywords(query) + print(keywords) + with graph_db.session() as session: + # res = session.run( + # f"MATCH (e: Entity) WHERE e.name IN {keywords} RETURN e" + # ) + # ents = [r for r in res] + res = session.run( + f"MATCH (c: Chunk)-[r]->(e: Entity) WHERE e.name IN {keywords} RETURN c" + ) + leaf_chunks = [r for r in res] + return leaf_chunks + +def query_graphdb_entities_and_rels_with_chunk_ids( + chunk_ids: List[str] +): + with graph_db.session() as session: + chunks_res = session.run( + f"MATCH (c: Chunk)-[]->(e1:Entity) WHERE c.id IN {chunk_ids} " \ + f"MATCH (c: Chunk)-[]->(e2:Entity) WHERE c.id IN {chunk_ids} " \ + f"MATCH p=(e1)-[r]->(e2) RETURN p" + ) + relations = [] + for r in chunks_res: + start_ent = r['p'].start_node['name'] + end_ent = r['p'].end_node['name'] + for rel in r['p'].relationships: + relations.append(start_ent + "#" + rel['description'] + "#" + end_ent) + return relations + +def embedding(queries: List[str]): + res = requests.post( + os.environ.get("REMOTE_BGE_URL", ""), + json={"model": "bge-m3", "input": queries}, + headers={ + "X-Token": os.environ.get("REMOTE_BGE_TOKEN", "") + }, + ) + try: + data = res.json() + except Exception as e: + print(f"XXXXXXXXXXXXXXXXXXXXXX {res.text} XXXXXXXXXXXXXXXXXXXX") + raise e + return data["embeddings"] + +def query_vecdb(query: str): + vec = embedding([query])[0] + res = ob.ann_search( + table_name=CONTENT_EMBED_TABLE, + vec_data=vec, + vec_column_name="content_embedding", + distance_func=func.l2_distance, + with_dist=False, + topk=VECTOR_RECALL_TOPK, + output_column_names=["chunk_id", "content"], + ) + return [ + { + "chunk_id": r[0], + "content": r[1] + } + for r in res + ] + +PROMPT = """ + 你是一个知识库问答助手,非常擅长利用文档上下文以及文档中实体的关系为用户提供详实、正确的问答服务 + + 以下是相关的文档上下文: + {context} + + 以下是文档上下文中的实体关系(每一行表示一组实体关系,格式为'起始实体#关系#目标实体'): + {relations} + + 以下是用户的问题: + {query} + + 现在请回答用户的问题: +""" + +def response_query( + query: str, +): + vres = query_vecdb(query) + chunk_ids = [v["chunk_id"] for v in vres] + rels = query_graphdb_entities_and_rels_with_chunk_ids(chunk_ids) + # print("\n".join(list(set(rels)))) + prompt = PROMPT.format( + context="\n".join([v["content"] for v in vres]), + relations="\n".join(list(set(rels))), + query=query, + ) + print(prompt) + return tongyi_lm(prompt) + +while True: + query = input("> ") + res = response_query(query) + print(f"====================================\n{res}\n==============================") +