🤖 自动格式化代码 [skip ci]

pull/789/head
github-actions[bot] 2025-04-18 04:34:37 +00:00
parent bf7827a571
commit 3c95813c90
18 changed files with 119 additions and 236 deletions

View File

@ -20,6 +20,7 @@ import sys
logger = get_module_logger("LPMM知识库-OpenIE导入") logger = get_module_logger("LPMM知识库-OpenIE导入")
def hash_deduplicate( def hash_deduplicate(
raw_paragraphs: Dict[str, str], raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]], triple_list_data: Dict[str, List[List[str]]],
@ -43,14 +44,10 @@ def hash_deduplicate(
# 保存去重后的三元组 # 保存去重后的三元组
new_triple_list_data = dict() new_triple_list_data = dict()
for _, (raw_paragraph, triple_list) in enumerate( for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
zip(raw_paragraphs.values(), triple_list_data.values())
):
# 段落hash # 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and ( if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes):
paragraph_hash in stored_paragraph_hashes
):
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
@ -58,9 +55,7 @@ def hash_deduplicate(
return new_raw_paragraphs, new_triple_list_data return new_raw_paragraphs, new_triple_list_data
def handle_import_openie( def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager
) -> bool:
# 从OpenIE数据中提取段落原文与三元组列表 # 从OpenIE数据中提取段落原文与三元组列表
# 索引的段落原文 # 索引的段落原文
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
@ -68,9 +63,7 @@ def handle_import_openie(
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表 # 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len( if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
triple_list_data
):
logger.error("OpenIE数据存在异常") logger.error("OpenIE数据存在异常")
return False return False
# 将索引换为对应段落的hash值 # 将索引换为对应段落的hash值
@ -112,7 +105,7 @@ def main():
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
print("同上样例导入时10700K几乎跑满14900HX占用80%峰值内存占用约3G") print("同上样例导入时10700K几乎跑满14900HX占用80%峰值内存占用约3G")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
logger.info("用户取消操作") logger.info("用户取消操作")
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
@ -129,9 +122,7 @@ def main():
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = embed_manager = EmbeddingManager( embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
llm_client_list[global_config["embedding"]["provider"]]
)
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()

View File

@ -91,7 +91,7 @@ def main():
print("或者使用可以用赠金抵扣的Pro模型") print("或者使用可以用赠金抵扣的Pro模型")
print("请确保账户余额充足,并且在执行前确认无误。") print("请确保账户余额充足,并且在执行前确认无误。")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
logger.info("用户取消操作") logger.info("用户取消操作")
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)

View File

