From d1d1e7ef20c604fe3ec1a9c7bfd37db8054f8008 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 21 Apr 2025 14:54:23 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- import_openie.py | 25 ++---- info_extraction.py | 4 +- raw_data_preprocessor.py | 33 ++++---- .../tool_can_use/lpmm_get_knowledge.py | 1 + src/plugins/knowledge/knowledge_lib.py | 5 +- src/plugins/knowledge/src/embedding_store.py | 28 ++----- src/plugins/knowledge/src/global_logger.py | 4 +- src/plugins/knowledge/src/ie_process.py | 18 +---- src/plugins/knowledge/src/kg_manager.py | 80 +++++-------------- src/plugins/knowledge/src/llm_client.py | 14 +--- src/plugins/knowledge/src/lpmmconfig.py | 5 +- .../knowledge/src/mem_active_manager.py | 8 +- src/plugins/knowledge/src/open_ie.py | 23 ++---- src/plugins/knowledge/src/prompt_template.py | 16 +--- src/plugins/knowledge/src/qa_manager.py | 32 ++++---- src/plugins/knowledge/src/raw_processing.py | 4 +- .../knowledge/src/utils/data_loader.py | 6 +- src/plugins/knowledge/src/utils/dyn_topk.py | 8 +- src/plugins/knowledge/src/utils/json_fix.py | 5 +- .../knowledge/src/utils/visualize_graph.py | 2 +- 20 files changed, 97 insertions(+), 224 deletions(-) diff --git a/import_openie.py b/import_openie.py index 537187db..5e347ef5 100644 --- a/import_openie.py +++ b/import_openie.py @@ -20,6 +20,7 @@ import sys logger = get_module_logger("LPMM知识库-OpenIE导入") + def hash_deduplicate( raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], @@ -43,14 +44,10 @@ def hash_deduplicate( # 保存去重后的三元组 new_triple_list_data = dict() - for _, (raw_paragraph, triple_list) in enumerate( - zip(raw_paragraphs.values(), triple_list_data.values()) - ): + for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and ( - paragraph_hash in stored_paragraph_hashes - ): + if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes): continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -58,9 +55,7 @@ def hash_deduplicate( return new_raw_paragraphs, new_triple_list_data -def handle_import_openie( - openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager -) -> bool: +def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: # 从OpenIE数据中提取段落原文与三元组列表 # 索引的段落原文 raw_paragraphs = openie_data.extract_raw_paragraph_dict() @@ -68,9 +63,7 @@ def handle_import_openie( entity_list_data = openie_data.extract_entity_dict() # 索引的三元组列表 triple_list_data = openie_data.extract_triple_dict() - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len( - triple_list_data - ): + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): logger.error("OpenIE数据存在异常") return False # 将索引换为对应段落的hash值 @@ -112,11 +105,11 @@ def main(): print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": logger.info("用户取消操作") print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") logger.info("----开始导入openie数据----\n") @@ -129,9 +122,7 @@ def main(): ) # 初始化Embedding库 - embed_manager = embed_manager = EmbeddingManager( - llm_client_list[global_config["embedding"]["provider"]] - ) + embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() diff --git a/info_extraction.py b/info_extraction.py index 04f2dd4f..b6ad8a9c 100644 --- a/info_extraction.py +++ b/info_extraction.py @@ -91,11 +91,11 @@ def main(): print("或者使用可以用赠金抵扣的Pro模型") print("请确保账户余额充足,并且在执行前确认无误。") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": logger.info("用户取消操作") print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") logger.info("--------进行信息提取--------\n") diff --git a/raw_data_preprocessor.py b/raw_data_preprocessor.py index abb1ad64..645a8506 100644 --- a/raw_data_preprocessor.py +++ b/raw_data_preprocessor.py @@ -3,18 +3,17 @@ import os from pathlib import Path import sys # 新增系统模块导入 + def check_and_create_dirs(): """检查并创建必要的目录""" - required_dirs = [ - "data/lpmm_raw_data", - "data/imported_lpmm_data" - ] - + required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] + for dir_path in required_dirs: if not os.path.exists(dir_path): os.makedirs(dir_path) print(f"已创建目录: {dir_path}") + def process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: @@ -29,12 +28,13 @@ def process_text_file(file_path): paragraph = "" else: paragraph += line + "\n" - + if paragraph != "": paragraphs.append(paragraph.strip()) - + return paragraphs + def main(): # 新增用户确认提示 print("=== 重要操作确认 ===") @@ -43,42 +43,43 @@ def main(): print("在进行知识库导入之前") print("请修改config/lpmm_config.toml中的配置项") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") # 检查并创建必要的目录 check_and_create_dirs() - + # 检查输出文件是否存在 if os.path.exists("data/import.json"): print("错误: data/import.json 已存在,请先处理或删除该文件") sys.exit(1) - + if os.path.exists("data/openie.json"): print("错误: data/openie.json 已存在,请先处理或删除该文件") sys.exit(1) - + # 获取所有原始文本文件 raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) if not raw_files: print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) - + # 处理所有文件 all_paragraphs = [] for file in raw_files: print(f"正在处理文件: {file.name}") paragraphs = process_text_file(file) all_paragraphs.extend(paragraphs) - + # 保存合并后的结果 output_path = "data/import.json" with open(output_path, "w", encoding="utf-8") as f: json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) - + print(f"处理完成,结果已保存到: {output_path}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/do_tool/tool_can_use/lpmm_get_knowledge.py b/src/do_tool/tool_can_use/lpmm_get_knowledge.py index 5c70adc1..601d6083 100644 --- a/src/do_tool/tool_can_use/lpmm_get_knowledge.py +++ b/src/do_tool/tool_can_use/lpmm_get_knowledge.py @@ -1,5 +1,6 @@ from src.do_tool.tool_can_use.base_tool import BaseTool from src.plugins.chat.utils import get_embedding + # from src.common.database import db from src.common.logger import get_module_logger from typing import Dict, Any diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py index 31167391..c0d2fe61 100644 --- a/src/plugins/knowledge/knowledge_lib.py +++ b/src/plugins/knowledge/knowledge_lib.py @@ -20,9 +20,7 @@ for key in global_config["llm_providers"]: ) # 初始化Embedding库 -embed_manager = EmbeddingManager( - llm_client_list[global_config["embedding"]["provider"]] -) +embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -62,4 +60,3 @@ inspire_manager = MemoryActiveManager( embed_manager, llm_client_list[global_config["embedding"]["provider"]], ) - diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 59c804fa..e972db57 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -47,9 +47,7 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return self.llm_client.send_embedding_request( - global_config["embedding"]["model"], s - ) + return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) def batch_insert_strs(self, strs: List[str]) -> None: """向库中存入字符串""" @@ -83,14 +81,10 @@ class EmbeddingStore: logger.info(f"{self.namespace}嵌入库保存成功") if self.faiss_index is not None and self.idx2hash is not None: - logger.info( - f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}" - ) + logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}") faiss.write_index(self.faiss_index, self.index_file_path) logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") - logger.info( - f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}" - ) + logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}") with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") @@ -103,24 +97,18 @@ class EmbeddingStore: logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): - self.store[row["hash"]] = EmbeddingStoreItem( - row["hash"], row["embedding"], row["str"] - ) + self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) logger.info(f"{self.namespace}嵌入库加载成功") try: if os.path.exists(self.index_file_path): - logger.info( - f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex" - ) + logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex") self.faiss_index = faiss.read_index(self.index_file_path) logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") else: raise Exception(f"文件{self.index_file_path}不存在") if os.path.exists(self.idx2hash_file_path): - logger.info( - f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射" - ) + logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射") with open(self.idx2hash_file_path, "r") as f: self.idx2hash = json.load(f) logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") @@ -215,9 +203,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs( - [str(triple) for triple in graph_triples] - ) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/global_logger.py b/src/plugins/knowledge/src/global_logger.py index 0311db5f..f7d8297e 100644 --- a/src/plugins/knowledge/src/global_logger.py +++ b/src/plugins/knowledge/src/global_logger.py @@ -7,8 +7,6 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) console_logging_handler = logging.StreamHandler() -console_logging_handler.setFormatter( - logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -) +console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) console_logging_handler.setLevel(logging.DEBUG) logger.addHandler(console_logging_handler) diff --git a/src/plugins/knowledge/src/ie_process.py b/src/plugins/knowledge/src/ie_process.py index 5da9ad9e..3e53e4b2 100644 --- a/src/plugins/knowledge/src/ie_process.py +++ b/src/plugins/knowledge/src/ie_process.py @@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract( - llm_client: LLMClient, paragraph: str, entities: list -) -> List[List[str]]: +def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - _, request_result = llm_client.send_chat_request( - global_config["rdf_build"]["llm"]["model"], entity_extract_context - ) + _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context) # 去除‘{’前的内容(结果中可能有多个‘{’) if "[" in request_result: @@ -60,11 +56,7 @@ def _rdf_triple_extract( entity_extract_result = json.loads(fix_broken_generated_json(request_result)) for triple in entity_extract_result: - if ( - len(triple) != 3 - or (triple[0] is None or triple[1] is None or triple[2] is None) - or "" in triple - ): + if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") return entity_extract_result @@ -91,9 +83,7 @@ def info_extract_from_str( try_count = 0 while True: try: - rdf_triple_extract_result = _rdf_triple_extract( - llm_client_for_rdf, paragraph, entity_extract_result - ) + rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result) break except Exception as e: logger.warning(f"实体提取失败,错误信息:{e}") diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index 4fcdcf80..71ce65ef 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -22,6 +22,7 @@ from .lpmmconfig import ( from .global_logger import logger + class KGManager: def __init__(self): # 会被保存的字段 @@ -35,9 +36,7 @@ class KGManager: # 持久化相关 self.dir_path = global_config["persistence"]["rag_data_dir"] self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" - self.ent_cnt_data_path = ( - self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" - ) + self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" def save_to_file(self): @@ -50,9 +49,7 @@ class KGManager: di_graph.save_to_file(self.graph, self.graph_data_path) # 保存实体计数到文件 - ent_cnt_df = pd.DataFrame( - [{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()] - ) + ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]) ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) # 保存段落hash到文件 @@ -77,9 +74,7 @@ class KGManager: # 加载实体计数 ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") - self.ent_appear_cnt = dict( - {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()} - ) + self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}) # 加载KG self.graph = di_graph.load_from_file(self.graph_data_path) @@ -99,20 +94,14 @@ class KGManager: # 一个triple就是一条边(同时构建双向联系) hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) - node_to_node[(hash_key1, hash_key2)] = ( - node_to_node.get((hash_key1, hash_key2), 0) + 1.0 - ) - node_to_node[(hash_key2, hash_key1)] = ( - node_to_node.get((hash_key2, hash_key1), 0) + 1.0 - ) + node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 + node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) entity_set.add(hash_key2) # 实体出现次数统计 for hash_key in entity_set: - self.ent_appear_cnt[hash_key] = ( - self.ent_appear_cnt.get(hash_key, 0) + 1.0 - ) + self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0 @staticmethod def _build_edges_between_ent_pg( @@ -124,9 +113,7 @@ class KGManager: for triple in triple_list_data[idx]: ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) pg_hash_key = PG_NAMESPACE + "-" + str(idx) - node_to_node[(ent_hash_key, pg_hash_key)] = ( - node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 - ) + node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod def _synonym_connect( @@ -175,9 +162,7 @@ class KGManager: new_edge_cnt += 1 res_ent.append( ( - embedding_manager.entities_embedding_store.store[ - res_ent_hash - ].str, + embedding_manager.entities_embedding_store.store[res_ent_hash].str, similarity, ) ) # Debug @@ -235,9 +220,7 @@ class KGManager: if node_hash not in existed_nodes: if node_hash.startswith(ENT_NAMESPACE): # 新增实体节点 - node = embedding_manager.entities_embedding_store.store[ - node_hash - ] + node = embedding_manager.entities_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) node_item = self.graph[node_hash] node_item["content"] = node.str @@ -246,15 +229,11 @@ class KGManager: self.graph.update_node(node_item) elif node_hash.startswith(PG_NAMESPACE): # 新增文段节点 - node = embedding_manager.paragraphs_embedding_store.store[ - node_hash - ] + node = embedding_manager.paragraphs_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) content = node.str.replace("\n", " ") node_item = self.graph[node_hash] - node_item["content"] = ( - content if len(content) < 8 else content[:8] + "..." - ) + node_item["content"] = content if len(content) < 8 else content[:8] + "..." node_item["type"] = "pg" node_item["create_time"] = now_time self.graph.update_node(node_item) @@ -324,9 +303,7 @@ class KGManager: ent_sim_scores = {} for relation_hash, similarity, _ in relation_search_result: # 提取主宾短语 - relation = embed_manager.relation_embedding_store.store.get( - relation_hash - ).str + relation = embed_manager.relation_embedding_store.store.get(relation_hash).str assert relation is not None # 断言:relation不为空 # 关系三元组 triple = relation[2:-2].split("', '") @@ -340,9 +317,7 @@ class KGManager: ent_mean_scores = {} # 记录实体的平均相似度 for ent_hash, scores in ent_sim_scores.items(): # 先对相似度进行累加,然后与实体计数相除获取最终权重 - ent_weights[ent_hash] = ( - float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] - ) + ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] # 记录实体的平均相似度,用于后续的top_k筛选 ent_mean_scores[ent_hash] = float(np.mean(scores)) del ent_sim_scores @@ -359,21 +334,14 @@ class KGManager: for ent_hash, score in ent_weights.items(): # 缩放相似度 ent_weights[ent_hash] = ( - (score - ent_weights_min) - * (1 - down_edge) - / (ent_weights_max - ent_weights_min) + (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min) ) + down_edge # 取平均相似度的top_k实体 top_k = global_config["qa"]["params"]["ent_filter_top_k"] if len(ent_mean_scores) > top_k: # 从大到小排序,取后len - k个 - ent_mean_scores = { - k: v - for k, v in sorted( - ent_mean_scores.items(), key=lambda item: item[1], reverse=True - ) - } + ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} for ent_hash, _ in ent_mean_scores.items(): # 删除被淘汰的实体节点权重设置 del ent_weights[ent_hash] @@ -394,9 +362,7 @@ class KGManager: # 归一化 for pg_hash, similarity in pg_sim_scores.items(): # 归一化相似度 - pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / ( - pg_sim_score_max - pg_sim_score_min - ) + pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min) del pg_sim_score_max, pg_sim_score_min for pg_hash, score in pg_sim_scores.items(): @@ -406,9 +372,7 @@ class KGManager: del pg_sim_scores # 最终权重数据 = 实体权重 + 文段权重 - ppr_node_weights = { - k: v for d in [ent_weights, pg_weights] for k, v in d.items() - } + ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()} del ent_weights, pg_weights # PersonalizedPageRank @@ -422,15 +386,11 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) - for node_key, score in ppr_res.items() - if node_key.startswith(PG_NAMESPACE) + (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE) ] del ppr_res # 排序:按照分数从大到小 - passage_node_res = sorted( - passage_node_res, key=lambda item: item[1], reverse=True - ) + passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True) return passage_node_res, ppr_node_weights diff --git a/src/plugins/knowledge/src/llm_client.py b/src/plugins/knowledge/src/llm_client.py index 3036662c..52d0dca0 100644 --- a/src/plugins/knowledge/src/llm_client.py +++ b/src/plugins/knowledge/src/llm_client.py @@ -21,20 +21,14 @@ class LLMClient: def send_chat_request(self, model, messages): """发送对话请求,等待返回结果""" - response = self.client.chat.completions.create( - model=model, messages=messages, stream=False - ) + response = self.client.chat.completions.create(model=model, messages=messages, stream=False) if hasattr(response.choices[0].message, "reasoning_content"): # 有单独的推理内容块 reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content else: # 无单独的推理内容块 - response = ( - response.choices[0] - .message.content.split("")[-1] - .split("") - ) + response = response.choices[0].message.content.split("")[-1].split("") # 如果有推理内容,则分割推理内容和内容 if len(response) == 2: reasoning_content = response[0] @@ -48,6 +42,4 @@ class LLMClient: def send_embedding_request(self, model, text): """发送嵌入请求,等待返回结果""" text = text.replace("\n", " ") - return ( - self.client.embeddings.create(input=[text], model=model).data[0].embedding - ) + return self.client.embeddings.create(input=[text], model=model).data[0].embedding diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py index ff1ac8fa..7f59bc89 100644 --- a/src/plugins/knowledge/src/lpmmconfig.py +++ b/src/plugins/knowledge/src/lpmmconfig.py @@ -65,7 +65,6 @@ def _load_config(config, config_file_path): config["persistence"] = file_config["persistence"] print(config) print("Configurations loaded from file: ", config_file_path) - parser = argparse.ArgumentParser(description="Configurations for the pipeline") @@ -129,9 +128,9 @@ global_config = dict( "embedding_data_dir": "data/embedding", "rag_data_dir": "data/rag", }, - "info_extraction":{ + "info_extraction": { "workers": 10, - } + }, } ) diff --git a/src/plugins/knowledge/src/mem_active_manager.py b/src/plugins/knowledge/src/mem_active_manager.py index 073e3bda..3998c066 100644 --- a/src/plugins/knowledge/src/mem_active_manager.py +++ b/src/plugins/knowledge/src/mem_active_manager.py @@ -16,13 +16,9 @@ class MemoryActiveManager: def get_activation(self, question: str) -> float: """获取记忆激活度""" # 生成问题的Embedding - question_embedding = self.embedding_client.send_embedding_request( - "text-embedding", question - ) + question_embedding = self.embedding_client.send_embedding_request("text-embedding", question) # 查询关系库中的相似度 - rel_search_res = self.embed_manager.relation_embedding_store.search_top_k( - question_embedding, 10 - ) + rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10) # 动态过滤阈值 rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0) diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index 58259ef8..5fe163bb 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]: """过滤无效的实体""" valid_entities = set() for entity in entities: - if ( - not isinstance(entity, str) - or entity.strip() == "" - or entity in INVALID_ENTITY - or entity in valid_entities - ): + if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities: # 非字符串/空字符串/在无效实体列表中/重复 continue valid_entities.add(entity) @@ -74,9 +69,7 @@ class OpenIE: for doc in self.docs: # 过滤实体列表 - doc["extracted_entities"] = _filter_invalid_entities( - doc["extracted_entities"] - ) + doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"]) # 过滤无效的三元组 doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) @@ -100,9 +93,7 @@ class OpenIE: @staticmethod def load() -> "OpenIE": """从文件中加载OpenIE数据""" - with open( - global_config["persistence"]["openie_data_path"], "r", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: data = json.loads(f.read()) openie_data = OpenIE._from_dict(data) @@ -112,9 +103,7 @@ class OpenIE: @staticmethod def save(openie_data: "OpenIE"): """保存OpenIE数据到文件""" - with open( - global_config["persistence"]["openie_data_path"], "w", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f: f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) def extract_entity_dict(self): @@ -141,7 +130,5 @@ class OpenIE: def extract_raw_paragraph_dict(self): """提取原始段落""" - raw_paragraph_dict = dict( - {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs} - ) + raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict diff --git a/src/plugins/knowledge/src/prompt_template.py b/src/plugins/knowledge/src/prompt_template.py index ab6ac08c..18a5002e 100644 --- a/src/plugins/knowledge/src/prompt_template.py +++ b/src/plugins/knowledge/src/prompt_template.py @@ -41,9 +41,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描 def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]: messages = [ LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), - LLMMessage( - "user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""" - ).to_dict(), + LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(), ] return messages @@ -58,16 +56,10 @@ qa_system_prompt = """ """ -def build_qa_context( - question: str, knowledge: list[(str, str, str)] -) -> List[LLMMessage]: - knowledge = "\n".join( - [f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)] - ) +def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]: + knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) messages = [ LLMMessage("system", qa_system_prompt).to_dict(), - LLMMessage( - "user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}" - ).to_dict(), + LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), ] return messages diff --git a/src/plugins/knowledge/src/qa_manager.py b/src/plugins/knowledge/src/qa_manager.py index 986c8419..894c0124 100644 --- a/src/plugins/knowledge/src/qa_manager.py +++ b/src/plugins/knowledge/src/qa_manager.py @@ -2,6 +2,7 @@ import time from typing import Tuple, List, Dict from .global_logger import logger + # from . import prompt_template from .embedding_store import EmbeddingManager from .llm_client import LLMClient @@ -31,7 +32,7 @@ class QAManager: """处理查询""" # 生成问题的Embedding - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() question_embedding = self.llm_client_list["embedding"].send_embedding_request( global_config["embedding"]["model"], question ) @@ -39,7 +40,7 @@ class QAManager: logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") # 根据问题Embedding查询Relation Embedding库 - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( question_embedding, global_config["qa"]["params"]["relation_search_top_k"], @@ -47,10 +48,7 @@ class QAManager: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if ( - relation_search_res[0][1] - < global_config["qa"]["params"]["relation_threshold"] - ): + if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]: # 未找到相关关系 relation_search_res = [] @@ -66,12 +64,10 @@ class QAManager: # part_start_time = time.time() # 根据问题Embedding查询Paragraph Embedding库 - part_start_time =time.perf_counter() - paragraph_search_res = ( - self.embed_manager.paragraphs_embedding_store.search_top_k( - question_embedding, - global_config["qa"]["params"]["paragraph_search_top_k"], - ) + part_start_time = time.perf_counter() + paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( + question_embedding, + global_config["qa"]["params"]["paragraph_search_top_k"], ) part_end_time = time.perf_counter() logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") @@ -79,7 +75,7 @@ class QAManager: if len(relation_search_res) != 0: logger.info("找到相关关系,将使用RAG进行检索") # 使用KG检索 - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() result, ppr_node_weights = self.kg_manager.kg_search( relation_search_res, paragraph_search_res, self.embed_manager ) @@ -94,9 +90,7 @@ class QAManager: result = dyn_select_top_k(result, 0.5, 1.0) for res in result: - raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[ - res[0] - ].str + raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") return result, ppr_node_weights @@ -113,5 +107,7 @@ class QAManager: ) for res in query_res ] - found_knowledge = "\n".join([f"第{i + 1}条知识:{k[1]}\n 该条知识对于问题的相关性:{k[0]}" for i, k in enumerate(knowledge)]) - return found_knowledge \ No newline at end of file + found_knowledge = "\n".join( + [f"第{i + 1}条知识:{k[1]}\n 该条知识对于问题的相关性:{k[0]}" for i, k in enumerate(knowledge)] + ) + return found_knowledge diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py index 75e4b863..91e681c7 100644 --- a/src/plugins/knowledge/src/raw_processing.py +++ b/src/plugins/knowledge/src/raw_processing.py @@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]: """ # 读取import.json文件 if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: - with open( - global_config["persistence"]["raw_data_path"], "r", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: import_json = json.loads(f.read()) else: raise Exception("原始数据文件读取失败") diff --git a/src/plugins/knowledge/src/utils/data_loader.py b/src/plugins/knowledge/src/utils/data_loader.py index 08bdf414..3b5e8d2e 100644 --- a/src/plugins/knowledge/src/utils/data_loader.py +++ b/src/plugins/knowledge/src/utils/data_loader.py @@ -14,11 +14,7 @@ class DataLoader: Args: custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径 """ - self.data_dir = ( - Path(custom_data_dir) - if custom_data_dir - else Path(config["persistence"]["data_root_path"]) - ) + self.data_dir = Path(custom_data_dir) if custom_data_dir else Path(config["persistence"]["data_root_path"]) if not self.data_dir.exists(): raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在") diff --git a/src/plugins/knowledge/src/utils/dyn_topk.py b/src/plugins/knowledge/src/utils/dyn_topk.py index 02bc5e3e..eb40ef3a 100644 --- a/src/plugins/knowledge/src/utils/dyn_topk.py +++ b/src/plugins/knowledge/src/utils/dyn_topk.py @@ -36,14 +36,10 @@ def dyn_select_top_k( # 计算均值 mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score) # 计算方差 - var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len( - normalized_score - ) + var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score) # 动态阈值 - threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * ( - mean_score + var_factor * var_score - ) + threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score) # 重新过滤 res = [s for s in normalized_score if s[2] > threshold] diff --git a/src/plugins/knowledge/src/utils/json_fix.py b/src/plugins/knowledge/src/utils/json_fix.py index 672fb1f8..a83eb491 100644 --- a/src/plugins/knowledge/src/utils/json_fix.py +++ b/src/plugins/knowledge/src/utils/json_fix.py @@ -29,10 +29,7 @@ def _find_unclosed(json_str): elif char in "{[": unclosed.append(char) elif char in "}]": - if unclosed and ( - (char == "}" and unclosed[-1] == "{") - or (char == "]" and unclosed[-1] == "[") - ): + if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")): unclosed.pop() return unclosed diff --git a/src/plugins/knowledge/src/utils/visualize_graph.py b/src/plugins/knowledge/src/utils/visualize_graph.py index 845e18a2..7ca9b7e6 100644 --- a/src/plugins/knowledge/src/utils/visualize_graph.py +++ b/src/plugins/knowledge/src/utils/visualize_graph.py @@ -14,4 +14,4 @@ def draw_graph_and_show(graph): font_family="Sarasa Mono SC", font_size=8, ) - fig.show() \ No newline at end of file + fig.show()