diff --git a/import_openie.py b/import_openie.py index 62150259..f958e34e 100644 --- a/import_openie.py +++ b/import_openie.py @@ -20,6 +20,7 @@ import sys logger = get_module_logger("LPMM知识库-OpenIE导入") + def hash_deduplicate( raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], @@ -43,14 +44,10 @@ def hash_deduplicate( # 保存去重后的三元组 new_triple_list_data = dict() - for _, (raw_paragraph, triple_list) in enumerate( - zip(raw_paragraphs.values(), triple_list_data.values()) - ): + for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and ( - paragraph_hash in stored_paragraph_hashes - ): + if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes): continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -58,9 +55,7 @@ def hash_deduplicate( return new_raw_paragraphs, new_triple_list_data -def handle_import_openie( - openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager -) -> bool: +def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: # 从OpenIE数据中提取段落原文与三元组列表 # 索引的段落原文 raw_paragraphs = openie_data.extract_raw_paragraph_dict() @@ -68,9 +63,7 @@ def handle_import_openie( entity_list_data = openie_data.extract_entity_dict() # 索引的三元组列表 triple_list_data = openie_data.extract_triple_dict() - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len( - triple_list_data - ): + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): logger.error("OpenIE数据存在异常") return False # 将索引换为对应段落的hash值 @@ -112,11 +105,11 @@ def main(): print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": logger.info("用户取消操作") print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") logger.info("----开始导入openie数据----\n") @@ -129,9 +122,7 @@ def main(): ) # 初始化Embedding库 - embed_manager = embed_manager = EmbeddingManager( - llm_client_list[global_config["embedding"]["provider"]] - ) + embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() diff --git a/info_extraction.py b/info_extraction.py index 04f2dd4f..b6ad8a9c 100644 --- a/info_extraction.py +++ b/info_extraction.py @@ -91,11 +91,11 @@ def main(): print("或者使用可以用赠金抵扣的Pro模型") print("请确保账户余额充足,并且在执行前确认无误。") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": logger.info("用户取消操作") print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") logger.info("--------进行信息提取--------\n") diff --git a/raw_data_preprocessor.py b/raw_data_preprocessor.py index abb1ad64..645a8506 100644 --- a/raw_data_preprocessor.py +++ b/raw_data_preprocessor.py @@ -3,18 +3,17 @@ import os from pathlib import Path import sys # 新增系统模块导入 + def check_and_create_dirs(): """检查并创建必要的目录""" - required_dirs = [ - "data/lpmm_raw_data", - "data/imported_lpmm_data" - ] - + required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] + for dir_path in required_dirs: if not os.path.exists(dir_path): os.makedirs(dir_path) print(f"已创建目录: {dir_path}") + def process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: @@ -29,12 +28,13 @@ def process_text_file(file_path): paragraph = "" else: paragraph += line + "\n" - + if paragraph != "": paragraphs.append(paragraph.strip()) - + return paragraphs + def main(): # 新增用户确认提示 print("=== 重要操作确认 ===") @@ -43,42 +43,43 @@ def main(): print("在进行知识库导入之前") print("请修改config/lpmm_config.toml中的配置项") confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != 'y': + if confirm != "y": print("操作已取消") sys.exit(1) - print("\n" + "="*40 + "\n") + print("\n" + "=" * 40 + "\n") # 检查并创建必要的目录 check_and_create_dirs() - + # 检查输出文件是否存在 if os.path.exists("data/import.json"): print("错误: data/import.json 已存在,请先处理或删除该文件") sys.exit(1) - + if os.path.exists("data/openie.json"): print("错误: data/openie.json 已存在,请先处理或删除该文件") sys.exit(1) - + # 获取所有原始文本文件 raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) if not raw_files: print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) - + # 处理所有文件 all_paragraphs = [] for file in raw_files: print(f"正在处理文件: {file.name}") paragraphs = process_text_file(file) all_paragraphs.extend(paragraphs) - + # 保存合并后的结果 output_path = "data/import.json" with open(output_path, "w", encoding="utf-8") as f: json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) - + print(f"处理完成,结果已保存到: {output_path}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/do_tool/tool_can_use/lpmm_get_knowledge.py b/src/do_tool/tool_can_use/lpmm_get_knowledge.py index d91d697c..a7b9b1b9 100644 --- a/src/do_tool/tool_can_use/lpmm_get_knowledge.py +++ b/src/do_tool/tool_can_use/lpmm_get_knowledge.py @@ -1,5 +1,6 @@ from src.do_tool.tool_can_use.base_tool import BaseTool from src.plugins.chat.utils import get_embedding + # from src.common.database import db from src.common.logger import get_module_logger from typing import Dict, Any diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index 63f4be7f..e94931fb 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -364,7 +364,7 @@ class PromptBuilder: # grouped_results[topic] = [] # grouped_results[topic].append(result) - # 按主题组织输出 + # 按主题组织输出 # for topic, results in grouped_results.items(): # related_info += f"【主题: {topic}】\n" # for _i, result in enumerate(results, 1): diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py index 24d06968..4edd519a 100644 --- a/src/plugins/knowledge/knowledge_lib.py +++ b/src/plugins/knowledge/knowledge_lib.py @@ -22,9 +22,7 @@ for key in global_config["llm_providers"]: ) # 初始化Embedding库 -embed_manager = EmbeddingManager( - llm_client_list[global_config["embedding"]["provider"]] -) +embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -63,4 +61,3 @@ inspire_manager = MemoryActiveManager( embed_manager, llm_client_list[global_config["embedding"]["provider"]], ) - diff --git a/src/plugins/knowledge/lib/quick_algo/__init__.py b/src/plugins/knowledge/lib/quick_algo/__init__.py index 612d1970..64a498de 100644 --- a/src/plugins/knowledge/lib/quick_algo/__init__.py +++ b/src/plugins/knowledge/lib/quick_algo/__init__.py @@ -23,9 +23,7 @@ def _nx_graph_to_lists( A tuple containing the list of edges and the list of nodes. """ nodes = [node for node in graph.nodes()] - edges = [ - (u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges() - ] + edges = [(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()] return edges, nodes diff --git a/src/plugins/knowledge/lib/quick_algo/pagerank_py.py b/src/plugins/knowledge/lib/quick_algo/pagerank_py.py index ec51fa1d..dd11cbd6 100644 --- a/src/plugins/knowledge/lib/quick_algo/pagerank_py.py +++ b/src/plugins/knowledge/lib/quick_algo/pagerank_py.py @@ -9,7 +9,7 @@ def pagerank_py( personalization: Optional[Dict[str, float]] = None, alpha: float = 0.85, max_iter: int = 100, - tol: float = 1e-6 + tol: float = 1e-6, ) -> Dict[str, float]: """使用 Python、NumPy 和 SciPy 计算个性化 PageRank。 @@ -44,14 +44,14 @@ def pagerank_py( raw_values = np.maximum(raw_values, 0) norm_sum = np.sum(raw_values) - if norm_sum > 1e-9: # 避免除以零 + if norm_sum > 1e-9: # 避免除以零 personalization_vec = raw_values / norm_sum else: # 如果所有提供的个性化值都为零或负数,则回退到均匀分布 print("警告:个性化值总和为零或所有值均为非正数。回退到均匀个性化设置。") personalization_vec.fill(1.0 / num_nodes) - # --- 构建稀疏邻接矩阵 --- + # --- 构建稀疏邻接矩阵 --- # 标准 PageRank 需要基于出度的归一化 row_ind = [] col_ind = [] @@ -66,9 +66,9 @@ def pagerank_py( col_ind.append(src_idx) # 暂存原始权重,如果需要加权 PageRank,可以在此使用 w # 对于标准 PageRank,我们只需要知道连接存在 - data.append(1.0) # 初始数据设为 1,之后归一化 + data.append(1.0) # 初始数据设为 1,之后归一化 # 标准 PageRank 的出度是边的数量,加权 PageRank 可以用 w - out_degree[src_idx] += 1 + out_degree[src_idx] += 1 # 归一化权重(构建转移矩阵 M 的转置 M.T) # M[j, i] 是从 i 到 j 的概率 @@ -82,17 +82,16 @@ def pagerank_py( if out_degree[c] > 0: # 标准 PageRank: 1.0 / out_degree[c] # 如果要用原始权重 w 作为转移概率(需确保它们已归一化),则用 w / sum(w for edges from c) - normalized_data.append(d / out_degree[c]) - new_row_ind.append(c) # M.T 的行索引是 src_idx - new_col_ind.append(r) # M.T 的列索引是 dst_idx + normalized_data.append(d / out_degree[c]) + new_row_ind.append(c) # M.T 的行索引是 src_idx + new_col_ind.append(r) # M.T 的列索引是 dst_idx # 创建稀疏矩阵 (M.T) # 注意:scipy.sparse 期望 (data, (row_ind, col_ind)) 格式 # 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ... if len(normalized_data) > 0: # 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法) - M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), - shape=(num_nodes, num_nodes)) + M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), shape=(num_nodes, num_nodes)) else: M_T = sp.csc_matrix((num_nodes, num_nodes)) @@ -109,44 +108,45 @@ def pagerank_py( # 还有一种做法是仅分配给个性化向量中非零的节点 # --- PageRank 迭代 --- - scores = personalization_vec.copy() # 从个性化向量开始 + scores = personalization_vec.copy() # 从个性化向量开始 for iteration in range(max_iter): prev_scores = scores.copy() - + # 计算来自链接的贡献 linked_scores = M_T @ scores - + # 计算来自悬挂节点的贡献 # 悬挂节点的总分数 * 悬挂权重向量 dangling_sum = np.sum(scores[is_dangling]) dangling_contribution = dangling_sum * dangling_weights - + # 结合瞬移、链接贡献和悬挂节点贡献 scores = alpha * (linked_scores + dangling_contribution) + (1 - alpha) * personalization_vec - + # 检查收敛性 (L1 范数) diff = np.sum(np.abs(scores - prev_scores)) if diff < tol: print(f"在 {iteration + 1} 次迭代后收敛。") break - else: # 循环完成但未中断 + else: # 循环完成但未中断 print(f"达到最大迭代次数 ({max_iter}) 但未收敛。") # --- 格式化输出 --- result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)} return result_dict + # --- 示例用法(类似于 pr.c 中的 main)--- if __name__ == "__main__": nodes_test = ["0", "1", "2", "3", "4"] edges_test = [ - ("0", "1", 0.5), # 权重在此实现中仅用于确定出度 + ("0", "1", 0.5), # 权重在此实现中仅用于确定出度 ("1", "2", 0.3), ("2", "0", 0.2), ("1", "3", 0.4), ("3", "4", 0.6), - ("4", "1", 0.7) + ("4", "1", 0.7), ] # 添加一个悬挂节点示例 nodes_test.append("5") @@ -161,37 +161,33 @@ if __name__ == "__main__": print("运行优化的 Python PageRank 实现...") result = pagerank_py( - nodes_test, - edges_test, - personalization_test, - alpha=alpha_test, - max_iter=max_iter_test, - tol=tol_test + nodes_test, edges_test, personalization_test, alpha=alpha_test, max_iter=max_iter_test, tol=tol_test ) print("\nPageRank 分数:") # 按节点索引排序以获得一致的输出 sorted_nodes = sorted(result.keys(), key=lambda x: int(x)) for node_id in sorted_nodes: - print(f"节点 {node_id}: {result[node_id]:.6f}") + print(f"节点 {node_id}: {result[node_id]:.6f}") print("\n使用默认个性化设置运行...") result_default_pers = pagerank_py( nodes_test, edges_test, - personalization=None, # 使用默认的统一性化设置 + personalization=None, # 使用默认的统一性化设置 alpha=alpha_test, max_iter=max_iter_test, - tol=tol_test + tol=tol_test, ) print("\nPageRank 分数(默认个性化):") sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x)) for node_id in sorted_nodes_default: - print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}") + print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}") # 与 NetworkX 对比 (如果安装了) try: import networkx as nx + print("\n与 NetworkX PageRank 对比 (个性化)...") G = nx.DiGraph() G.add_nodes_from(nodes_test) @@ -200,25 +196,29 @@ if __name__ == "__main__": # 为了更接近我们的实现,我们不传递权重给 add_edges_from edges_for_nx = [(u, v) for u, v, w in edges_test] G.add_edges_from(edges_for_nx) - + # 归一化 NetworkX 的个性化向量 nx_pers = {node: personalization_test.get(node, 0.0) for node in nodes_test} pers_sum = sum(nx_pers.values()) if pers_sum > 0: nx_pers = {k: v / pers_sum for k, v in nx_pers.items()} - else: # 如果全为0,NetworkX 会报错或行为未定义,我们设为 None + else: # 如果全为0,NetworkX 会报错或行为未定义,我们设为 None nx_pers = None - nx_result = nx.pagerank(G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None) # weight=None 强制标准 PageRank + nx_result = nx.pagerank( + G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None + ) # weight=None 强制标准 PageRank for node_id in sorted_nodes: print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}") print("\n与 NetworkX PageRank 对比 (默认)...") - nx_result_default = nx.pagerank(G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None) + nx_result_default = nx.pagerank( + G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None + ) for node_id in sorted_nodes_default: print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}") except ImportError: print("\n未安装 NetworkX,跳过对比。") except Exception as e: - print(f"\n运行 NetworkX PageRank 时出错: {e}") \ No newline at end of file + print(f"\n运行 NetworkX PageRank 时出错: {e}") diff --git a/src/plugins/knowledge/src/config.py b/src/plugins/knowledge/src/config.py index fa42a481..8771949e 100644 --- a/src/plugins/knowledge/src/config.py +++ b/src/plugins/knowledge/src/config.py @@ -1,6 +1,7 @@ import os import toml from .global_logger import logger + PG_NAMESPACE = "paragraph" ENT_NAMESPACE = "entity" REL_NAMESPACE = "relation" @@ -58,6 +59,7 @@ def _load_config(config, config_file_path): logger.info(f"Configurations loaded from file: {config_file_path}") + global_config = dict( { "llm_providers": { @@ -119,4 +121,4 @@ file_path = os.path.abspath(__file__) dir_path = os.path.dirname(file_path) root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir) config_path = os.path.join(root_path, "config", "lpmm_config.toml") -_load_config(global_config, config_path) \ No newline at end of file +_load_config(global_config, config_path) diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index 59c804fa..e972db57 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -47,9 +47,7 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return self.llm_client.send_embedding_request( - global_config["embedding"]["model"], s - ) + return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) def batch_insert_strs(self, strs: List[str]) -> None: """向库中存入字符串""" @@ -83,14 +81,10 @@ class EmbeddingStore: logger.info(f"{self.namespace}嵌入库保存成功") if self.faiss_index is not None and self.idx2hash is not None: - logger.info( - f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}" - ) + logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}") faiss.write_index(self.faiss_index, self.index_file_path) logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") - logger.info( - f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}" - ) + logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}") with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") @@ -103,24 +97,18 @@ class EmbeddingStore: logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): - self.store[row["hash"]] = EmbeddingStoreItem( - row["hash"], row["embedding"], row["str"] - ) + self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) logger.info(f"{self.namespace}嵌入库加载成功") try: if os.path.exists(self.index_file_path): - logger.info( - f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex" - ) + logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex") self.faiss_index = faiss.read_index(self.index_file_path) logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") else: raise Exception(f"文件{self.index_file_path}不存在") if os.path.exists(self.idx2hash_file_path): - logger.info( - f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射" - ) + logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射") with open(self.idx2hash_file_path, "r") as f: self.idx2hash = json.load(f) logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") @@ -215,9 +203,7 @@ class EmbeddingManager: for triples in triple_list_data.values(): graph_triples.extend([tuple(t) for t in triples]) graph_triples = list(set(graph_triples)) - self.relation_embedding_store.batch_insert_strs( - [str(triple) for triple in graph_triples] - ) + self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) def load_from_file(self): """从文件加载""" diff --git a/src/plugins/knowledge/src/global_logger.py b/src/plugins/knowledge/src/global_logger.py index 0311db5f..f7d8297e 100644 --- a/src/plugins/knowledge/src/global_logger.py +++ b/src/plugins/knowledge/src/global_logger.py @@ -7,8 +7,6 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) console_logging_handler = logging.StreamHandler() -console_logging_handler.setFormatter( - logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") -) +console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) console_logging_handler.setLevel(logging.DEBUG) logger.addHandler(console_logging_handler) diff --git a/src/plugins/knowledge/src/ie_process.py b/src/plugins/knowledge/src/ie_process.py index 72435dc1..a959568e 100644 --- a/src/plugins/knowledge/src/ie_process.py +++ b/src/plugins/knowledge/src/ie_process.py @@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract( - llm_client: LLMClient, paragraph: str, entities: list -) -> List[List[str]]: +def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - _, request_result = llm_client.send_chat_request( - global_config["rdf_build"]["llm"]["model"], entity_extract_context - ) + _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context) # 去除‘{’前的内容(结果中可能有多个‘{’) if "[" in request_result: @@ -60,11 +56,7 @@ def _rdf_triple_extract( entity_extract_result = json.loads(fix_broken_generated_json(request_result)) for triple in entity_extract_result: - if ( - len(triple) != 3 - or (triple[0] is None or triple[1] is None or triple[2] is None) - or "" in triple - ): + if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") return entity_extract_result @@ -91,9 +83,7 @@ def info_extract_from_str( try_count = 0 while True: try: - rdf_triple_extract_result = _rdf_triple_extract( - llm_client_for_rdf, paragraph, entity_extract_result - ) + rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result) break except Exception as e: logger.warning(f"实体提取失败,错误信息:{e}") diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index f0e6090d..848dd9b7 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -22,6 +22,7 @@ from .config import ( from .global_logger import logger + class KGManager: def __init__(self): # 会被保存的字段 @@ -35,9 +36,7 @@ class KGManager: # 持久化相关 self.dir_path = global_config["persistence"]["rag_data_dir"] self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz" - self.ent_cnt_data_path = ( - self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" - ) + self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" def save_to_file(self): @@ -50,9 +49,7 @@ class KGManager: nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8") # 保存实体计数到文件 - ent_cnt_df = pd.DataFrame( - [{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()] - ) + ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]) ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) # 保存段落hash到文件 @@ -77,9 +74,7 @@ class KGManager: # 加载实体计数 ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") - self.ent_appear_cnt = dict( - {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()} - ) + self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}) # 加载KG self.graph = nx.read_graphml(self.graph_data_path) @@ -101,20 +96,14 @@ class KGManager: # 一个triple就是一条边(同时构建双向联系) hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) - node_to_node[(hash_key1, hash_key2)] = ( - node_to_node.get((hash_key1, hash_key2), 0) + 1.0 - ) - node_to_node[(hash_key2, hash_key1)] = ( - node_to_node.get((hash_key2, hash_key1), 0) + 1.0 - ) + node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 + node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) entity_set.add(hash_key2) # 实体出现次数统计 for hash_key in entity_set: - self.ent_appear_cnt[hash_key] = ( - self.ent_appear_cnt.get(hash_key, 0) + 1.0 - ) + self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0 @staticmethod def _build_edges_between_ent_pg( @@ -126,9 +115,7 @@ class KGManager: for triple in triple_list_data[idx]: ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) pg_hash_key = PG_NAMESPACE + "-" + str(idx) - node_to_node[(ent_hash_key, pg_hash_key)] = ( - node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 - ) + node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod def _synonym_connect( @@ -177,9 +164,7 @@ class KGManager: new_edge_cnt += 1 res_ent.append( ( - embedding_manager.entities_embedding_store.store[ - res_ent_hash - ].str, + embedding_manager.entities_embedding_store.store[res_ent_hash].str, similarity, ) ) # Debug @@ -236,22 +221,16 @@ class KGManager: if node_hash not in existed_nodes: if node_hash.startswith(ENT_NAMESPACE): # 新增实体节点 - node = embedding_manager.entities_embedding_store.store[ - node_hash - ] + node = embedding_manager.entities_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) self.graph.nodes[node_hash]["content"] = node.str self.graph.nodes[node_hash]["type"] = "ent" elif node_hash.startswith(PG_NAMESPACE): # 新增文段节点 - node = embedding_manager.paragraphs_embedding_store.store[ - node_hash - ] + node = embedding_manager.paragraphs_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) content = node.str.replace("\n", " ") - self.graph.nodes[node_hash]["content"] = ( - content if len(content) < 8 else content[:8] + "..." - ) + self.graph.nodes[node_hash]["content"] = content if len(content) < 8 else content[:8] + "..." self.graph.nodes[node_hash]["type"] = "pg" def build_kg( @@ -319,9 +298,7 @@ class KGManager: ent_sim_scores = {} for relation_hash, similarity, _ in relation_search_result: # 提取主宾短语 - relation = embed_manager.relation_embedding_store.store.get( - relation_hash - ).str + relation = embed_manager.relation_embedding_store.store.get(relation_hash).str assert relation is not None # 断言:relation不为空 # 关系三元组 triple = relation[2:-2].split("', '") @@ -335,9 +312,7 @@ class KGManager: ent_mean_scores = {} # 记录实体的平均相似度 for ent_hash, scores in ent_sim_scores.items(): # 先对相似度进行累加,然后与实体计数相除获取最终权重 - ent_weights[ent_hash] = ( - float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] - ) + ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] # 记录实体的平均相似度,用于后续的top_k筛选 ent_mean_scores[ent_hash] = float(np.mean(scores)) del ent_sim_scores @@ -354,21 +329,14 @@ class KGManager: for ent_hash, score in ent_weights.items(): # 缩放相似度 ent_weights[ent_hash] = ( - (score - ent_weights_min) - * (1 - down_edge) - / (ent_weights_max - ent_weights_min) + (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min) ) + down_edge # 取平均相似度的top_k实体 top_k = global_config["qa"]["params"]["ent_filter_top_k"] if len(ent_mean_scores) > top_k: # 从大到小排序,取后len - k个 - ent_mean_scores = { - k: v - for k, v in sorted( - ent_mean_scores.items(), key=lambda item: item[1], reverse=True - ) - } + ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} for ent_hash, _ in ent_mean_scores.items(): # 删除被淘汰的实体节点权重设置 del ent_weights[ent_hash] @@ -389,9 +357,7 @@ class KGManager: # 归一化 for pg_hash, similarity in pg_sim_scores.items(): # 归一化相似度 - pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / ( - pg_sim_score_max - pg_sim_score_min - ) + pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min) del pg_sim_score_max, pg_sim_score_min for pg_hash, score in pg_sim_scores.items(): @@ -401,9 +367,7 @@ class KGManager: del pg_sim_scores # 最终权重数据 = 实体权重 + 文段权重 - ppr_node_weights = { - k: v for d in [ent_weights, pg_weights] for k, v in d.items() - } + ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()} del ent_weights, pg_weights # PersonalizedPageRank @@ -418,14 +382,12 @@ class KGManager: # 从搜索结果中提取文段节点的结果 passage_node_res = [ (node_key, score) - for node_key, score in ppr_res.items() # Iterate over dictionary items + for node_key, score in ppr_res.items() # Iterate over dictionary items if node_key.startswith(PG_NAMESPACE) ] del ppr_res # 排序:按照分数从大到小 - passage_node_res = sorted( - passage_node_res, key=lambda item: item[1], reverse=True - ) + passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True) return passage_node_res, ppr_node_weights diff --git a/src/plugins/knowledge/src/llm_client.py b/src/plugins/knowledge/src/llm_client.py index 3036662c..52d0dca0 100644 --- a/src/plugins/knowledge/src/llm_client.py +++ b/src/plugins/knowledge/src/llm_client.py @@ -21,20 +21,14 @@ class LLMClient: def send_chat_request(self, model, messages): """发送对话请求,等待返回结果""" - response = self.client.chat.completions.create( - model=model, messages=messages, stream=False - ) + response = self.client.chat.completions.create(model=model, messages=messages, stream=False) if hasattr(response.choices[0].message, "reasoning_content"): # 有单独的推理内容块 reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content else: # 无单独的推理内容块 - response = ( - response.choices[0] - .message.content.split("")[-1] - .split("") - ) + response = response.choices[0].message.content.split("")[-1].split("") # 如果有推理内容,则分割推理内容和内容 if len(response) == 2: reasoning_content = response[0] @@ -48,6 +42,4 @@ class LLMClient: def send_embedding_request(self, model, text): """发送嵌入请求,等待返回结果""" text = text.replace("\n", " ") - return ( - self.client.embeddings.create(input=[text], model=model).data[0].embedding - ) + return self.client.embeddings.create(input=[text], model=model).data[0].embedding diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py index 152f895e..751177ed 100644 --- a/src/plugins/knowledge/src/lpmmconfig.py +++ b/src/plugins/knowledge/src/lpmmconfig.py @@ -58,7 +58,6 @@ def _load_config(config, config_file_path): config["persistence"] = file_config["persistence"] print(config) print("Configurations loaded from file: ", config_file_path) - parser = argparse.ArgumentParser(description="Configurations for the pipeline") @@ -122,9 +121,9 @@ global_config = dict( "embedding_data_dir": "data/embedding", "rag_data_dir": "data/rag", }, - "info_extraction":{ + "info_extraction": { "workers": 10, - } + }, } ) diff --git a/src/plugins/knowledge/src/mem_active_manager.py b/src/plugins/knowledge/src/mem_active_manager.py index 073e3bda..3998c066 100644 --- a/src/plugins/knowledge/src/mem_active_manager.py +++ b/src/plugins/knowledge/src/mem_active_manager.py @@ -16,13 +16,9 @@ class MemoryActiveManager: def get_activation(self, question: str) -> float: """获取记忆激活度""" # 生成问题的Embedding - question_embedding = self.embedding_client.send_embedding_request( - "text-embedding", question - ) + question_embedding = self.embedding_client.send_embedding_request("text-embedding", question) # 查询关系库中的相似度 - rel_search_res = self.embed_manager.relation_embedding_store.search_top_k( - question_embedding, 10 - ) + rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10) # 动态过滤阈值 rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0) diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py index 3b5d704d..592c66d8 100644 --- a/src/plugins/knowledge/src/open_ie.py +++ b/src/plugins/knowledge/src/open_ie.py @@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]: """过滤无效的实体""" valid_entities = set() for entity in entities: - if ( - not isinstance(entity, str) - or entity.strip() == "" - or entity in INVALID_ENTITY - or entity in valid_entities - ): + if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities: # 非字符串/空字符串/在无效实体列表中/重复 continue valid_entities.add(entity) @@ -74,9 +69,7 @@ class OpenIE: for doc in self.docs: # 过滤实体列表 - doc["extracted_entities"] = _filter_invalid_entities( - doc["extracted_entities"] - ) + doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"]) # 过滤无效的三元组 doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) @@ -100,9 +93,7 @@ class OpenIE: @staticmethod def load() -> "OpenIE": """从文件中加载OpenIE数据""" - with open( - global_config["persistence"]["openie_data_path"], "r", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: data = json.loads(f.read()) openie_data = OpenIE._from_dict(data) @@ -112,9 +103,7 @@ class OpenIE: @staticmethod def save(openie_data: "OpenIE"): """保存OpenIE数据到文件""" - with open( - global_config["persistence"]["openie_data_path"], "w", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f: f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) def extract_entity_dict(self): @@ -141,7 +130,5 @@ class OpenIE: def extract_raw_paragraph_dict(self): """提取原始段落""" - raw_paragraph_dict = dict( - {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs} - ) + raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) return raw_paragraph_dict diff --git a/src/plugins/knowledge/src/prompt_template.py b/src/plugins/knowledge/src/prompt_template.py index ab6ac08c..18a5002e 100644 --- a/src/plugins/knowledge/src/prompt_template.py +++ b/src/plugins/knowledge/src/prompt_template.py @@ -41,9 +41,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描 def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]: messages = [ LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), - LLMMessage( - "user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""" - ).to_dict(), + LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(), ] return messages @@ -58,16 +56,10 @@ qa_system_prompt = """ """ -def build_qa_context( - question: str, knowledge: list[(str, str, str)] -) -> List[LLMMessage]: - knowledge = "\n".join( - [f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)] - ) +def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]: + knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) messages = [ LLMMessage("system", qa_system_prompt).to_dict(), - LLMMessage( - "user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}" - ).to_dict(), + LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), ] return messages diff --git a/src/plugins/knowledge/src/qa_manager.py b/src/plugins/knowledge/src/qa_manager.py index e65c9b8f..8eaaf7b6 100644 --- a/src/plugins/knowledge/src/qa_manager.py +++ b/src/plugins/knowledge/src/qa_manager.py @@ -2,6 +2,7 @@ import time from typing import Tuple, List, Dict from .global_logger import logger + # from . import prompt_template from .embedding_store import EmbeddingManager from .llm_client import LLMClient @@ -31,7 +32,7 @@ class QAManager: """处理查询""" # 生成问题的Embedding - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() question_embedding = self.llm_client_list["embedding"].send_embedding_request( global_config["embedding"]["model"], question ) @@ -39,7 +40,7 @@ class QAManager: logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") # 根据问题Embedding查询Relation Embedding库 - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( question_embedding, global_config["qa"]["params"]["relation_search_top_k"], @@ -47,10 +48,7 @@ class QAManager: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if ( - relation_search_res[0][1] - < global_config["qa"]["params"]["relation_threshold"] - ): + if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]: # 未找到相关关系 relation_search_res = [] @@ -66,12 +64,10 @@ class QAManager: # part_start_time = time.time() # 根据问题Embedding查询Paragraph Embedding库 - part_start_time =time.perf_counter() - paragraph_search_res = ( - self.embed_manager.paragraphs_embedding_store.search_top_k( - question_embedding, - global_config["qa"]["params"]["paragraph_search_top_k"], - ) + part_start_time = time.perf_counter() + paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( + question_embedding, + global_config["qa"]["params"]["paragraph_search_top_k"], ) part_end_time = time.perf_counter() logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") @@ -79,7 +75,7 @@ class QAManager: if len(relation_search_res) != 0: logger.info("找到相关关系,将使用RAG进行检索") # 使用KG检索 - part_start_time =time.perf_counter() + part_start_time = time.perf_counter() result, ppr_node_weights = self.kg_manager.kg_search( relation_search_res, paragraph_search_res, self.embed_manager ) @@ -94,9 +90,7 @@ class QAManager: result = dyn_select_top_k(result, 0.5, 1.0) for res in result: - raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[ - res[0] - ].str + raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") return result, ppr_node_weights @@ -114,4 +108,4 @@ class QAManager: for res in query_res ] found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) - return found_knowledge \ No newline at end of file + return found_knowledge diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py index 3670a705..66fe5303 100644 --- a/src/plugins/knowledge/src/raw_processing.py +++ b/src/plugins/knowledge/src/raw_processing.py @@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]: """ # 读取import.json文件 if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: - with open( - global_config["persistence"]["raw_data_path"], "r", encoding="utf-8" - ) as f: + with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: import_json = json.loads(f.read()) else: raise Exception("原始数据文件读取失败") diff --git a/src/plugins/knowledge/src/utils/data_loader.py b/src/plugins/knowledge/src/utils/data_loader.py index 08bdf414..3b5e8d2e 100644 --- a/src/plugins/knowledge/src/utils/data_loader.py +++ b/src/plugins/knowledge/src/utils/data_loader.py @@ -14,11 +14,7 @@ class DataLoader: Args: custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径 """ - self.data_dir = ( - Path(custom_data_dir) - if custom_data_dir - else Path(config["persistence"]["data_root_path"]) - ) + self.data_dir = Path(custom_data_dir) if custom_data_dir else Path(config["persistence"]["data_root_path"]) if not self.data_dir.exists(): raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在") diff --git a/src/plugins/knowledge/src/utils/dyn_topk.py b/src/plugins/knowledge/src/utils/dyn_topk.py index 02bc5e3e..eb40ef3a 100644 --- a/src/plugins/knowledge/src/utils/dyn_topk.py +++ b/src/plugins/knowledge/src/utils/dyn_topk.py @@ -36,14 +36,10 @@ def dyn_select_top_k( # 计算均值 mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score) # 计算方差 - var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len( - normalized_score - ) + var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score) # 动态阈值 - threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * ( - mean_score + var_factor * var_score - ) + threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score) # 重新过滤 res = [s for s in normalized_score if s[2] > threshold] diff --git a/src/plugins/knowledge/src/utils/json_fix.py b/src/plugins/knowledge/src/utils/json_fix.py index 672fb1f8..a83eb491 100644 --- a/src/plugins/knowledge/src/utils/json_fix.py +++ b/src/plugins/knowledge/src/utils/json_fix.py @@ -29,10 +29,7 @@ def _find_unclosed(json_str): elif char in "{[": unclosed.append(char) elif char in "}]": - if unclosed and ( - (char == "}" and unclosed[-1] == "{") - or (char == "]" and unclosed[-1] == "[") - ): + if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")): unclosed.pop() return unclosed diff --git a/src/plugins/knowledge/src/utils/visualize_graph.py b/src/plugins/knowledge/src/utils/visualize_graph.py index 845e18a2..7ca9b7e6 100644 --- a/src/plugins/knowledge/src/utils/visualize_graph.py +++ b/src/plugins/knowledge/src/utils/visualize_graph.py @@ -14,4 +14,4 @@ def draw_graph_and_show(graph): font_family="Sarasa Mono SC", font_size=8, ) - fig.show() \ No newline at end of file + fig.show()