@ -3,18 +3,17 @@ import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
def check_and_create_dirs(): def check_and_create_dirs():
"""检查并创建必要的目录""" """检查并创建必要的目录"""
required_dirs = [ required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
"data/lpmm_raw_data",
"data/imported_lpmm_data"
]
for dir_path in required_dirs: for dir_path in required_dirs:
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
os.makedirs(dir_path) os.makedirs(dir_path)
print(f"已创建目录: {dir_path}") print(f"已创建目录: {dir_path}")
def process_text_file(file_path): def process_text_file(file_path):
"""处理单个文本文件,返回段落列表""" """处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
@ -35,6 +34,7 @@ def process_text_file(file_path):
return paragraphs return paragraphs
def main(): def main():
# 新增用户确认提示 # 新增用户确认提示
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
@ -43,7 +43,7 @@ def main():
print("在进行知识库导入之前") print("在进行知识库导入之前")
print("请修改config/lpmm_config.toml中的配置项") print("请修改config/lpmm_config.toml中的配置项")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "=" * 40 + "\n") print("\n" + "=" * 40 + "\n")
@ -80,5 +80,6 @@ def main():
print(f"处理完成,结果已保存到: {output_path}") print(f"处理完成,结果已保存到: {output_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
from src.plugins.chat.utils import get_embedding from src.plugins.chat.utils import get_embedding
# from src.common.database import db # from src.common.database import db
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Dict, Any

View File

@ -22,9 +22,7 @@ for key in global_config["llm_providers"]:
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager( embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
llm_client_list[global_config["embedding"]["provider"]]
)
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
@ -63,4 +61,3 @@ inspire_manager = MemoryActiveManager(
embed_manager, embed_manager,
llm_client_list[global_config["embedding"]["provider"]], llm_client_list[global_config["embedding"]["provider"]],
) )

View File

@ -23,9 +23,7 @@ def _nx_graph_to_lists(
A tuple containing the list of edges and the list of nodes. A tuple containing the list of edges and the list of nodes.
""" """
nodes = [node for node in graph.nodes()] nodes = [node for node in graph.nodes()]
edges = [ edges = [(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()]
(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()
]
return edges, nodes return edges, nodes

View File

@ -9,7 +9,7 @@ def pagerank_py(
personalization: Optional[Dict[str, float]] = None, personalization: Optional[Dict[str, float]] = None,
alpha: float = 0.85, alpha: float = 0.85,
max_iter: int = 100, max_iter: int = 100,
tol: float = 1e-6 tol: float = 1e-6,
) -> Dict[str, float]: ) -> Dict[str, float]:
"""使用 Python、NumPy 和 SciPy 计算个性化 PageRank。 """使用 Python、NumPy 和 SciPy 计算个性化 PageRank。
@ -91,8 +91,7 @@ def pagerank_py(
# 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ... # 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ...
if len(normalized_data) > 0: if len(normalized_data) > 0:
# 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法) # 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法)
M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), shape=(num_nodes, num_nodes))
shape=(num_nodes, num_nodes))
else: else:
M_T = sp.csc_matrix((num_nodes, num_nodes)) M_T = sp.csc_matrix((num_nodes, num_nodes))
@ -137,6 +136,7 @@ def pagerank_py(
result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)} result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)}
return result_dict return result_dict
# --- 示例用法(类似于 pr.c 中的 main--- # --- 示例用法(类似于 pr.c 中的 main---
if __name__ == "__main__": if __name__ == "__main__":
nodes_test = ["0", "1", "2", "3", "4"] nodes_test = ["0", "1", "2", "3", "4"]
@ -146,7 +146,7 @@ if __name__ == "__main__":
("2", "0", 0.2), ("2", "0", 0.2),
("1", "3", 0.4), ("1", "3", 0.4),
("3", "4", 0.6), ("3", "4", 0.6),
("4", "1", 0.7) ("4", "1", 0.7),
] ]
# 添加一个悬挂节点示例 # 添加一个悬挂节点示例
nodes_test.append("5") nodes_test.append("5")
@ -161,12 +161,7 @@ if __name__ == "__main__":
print("运行优化的 Python PageRank 实现...") print("运行优化的 Python PageRank 实现...")
result = pagerank_py( result = pagerank_py(
nodes_test, nodes_test, edges_test, personalization_test, alpha=alpha_test, max_iter=max_iter_test, tol=tol_test
edges_test,
personalization_test,
alpha=alpha_test,
max_iter=max_iter_test,
tol=tol_test
) )
print("\nPageRank 分数:") print("\nPageRank 分数:")
@ -182,7 +177,7 @@ if __name__ == "__main__":
personalization=None, # 使用默认的统一性化设置 personalization=None, # 使用默认的统一性化设置
alpha=alpha_test, alpha=alpha_test,
max_iter=max_iter_test, max_iter=max_iter_test,
tol=tol_test tol=tol_test,
) )
print("\nPageRank 分数(默认个性化):") print("\nPageRank 分数(默认个性化):")
sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x)) sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x))
@ -192,6 +187,7 @@ if __name__ == "__main__":
# 与 NetworkX 对比 (如果安装了) # 与 NetworkX 对比 (如果安装了)
try: try:
import networkx as nx import networkx as nx
print("\n与 NetworkX PageRank 对比 (个性化)...") print("\n与 NetworkX PageRank 对比 (个性化)...")
G = nx.DiGraph() G = nx.DiGraph()
G.add_nodes_from(nodes_test) G.add_nodes_from(nodes_test)
@ -209,12 +205,16 @@ if __name__ == "__main__":
else: # 如果全为0NetworkX 会报错或行为未定义,我们设为 None else: # 如果全为0NetworkX 会报错或行为未定义,我们设为 None
nx_pers = None nx_pers = None
nx_result = nx.pagerank(G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None) # weight=None 强制标准 PageRank nx_result = nx.pagerank(
G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None
) # weight=None 强制标准 PageRank
for node_id in sorted_nodes: for node_id in sorted_nodes:
print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}") print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}")
print("\n与 NetworkX PageRank 对比 (默认)...") print("\n与 NetworkX PageRank 对比 (默认)...")
nx_result_default = nx.pagerank(G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None) nx_result_default = nx.pagerank(
G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None
)
for node_id in sorted_nodes_default: for node_id in sorted_nodes_default:
print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}") print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}")

View File

@ -47,9 +47,7 @@ class EmbeddingStore:
self.idx2hash = None self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request( return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
global_config["embedding"]["model"], s
)
def batch_insert_strs(self, strs: List[str]) -> None: def batch_insert_strs(self, strs: List[str]) -> None:
"""向库中存入字符串""" """向库中存入字符串"""
@ -83,14 +81,10 @@ class EmbeddingStore:
logger.info(f"{self.namespace}嵌入库保存成功") logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None: if self.faiss_index is not None and self.idx2hash is not None:
logger.info( logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}"
)
faiss.write_index(self.faiss_index, self.index_file_path) faiss.write_index(self.faiss_index, self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
logger.info( logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}"
)
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
@ -103,24 +97,18 @@ class EmbeddingStore:
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
self.store[row["hash"]] = EmbeddingStoreItem( self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
row["hash"], row["embedding"], row["str"]
)
logger.info(f"{self.namespace}嵌入库加载成功") logger.info(f"{self.namespace}嵌入库加载成功")
try: try:
if os.path.exists(self.index_file_path): if os.path.exists(self.index_file_path):
logger.info( logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex"
)
self.faiss_index = faiss.read_index(self.index_file_path) self.faiss_index = faiss.read_index(self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
else: else:
raise Exception(f"文件{self.index_file_path}不存在") raise Exception(f"文件{self.index_file_path}不存在")
if os.path.exists(self.idx2hash_file_path): if os.path.exists(self.idx2hash_file_path):
logger.info( logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射"
)
with open(self.idx2hash_file_path, "r") as f: with open(self.idx2hash_file_path, "r") as f:
self.idx2hash = json.load(f) self.idx2hash = json.load(f)
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
@ -215,9 +203,7 @@ class EmbeddingManager:
for triples in triple_list_data.values(): for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples]) graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples)) graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs( self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
[str(triple) for triple in graph_triples]
)
def load_from_file(self): def load_from_file(self):
"""从文件加载""" """从文件加载"""

View File

@ -7,8 +7,6 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
console_logging_handler = logging.StreamHandler() console_logging_handler = logging.StreamHandler()
console_logging_handler.setFormatter( console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
console_logging_handler.setLevel(logging.DEBUG) console_logging_handler.setLevel(logging.DEBUG)
logger.addHandler(console_logging_handler) logger.addHandler(console_logging_handler)

View File

@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
return entity_extract_result return entity_extract_result
def _rdf_triple_extract( def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
llm_client: LLMClient, paragraph: str, entities: list
) -> List[List[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_rdf_triple_extract_context( entity_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False) paragraph, entities=json.dumps(entities, ensure_ascii=False)
) )
_, request_result = llm_client.send_chat_request( _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
global_config["rdf_build"]["llm"]["model"], entity_extract_context
)
# 去除‘{’前的内容(结果中可能有多个‘{ # 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result: if "[" in request_result:
@ -60,11 +56,7 @@ def _rdf_triple_extract(
entity_extract_result = json.loads(fix_broken_generated_json(request_result)) entity_extract_result = json.loads(fix_broken_generated_json(request_result))
for triple in entity_extract_result: for triple in entity_extract_result:
if ( if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
len(triple) != 3
or (triple[0] is None or triple[1] is None or triple[2] is None)
or "" in triple
):
raise Exception("RDF提取结果格式错误") raise Exception("RDF提取结果格式错误")
return entity_extract_result return entity_extract_result
@ -91,9 +83,7 @@ def info_extract_from_str(
try_count = 0 try_count = 0
while True: while True:
try: try:
rdf_triple_extract_result = _rdf_triple_extract( rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result)
llm_client_for_rdf, paragraph, entity_extract_result
)
break break
except Exception as e: except Exception as e:
logger.warning(f"实体提取失败,错误信息:{e}") logger.warning(f"实体提取失败,错误信息:{e}")

View File

@ -22,6 +22,7 @@ from .lpmmconfig import (
from .global_logger import logger from .global_logger import logger
class KGManager: class KGManager:
def __init__(self): def __init__(self):
# 会被保存的字段 # 会被保存的字段
@ -35,9 +36,7 @@ class KGManager:
# 持久化相关 # 持久化相关
self.dir_path = global_config["persistence"]["rag_data_dir"] self.dir_path = global_config["persistence"]["rag_data_dir"]
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz" self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz"
self.ent_cnt_data_path = ( self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
)
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
def save_to_file(self): def save_to_file(self):
@ -50,9 +49,7 @@ class KGManager:
nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8") nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8")
# 保存实体计数到文件 # 保存实体计数到文件
ent_cnt_df = pd.DataFrame( ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()])
[{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]
)
ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
# 保存段落hash到文件 # 保存段落hash到文件
@ -77,9 +74,7 @@ class KGManager:
# 加载实体计数 # 加载实体计数
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
self.ent_appear_cnt = dict( self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
{row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
)
# 加载KG # 加载KG
self.graph = nx.read_graphml(self.graph_data_path) self.graph = nx.read_graphml(self.graph_data_path)
@ -101,20 +96,14 @@ class KGManager:
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = ( node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
)
node_to_node[(hash_key2, hash_key1)] = (
node_to_node.get((hash_key2, hash_key1), 0) + 1.0
)
entity_set.add(hash_key1) entity_set.add(hash_key1)
entity_set.add(hash_key2) entity_set.add(hash_key2)
# 实体出现次数统计 # 实体出现次数统计
for hash_key in entity_set: for hash_key in entity_set:
self.ent_appear_cnt[hash_key] = ( self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0
self.ent_appear_cnt.get(hash_key, 0) + 1.0
)
@staticmethod @staticmethod
def _build_edges_between_ent_pg( def _build_edges_between_ent_pg(
@ -126,9 +115,7 @@ class KGManager:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx) pg_hash_key = PG_NAMESPACE + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = ( node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
)
@staticmethod @staticmethod
def _synonym_connect( def _synonym_connect(
@ -177,9 +164,7 @@ class KGManager:
new_edge_cnt += 1 new_edge_cnt += 1
res_ent.append( res_ent.append(
( (
embedding_manager.entities_embedding_store.store[ embedding_manager.entities_embedding_store.store[res_ent_hash].str,
res_ent_hash
].str,
similarity, similarity,
) )
) # Debug ) # Debug
@ -236,22 +221,16 @@ class KGManager:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE): if node_hash.startswith(ENT_NAMESPACE):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store[ node = embedding_manager.entities_embedding_store.store[node_hash]
node_hash
]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
self.graph.nodes[node_hash]["content"] = node.str self.graph.nodes[node_hash]["content"] = node.str
self.graph.nodes[node_hash]["type"] = "ent" self.graph.nodes[node_hash]["type"] = "ent"
elif node_hash.startswith(PG_NAMESPACE): elif node_hash.startswith(PG_NAMESPACE):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[ node = embedding_manager.paragraphs_embedding_store.store[node_hash]
node_hash
]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
content = node.str.replace("\n", " ") content = node.str.replace("\n", " ")
self.graph.nodes[node_hash]["content"] = ( self.graph.nodes[node_hash]["content"] = content if len(content) < 8 else content[:8] + "..."
content if len(content) < 8 else content[:8] + "..."
)
self.graph.nodes[node_hash]["type"] = "pg" self.graph.nodes[node_hash]["type"] = "pg"
def build_kg( def build_kg(
@ -319,9 +298,7 @@ class KGManager:
ent_sim_scores = {} ent_sim_scores = {}
for relation_hash, similarity, _ in relation_search_result: for relation_hash, similarity, _ in relation_search_result:
# 提取主宾短语 # 提取主宾短语
relation = embed_manager.relation_embedding_store.store.get( relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
relation_hash
).str
assert relation is not None # 断言relation不为空 assert relation is not None # 断言relation不为空
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
@ -335,9 +312,7 @@ class KGManager:
ent_mean_scores = {} # 记录实体的平均相似度 ent_mean_scores = {} # 记录实体的平均相似度
for ent_hash, scores in ent_sim_scores.items(): for ent_hash, scores in ent_sim_scores.items():
# 先对相似度进行累加,然后与实体计数相除获取最终权重 # 先对相似度进行累加,然后与实体计数相除获取最终权重
ent_weights[ent_hash] = ( ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
)
# 记录实体的平均相似度用于后续的top_k筛选 # 记录实体的平均相似度用于后续的top_k筛选
ent_mean_scores[ent_hash] = float(np.mean(scores)) ent_mean_scores[ent_hash] = float(np.mean(scores))
del ent_sim_scores del ent_sim_scores
@ -354,21 +329,14 @@ class KGManager:
for ent_hash, score in ent_weights.items(): for ent_hash, score in ent_weights.items():
# 缩放相似度 # 缩放相似度
ent_weights[ent_hash] = ( ent_weights[ent_hash] = (
(score - ent_weights_min) (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min)
* (1 - down_edge)
/ (ent_weights_max - ent_weights_min)
) + down_edge ) + down_edge
# 取平均相似度的top_k实体 # 取平均相似度的top_k实体
top_k = global_config["qa"]["params"]["ent_filter_top_k"] top_k = global_config["qa"]["params"]["ent_filter_top_k"]
if len(ent_mean_scores) > top_k: if len(ent_mean_scores) > top_k:
# 从大到小排序取后len - k个 # 从大到小排序取后len - k个
ent_mean_scores = { ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
k: v
for k, v in sorted(
ent_mean_scores.items(), key=lambda item: item[1], reverse=True
)
}
for ent_hash, _ in ent_mean_scores.items(): for ent_hash, _ in ent_mean_scores.items():
# 删除被淘汰的实体节点权重设置 # 删除被淘汰的实体节点权重设置
del ent_weights[ent_hash] del ent_weights[ent_hash]
@ -389,9 +357,7 @@ class KGManager:
# 归一化 # 归一化
for pg_hash, similarity in pg_sim_scores.items(): for pg_hash, similarity in pg_sim_scores.items():
# 归一化相似度 # 归一化相似度
pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / ( pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min)
pg_sim_score_max - pg_sim_score_min
)
del pg_sim_score_max, pg_sim_score_min del pg_sim_score_max, pg_sim_score_min
for pg_hash, score in pg_sim_scores.items(): for pg_hash, score in pg_sim_scores.items():
@ -401,9 +367,7 @@ class KGManager:
del pg_sim_scores del pg_sim_scores
# 最终权重数据 = 实体权重 + 文段权重 # 最终权重数据 = 实体权重 + 文段权重
ppr_node_weights = { ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()}
k: v for d in [ent_weights, pg_weights] for k, v in d.items()
}
del ent_weights, pg_weights del ent_weights, pg_weights
# PersonalizedPageRank # PersonalizedPageRank
@ -424,8 +388,6 @@ class KGManager:
del ppr_res del ppr_res
# 排序:按照分数从大到小 # 排序:按照分数从大到小
passage_node_res = sorted( passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
passage_node_res, key=lambda item: item[1], reverse=True
)
return passage_node_res, ppr_node_weights return passage_node_res, ppr_node_weights

View File

@ -21,20 +21,14 @@ class LLMClient:
def send_chat_request(self, model, messages): def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果""" """发送对话请求,等待返回结果"""
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
model=model, messages=messages, stream=False
)
if hasattr(response.choices[0].message, "reasoning_content"): if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块 # 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content content = response.choices[0].message.content
else: else:
# 无单独的推理内容块 # 无单独的推理内容块
response = ( response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
response.choices[0]
.message.content.split("<think>")[-1]
.split("</think>")
)
# 如果有推理内容,则分割推理内容和内容 # 如果有推理内容,则分割推理内容和内容
if len(response) == 2: if len(response) == 2:
reasoning_content = response[0] reasoning_content = response[0]
@ -48,6 +42,4 @@ class LLMClient:
def send_embedding_request(self, model, text): def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果""" """发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return self.client.embeddings.create(input=[text], model=model).data[0].embedding
self.client.embeddings.create(input=[text], model=model).data[0].embedding
)

View File

@ -16,13 +16,9 @@ class MemoryActiveManager:
def get_activation(self, question: str) -> float: def get_activation(self, question: str) -> float:
"""获取记忆激活度""" """获取记忆激活度"""
# 生成问题的Embedding # 生成问题的Embedding
question_embedding = self.embedding_client.send_embedding_request( question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
"text-embedding", question
)
# 查询关系库中的相似度 # 查询关系库中的相似度
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k( rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
question_embedding, 10
)
# 动态过滤阈值 # 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0) rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)

View File

@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体""" """过滤无效的实体"""
valid_entities = set() valid_entities = set()
for entity in entities: for entity in entities:
if ( if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities:
not isinstance(entity, str)
or entity.strip() == ""
or entity in INVALID_ENTITY
or entity in valid_entities
):
# 非字符串/空字符串/在无效实体列表中/重复 # 非字符串/空字符串/在无效实体列表中/重复
continue continue
valid_entities.add(entity) valid_entities.add(entity)
@ -74,9 +69,7 @@ class OpenIE:
for doc in self.docs: for doc in self.docs:
# 过滤实体列表 # 过滤实体列表
doc["extracted_entities"] = _filter_invalid_entities( doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"])
doc["extracted_entities"]
)
# 过滤无效的三元组 # 过滤无效的三元组
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@ -100,9 +93,7 @@ class OpenIE:
@staticmethod @staticmethod
def load() -> "OpenIE": def load() -> "OpenIE":
"""从文件中加载OpenIE数据""" """从文件中加载OpenIE数据"""
with open( with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
global_config["persistence"]["openie_data_path"], "r", encoding="utf-8"
) as f:
data = json.loads(f.read()) data = json.loads(f.read())
openie_data = OpenIE._from_dict(data) openie_data = OpenIE._from_dict(data)
@ -112,9 +103,7 @@ class OpenIE:
@staticmethod @staticmethod
def save(openie_data: "OpenIE"): def save(openie_data: "OpenIE"):
"""保存OpenIE数据到文件""" """保存OpenIE数据到文件"""
with open( with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
global_config["persistence"]["openie_data_path"], "w", encoding="utf-8"
) as f:
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
def extract_entity_dict(self): def extract_entity_dict(self):
@ -141,7 +130,5 @@ class OpenIE:
def extract_raw_paragraph_dict(self): def extract_raw_paragraph_dict(self):
"""提取原始段落""" """提取原始段落"""
raw_paragraph_dict = dict( raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
{doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
)
return raw_paragraph_dict return raw_paragraph_dict

View File

@ -46,10 +46,7 @@ class QAManager:
# 过滤阈值 # 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
if ( if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
relation_search_res[0][1]
< global_config["qa"]["params"]["relation_threshold"]
):
# 未找到相关关系 # 未找到相关关系
relation_search_res = [] relation_search_res = []
@ -66,12 +63,10 @@ class QAManager:
# 根据问题Embedding查询Paragraph Embedding库 # 根据问题Embedding查询Paragraph Embedding库
part_start_time = time.perf_counter() part_start_time = time.perf_counter()
paragraph_search_res = ( paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
self.embed_manager.paragraphs_embedding_store.search_top_k(
question_embedding, question_embedding,
global_config["qa"]["params"]["paragraph_search_top_k"], global_config["qa"]["params"]["paragraph_search_top_k"],
) )
)
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
@ -93,9 +88,7 @@ class QAManager:
result = dyn_select_top_k(result, 0.5, 1.0) result = dyn_select_top_k(result, 0.5, 1.0)
for res in result: for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[ raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
res[0]
].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights return result, ppr_node_weights

View File

@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]:
""" """
# 读取import.json文件 # 读取import.json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
with open( with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
global_config["persistence"]["raw_data_path"], "r", encoding="utf-8"
) as f:
import_json = json.loads(f.read()) import_json = json.loads(f.read())
else: else:
raise Exception("原始数据文件读取失败") raise Exception("原始数据文件读取失败")

View File

@ -36,14 +36,10 @@ def dyn_select_top_k(
# 计算均值 # 计算均值
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score) mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
# 计算方差 # 计算方差
var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len( var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score)
normalized_score
)
# 动态阈值 # 动态阈值
threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * ( threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score)
mean_score + var_factor * var_score
)
# 重新过滤 # 重新过滤
res = [s for s in normalized_score if s[2] > threshold] res = [s for s in normalized_score if s[2] > threshold]

View File

@ -29,10 +29,7 @@ def _find_unclosed(json_str):
elif char in "{[": elif char in "{[":
unclosed.append(char) unclosed.append(char)
elif char in "}]": elif char in "}]":
if unclosed and ( if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
(char == "}" and unclosed[-1] == "{")
or (char == "]" and unclosed[-1] == "[")
):
unclosed.pop() unclosed.pop()
return unclosed return unclosed