🤖 自动格式化代码 [skip ci]

pull/787/head
github-actions[bot] 2025-04-17 15:01:56 +00:00
parent 997e4bd203
commit 51f398176f
24 changed files with 132 additions and 259 deletions

View File

@ -20,6 +20,7 @@ import sys
logger = get_module_logger("LPMM知识库-OpenIE导入") logger = get_module_logger("LPMM知识库-OpenIE导入")
def hash_deduplicate( def hash_deduplicate(
raw_paragraphs: Dict[str, str], raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]], triple_list_data: Dict[str, List[List[str]]],
@ -43,14 +44,10 @@ def hash_deduplicate(
# 保存去重后的三元组 # 保存去重后的三元组
new_triple_list_data = dict() new_triple_list_data = dict()
for _, (raw_paragraph, triple_list) in enumerate( for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
zip(raw_paragraphs.values(), triple_list_data.values())
):
# 段落hash # 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and ( if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes):
paragraph_hash in stored_paragraph_hashes
):
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
@ -58,9 +55,7 @@ def hash_deduplicate(
return new_raw_paragraphs, new_triple_list_data return new_raw_paragraphs, new_triple_list_data
def handle_import_openie( def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager
) -> bool:
# 从OpenIE数据中提取段落原文与三元组列表 # 从OpenIE数据中提取段落原文与三元组列表
# 索引的段落原文 # 索引的段落原文
raw_paragraphs = openie_data.extract_raw_paragraph_dict() raw_paragraphs = openie_data.extract_raw_paragraph_dict()
@ -68,9 +63,7 @@ def handle_import_openie(
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表 # 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len( if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
triple_list_data
):
logger.error("OpenIE数据存在异常") logger.error("OpenIE数据存在异常")
return False return False
# 将索引换为对应段落的hash值 # 将索引换为对应段落的hash值
@ -112,11 +105,11 @@ def main():
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
print("同上样例导入时10700K几乎跑满14900HX占用80%峰值内存占用约3G") print("同上样例导入时10700K几乎跑满14900HX占用80%峰值内存占用约3G")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
logger.info("用户取消操作") logger.info("用户取消操作")
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "="*40 + "\n") print("\n" + "=" * 40 + "\n")
logger.info("----开始导入openie数据----\n") logger.info("----开始导入openie数据----\n")
@ -129,9 +122,7 @@ def main():
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = embed_manager = EmbeddingManager( embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
llm_client_list[global_config["embedding"]["provider"]]
)
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()

View File

@ -91,11 +91,11 @@ def main():
print("或者使用可以用赠金抵扣的Pro模型") print("或者使用可以用赠金抵扣的Pro模型")
print("请确保账户余额充足,并且在执行前确认无误。") print("请确保账户余额充足,并且在执行前确认无误。")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
logger.info("用户取消操作") logger.info("用户取消操作")
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "="*40 + "\n") print("\n" + "=" * 40 + "\n")
logger.info("--------进行信息提取--------\n") logger.info("--------进行信息提取--------\n")

View File

@ -3,18 +3,17 @@ import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
def check_and_create_dirs(): def check_and_create_dirs():
"""检查并创建必要的目录""" """检查并创建必要的目录"""
required_dirs = [ required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
"data/lpmm_raw_data",
"data/imported_lpmm_data"
]
for dir_path in required_dirs: for dir_path in required_dirs:
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
os.makedirs(dir_path) os.makedirs(dir_path)
print(f"已创建目录: {dir_path}") print(f"已创建目录: {dir_path}")
def process_text_file(file_path): def process_text_file(file_path):
"""处理单个文本文件,返回段落列表""" """处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f: with open(file_path, "r", encoding="utf-8") as f:
@ -29,12 +28,13 @@ def process_text_file(file_path):
paragraph = "" paragraph = ""
else: else:
paragraph += line + "\n" paragraph += line + "\n"
if paragraph != "": if paragraph != "":
paragraphs.append(paragraph.strip()) paragraphs.append(paragraph.strip())
return paragraphs return paragraphs
def main(): def main():
# 新增用户确认提示 # 新增用户确认提示
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
@ -43,42 +43,43 @@ def main():
print("在进行知识库导入之前") print("在进行知识库导入之前")
print("请修改config/lpmm_config.toml中的配置项") print("请修改config/lpmm_config.toml中的配置项")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != 'y': if confirm != "y":
print("操作已取消") print("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "="*40 + "\n") print("\n" + "=" * 40 + "\n")
# 检查并创建必要的目录 # 检查并创建必要的目录
check_and_create_dirs() check_and_create_dirs()
# 检查输出文件是否存在 # 检查输出文件是否存在
if os.path.exists("data/import.json"): if os.path.exists("data/import.json"):
print("错误: data/import.json 已存在,请先处理或删除该文件") print("错误: data/import.json 已存在,请先处理或删除该文件")
sys.exit(1) sys.exit(1)
if os.path.exists("data/openie.json"): if os.path.exists("data/openie.json"):
print("错误: data/openie.json 已存在,请先处理或删除该文件") print("错误: data/openie.json 已存在,请先处理或删除该文件")
sys.exit(1) sys.exit(1)
# 获取所有原始文本文件 # 获取所有原始文本文件
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
if not raw_files: if not raw_files:
print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1) sys.exit(1)
# 处理所有文件 # 处理所有文件
all_paragraphs = [] all_paragraphs = []
for file in raw_files: for file in raw_files:
print(f"正在处理文件: {file.name}") print(f"正在处理文件: {file.name}")
paragraphs = process_text_file(file) paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs) all_paragraphs.extend(paragraphs)
# 保存合并后的结果 # 保存合并后的结果
output_path = "data/import.json" output_path = "data/import.json"
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
print(f"处理完成,结果已保存到: {output_path}") print(f"处理完成,结果已保存到: {output_path}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool from src.do_tool.tool_can_use.base_tool import BaseTool
from src.plugins.chat.utils import get_embedding from src.plugins.chat.utils import get_embedding
# from src.common.database import db # from src.common.database import db
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from typing import Dict, Any from typing import Dict, Any

View File

@ -364,7 +364,7 @@ class PromptBuilder:
# grouped_results[topic] = [] # grouped_results[topic] = []
# grouped_results[topic].append(result) # grouped_results[topic].append(result)
# 按主题组织输出 # 按主题组织输出
# for topic, results in grouped_results.items(): # for topic, results in grouped_results.items():
# related_info += f"【主题: {topic}】\n" # related_info += f"【主题: {topic}】\n"
# for _i, result in enumerate(results, 1): # for _i, result in enumerate(results, 1):

View File

@ -22,9 +22,7 @@ for key in global_config["llm_providers"]:
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager( embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
llm_client_list[global_config["embedding"]["provider"]]
)
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
try: try:
embed_manager.load_from_file() embed_manager.load_from_file()
@ -63,4 +61,3 @@ inspire_manager = MemoryActiveManager(
embed_manager, embed_manager,
llm_client_list[global_config["embedding"]["provider"]], llm_client_list[global_config["embedding"]["provider"]],
) )

View File

@ -23,9 +23,7 @@ def _nx_graph_to_lists(
A tuple containing the list of edges and the list of nodes. A tuple containing the list of edges and the list of nodes.
""" """
nodes = [node for node in graph.nodes()] nodes = [node for node in graph.nodes()]
edges = [ edges = [(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()]
(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()
]
return edges, nodes return edges, nodes

View File

@ -9,7 +9,7 @@ def pagerank_py(
personalization: Optional[Dict[str, float]] = None, personalization: Optional[Dict[str, float]] = None,
alpha: float = 0.85, alpha: float = 0.85,
max_iter: int = 100, max_iter: int = 100,
tol: float = 1e-6 tol: float = 1e-6,
) -> Dict[str, float]: ) -> Dict[str, float]:
"""使用 Python、NumPy 和 SciPy 计算个性化 PageRank。 """使用 Python、NumPy 和 SciPy 计算个性化 PageRank。
@ -44,14 +44,14 @@ def pagerank_py(
raw_values = np.maximum(raw_values, 0) raw_values = np.maximum(raw_values, 0)
norm_sum = np.sum(raw_values) norm_sum = np.sum(raw_values)
if norm_sum > 1e-9: # 避免除以零 if norm_sum > 1e-9: # 避免除以零
personalization_vec = raw_values / norm_sum personalization_vec = raw_values / norm_sum
else: else:
# 如果所有提供的个性化值都为零或负数,则回退到均匀分布 # 如果所有提供的个性化值都为零或负数,则回退到均匀分布
print("警告:个性化值总和为零或所有值均为非正数。回退到均匀个性化设置。") print("警告:个性化值总和为零或所有值均为非正数。回退到均匀个性化设置。")
personalization_vec.fill(1.0 / num_nodes) personalization_vec.fill(1.0 / num_nodes)
# --- 构建稀疏邻接矩阵 --- # --- 构建稀疏邻接矩阵 ---
# 标准 PageRank 需要基于出度的归一化 # 标准 PageRank 需要基于出度的归一化
row_ind = [] row_ind = []
col_ind = [] col_ind = []
@ -66,9 +66,9 @@ def pagerank_py(
col_ind.append(src_idx) col_ind.append(src_idx)
# 暂存原始权重,如果需要加权 PageRank可以在此使用 w # 暂存原始权重,如果需要加权 PageRank可以在此使用 w
# 对于标准 PageRank我们只需要知道连接存在 # 对于标准 PageRank我们只需要知道连接存在
data.append(1.0) # 初始数据设为 1之后归一化 data.append(1.0) # 初始数据设为 1之后归一化
# 标准 PageRank 的出度是边的数量,加权 PageRank 可以用 w # 标准 PageRank 的出度是边的数量,加权 PageRank 可以用 w
out_degree[src_idx] += 1 out_degree[src_idx] += 1
# 归一化权重(构建转移矩阵 M 的转置 M.T # 归一化权重(构建转移矩阵 M 的转置 M.T
# M[j, i] 是从 i 到 j 的概率 # M[j, i] 是从 i 到 j 的概率
@ -82,17 +82,16 @@ def pagerank_py(
if out_degree[c] > 0: if out_degree[c] > 0:
# 标准 PageRank: 1.0 / out_degree[c] # 标准 PageRank: 1.0 / out_degree[c]
# 如果要用原始权重 w 作为转移概率(需确保它们已归一化),则用 w / sum(w for edges from c) # 如果要用原始权重 w 作为转移概率(需确保它们已归一化),则用 w / sum(w for edges from c)
normalized_data.append(d / out_degree[c]) normalized_data.append(d / out_degree[c])
new_row_ind.append(c) # M.T 的行索引是 src_idx new_row_ind.append(c) # M.T 的行索引是 src_idx
new_col_ind.append(r) # M.T 的列索引是 dst_idx new_col_ind.append(r) # M.T 的列索引是 dst_idx
# 创建稀疏矩阵 (M.T) # 创建稀疏矩阵 (M.T)
# 注意scipy.sparse 期望 (data, (row_ind, col_ind)) 格式 # 注意scipy.sparse 期望 (data, (row_ind, col_ind)) 格式
# 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ... # 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ...
if len(normalized_data) > 0: if len(normalized_data) > 0:
# 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法) # 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法)
M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), shape=(num_nodes, num_nodes))
shape=(num_nodes, num_nodes))
else: else:
M_T = sp.csc_matrix((num_nodes, num_nodes)) M_T = sp.csc_matrix((num_nodes, num_nodes))
@ -109,44 +108,45 @@ def pagerank_py(
# 还有一种做法是仅分配给个性化向量中非零的节点 # 还有一种做法是仅分配给个性化向量中非零的节点
# --- PageRank 迭代 --- # --- PageRank 迭代 ---
scores = personalization_vec.copy() # 从个性化向量开始 scores = personalization_vec.copy() # 从个性化向量开始
for iteration in range(max_iter): for iteration in range(max_iter):
prev_scores = scores.copy() prev_scores = scores.copy()
# 计算来自链接的贡献 # 计算来自链接的贡献
linked_scores = M_T @ scores linked_scores = M_T @ scores
# 计算来自悬挂节点的贡献 # 计算来自悬挂节点的贡献
# 悬挂节点的总分数 * 悬挂权重向量 # 悬挂节点的总分数 * 悬挂权重向量
dangling_sum = np.sum(scores[is_dangling]) dangling_sum = np.sum(scores[is_dangling])
dangling_contribution = dangling_sum * dangling_weights dangling_contribution = dangling_sum * dangling_weights
# 结合瞬移、链接贡献和悬挂节点贡献 # 结合瞬移、链接贡献和悬挂节点贡献
scores = alpha * (linked_scores + dangling_contribution) + (1 - alpha) * personalization_vec scores = alpha * (linked_scores + dangling_contribution) + (1 - alpha) * personalization_vec
# 检查收敛性 (L1 范数) # 检查收敛性 (L1 范数)
diff = np.sum(np.abs(scores - prev_scores)) diff = np.sum(np.abs(scores - prev_scores))
if diff < tol: if diff < tol:
print(f"{iteration + 1} 次迭代后收敛。") print(f"{iteration + 1} 次迭代后收敛。")
break break
else: # 循环完成但未中断 else: # 循环完成但未中断
print(f"达到最大迭代次数 ({max_iter}) 但未收敛。") print(f"达到最大迭代次数 ({max_iter}) 但未收敛。")
# --- 格式化输出 --- # --- 格式化输出 ---
result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)} result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)}
return result_dict return result_dict
# --- 示例用法(类似于 pr.c 中的 main--- # --- 示例用法(类似于 pr.c 中的 main---
if __name__ == "__main__": if __name__ == "__main__":
nodes_test = ["0", "1", "2", "3", "4"] nodes_test = ["0", "1", "2", "3", "4"]
edges_test = [ edges_test = [
("0", "1", 0.5), # 权重在此实现中仅用于确定出度 ("0", "1", 0.5), # 权重在此实现中仅用于确定出度
("1", "2", 0.3), ("1", "2", 0.3),
("2", "0", 0.2), ("2", "0", 0.2),
("1", "3", 0.4), ("1", "3", 0.4),
("3", "4", 0.6), ("3", "4", 0.6),
("4", "1", 0.7) ("4", "1", 0.7),
] ]
# 添加一个悬挂节点示例 # 添加一个悬挂节点示例
nodes_test.append("5") nodes_test.append("5")
@ -161,37 +161,33 @@ if __name__ == "__main__":
print("运行优化的 Python PageRank 实现...") print("运行优化的 Python PageRank 实现...")
result = pagerank_py( result = pagerank_py(
nodes_test, nodes_test, edges_test, personalization_test, alpha=alpha_test, max_iter=max_iter_test, tol=tol_test
edges_test,
personalization_test,
alpha=alpha_test,
max_iter=max_iter_test,
tol=tol_test
) )
print("\nPageRank 分数:") print("\nPageRank 分数:")
# 按节点索引排序以获得一致的输出 # 按节点索引排序以获得一致的输出
sorted_nodes = sorted(result.keys(), key=lambda x: int(x)) sorted_nodes = sorted(result.keys(), key=lambda x: int(x))
for node_id in sorted_nodes: for node_id in sorted_nodes:
print(f"节点 {node_id}: {result[node_id]:.6f}") print(f"节点 {node_id}: {result[node_id]:.6f}")
print("\n使用默认个性化设置运行...") print("\n使用默认个性化设置运行...")
result_default_pers = pagerank_py( result_default_pers = pagerank_py(
nodes_test, nodes_test,
edges_test, edges_test,
personalization=None, # 使用默认的统一性化设置 personalization=None, # 使用默认的统一性化设置
alpha=alpha_test, alpha=alpha_test,
max_iter=max_iter_test, max_iter=max_iter_test,
tol=tol_test tol=tol_test,
) )
print("\nPageRank 分数(默认个性化):") print("\nPageRank 分数(默认个性化):")
sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x)) sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x))
for node_id in sorted_nodes_default: for node_id in sorted_nodes_default:
print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}") print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}")
# 与 NetworkX 对比 (如果安装了) # 与 NetworkX 对比 (如果安装了)
try: try:
import networkx as nx import networkx as nx
print("\n与 NetworkX PageRank 对比 (个性化)...") print("\n与 NetworkX PageRank 对比 (个性化)...")
G = nx.DiGraph() G = nx.DiGraph()
G.add_nodes_from(nodes_test) G.add_nodes_from(nodes_test)
@ -200,25 +196,29 @@ if __name__ == "__main__":
# 为了更接近我们的实现,我们不传递权重给 add_edges_from # 为了更接近我们的实现,我们不传递权重给 add_edges_from
edges_for_nx = [(u, v) for u, v, w in edges_test] edges_for_nx = [(u, v) for u, v, w in edges_test]
G.add_edges_from(edges_for_nx) G.add_edges_from(edges_for_nx)
# 归一化 NetworkX 的个性化向量 # 归一化 NetworkX 的个性化向量
nx_pers = {node: personalization_test.get(node, 0.0) for node in nodes_test} nx_pers = {node: personalization_test.get(node, 0.0) for node in nodes_test}
pers_sum = sum(nx_pers.values()) pers_sum = sum(nx_pers.values())
if pers_sum > 0: if pers_sum > 0:
nx_pers = {k: v / pers_sum for k, v in nx_pers.items()} nx_pers = {k: v / pers_sum for k, v in nx_pers.items()}
else: # 如果全为0NetworkX 会报错或行为未定义,我们设为 None else: # 如果全为0NetworkX 会报错或行为未定义,我们设为 None
nx_pers = None nx_pers = None
nx_result = nx.pagerank(G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None) # weight=None 强制标准 PageRank nx_result = nx.pagerank(
G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None
) # weight=None 强制标准 PageRank
for node_id in sorted_nodes: for node_id in sorted_nodes:
print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}") print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}")
print("\n与 NetworkX PageRank 对比 (默认)...") print("\n与 NetworkX PageRank 对比 (默认)...")
nx_result_default = nx.pagerank(G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None) nx_result_default = nx.pagerank(
G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None
)
for node_id in sorted_nodes_default: for node_id in sorted_nodes_default:
print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}") print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}")
except ImportError: except ImportError:
print("\n未安装 NetworkX跳过对比。") print("\n未安装 NetworkX跳过对比。")
except Exception as e: except Exception as e:
print(f"\n运行 NetworkX PageRank 时出错: {e}") print(f"\n运行 NetworkX PageRank 时出错: {e}")

View File

@ -1,6 +1,7 @@
import os import os
import toml import toml
from .global_logger import logger from .global_logger import logger
PG_NAMESPACE = "paragraph" PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity" ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation" REL_NAMESPACE = "relation"
@ -58,6 +59,7 @@ def _load_config(config, config_file_path):
logger.info(f"Configurations loaded from file: {config_file_path}") logger.info(f"Configurations loaded from file: {config_file_path}")
global_config = dict( global_config = dict(
{ {
"llm_providers": { "llm_providers": {
@ -119,4 +121,4 @@ file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir) root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
config_path = os.path.join(root_path, "config", "lpmm_config.toml") config_path = os.path.join(root_path, "config", "lpmm_config.toml")
_load_config(global_config, config_path) _load_config(global_config, config_path)

View File

@ -47,9 +47,7 @@ class EmbeddingStore:
self.idx2hash = None self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request( return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
global_config["embedding"]["model"], s
)
def batch_insert_strs(self, strs: List[str]) -> None: def batch_insert_strs(self, strs: List[str]) -> None:
"""向库中存入字符串""" """向库中存入字符串"""
@ -83,14 +81,10 @@ class EmbeddingStore:
logger.info(f"{self.namespace}嵌入库保存成功") logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None: if self.faiss_index is not None and self.idx2hash is not None:
logger.info( logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}"
)
faiss.write_index(self.faiss_index, self.index_file_path) faiss.write_index(self.faiss_index, self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
logger.info( logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}"
)
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
@ -103,24 +97,18 @@ class EmbeddingStore:
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
self.store[row["hash"]] = EmbeddingStoreItem( self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
row["hash"], row["embedding"], row["str"]
)
logger.info(f"{self.namespace}嵌入库加载成功") logger.info(f"{self.namespace}嵌入库加载成功")
try: try:
if os.path.exists(self.index_file_path): if os.path.exists(self.index_file_path):
logger.info( logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex"
)
self.faiss_index = faiss.read_index(self.index_file_path) self.faiss_index = faiss.read_index(self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
else: else:
raise Exception(f"文件{self.index_file_path}不存在") raise Exception(f"文件{self.index_file_path}不存在")
if os.path.exists(self.idx2hash_file_path): if os.path.exists(self.idx2hash_file_path):
logger.info( logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射"
)
with open(self.idx2hash_file_path, "r") as f: with open(self.idx2hash_file_path, "r") as f:
self.idx2hash = json.load(f) self.idx2hash = json.load(f)
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
@ -215,9 +203,7 @@ class EmbeddingManager:
for triples in triple_list_data.values(): for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples]) graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples)) graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs( self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
[str(triple) for triple in graph_triples]
)
def load_from_file(self): def load_from_file(self):
"""从文件加载""" """从文件加载"""

View File

@ -7,8 +7,6 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
console_logging_handler = logging.StreamHandler() console_logging_handler = logging.StreamHandler()
console_logging_handler.setFormatter( console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
)
console_logging_handler.setLevel(logging.DEBUG) console_logging_handler.setLevel(logging.DEBUG)
logger.addHandler(console_logging_handler) logger.addHandler(console_logging_handler)

View File

@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
return entity_extract_result return entity_extract_result
def _rdf_triple_extract( def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
llm_client: LLMClient, paragraph: str, entities: list
) -> List[List[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_rdf_triple_extract_context( entity_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False) paragraph, entities=json.dumps(entities, ensure_ascii=False)
) )
_, request_result = llm_client.send_chat_request( _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
global_config["rdf_build"]["llm"]["model"], entity_extract_context
)
# 去除‘{’前的内容(结果中可能有多个‘{ # 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result: if "[" in request_result:
@ -60,11 +56,7 @@ def _rdf_triple_extract(
entity_extract_result = json.loads(fix_broken_generated_json(request_result)) entity_extract_result = json.loads(fix_broken_generated_json(request_result))
for triple in entity_extract_result: for triple in entity_extract_result:
if ( if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
len(triple) != 3
or (triple[0] is None or triple[1] is None or triple[2] is None)
or "" in triple
):
raise Exception("RDF提取结果格式错误") raise Exception("RDF提取结果格式错误")
return entity_extract_result return entity_extract_result
@ -91,9 +83,7 @@ def info_extract_from_str(
try_count = 0 try_count = 0
while True: while True:
try: try:
rdf_triple_extract_result = _rdf_triple_extract( rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result)
llm_client_for_rdf, paragraph, entity_extract_result
)
break break
except Exception as e: except Exception as e:
logger.warning(f"实体提取失败,错误信息:{e}") logger.warning(f"实体提取失败,错误信息:{e}")

View File

@ -22,6 +22,7 @@ from .config import (
from .global_logger import logger from .global_logger import logger
class KGManager: class KGManager:
def __init__(self): def __init__(self):
# 会被保存的字段 # 会被保存的字段
@ -35,9 +36,7 @@ class KGManager:
# 持久化相关 # 持久化相关
self.dir_path = global_config["persistence"]["rag_data_dir"] self.dir_path = global_config["persistence"]["rag_data_dir"]
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz" self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz"
self.ent_cnt_data_path = ( self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
)
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
def save_to_file(self): def save_to_file(self):
@ -50,9 +49,7 @@ class KGManager:
nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8") nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8")
# 保存实体计数到文件 # 保存实体计数到文件
ent_cnt_df = pd.DataFrame( ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()])
[{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]
)
ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
# 保存段落hash到文件 # 保存段落hash到文件
@ -77,9 +74,7 @@ class KGManager:
# 加载实体计数 # 加载实体计数
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
self.ent_appear_cnt = dict( self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
{row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
)
# 加载KG # 加载KG
self.graph = nx.read_graphml(self.graph_data_path) self.graph = nx.read_graphml(self.graph_data_path)
@ -101,20 +96,14 @@ class KGManager:
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = ( node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
)
node_to_node[(hash_key2, hash_key1)] = (
node_to_node.get((hash_key2, hash_key1), 0) + 1.0
)
entity_set.add(hash_key1) entity_set.add(hash_key1)
entity_set.add(hash_key2) entity_set.add(hash_key2)
# 实体出现次数统计 # 实体出现次数统计
for hash_key in entity_set: for hash_key in entity_set:
self.ent_appear_cnt[hash_key] = ( self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0
self.ent_appear_cnt.get(hash_key, 0) + 1.0
)
@staticmethod @staticmethod
def _build_edges_between_ent_pg( def _build_edges_between_ent_pg(
@ -126,9 +115,7 @@ class KGManager:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx) pg_hash_key = PG_NAMESPACE + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = ( node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
)
@staticmethod @staticmethod
def _synonym_connect( def _synonym_connect(
@ -177,9 +164,7 @@ class KGManager:
new_edge_cnt += 1 new_edge_cnt += 1
res_ent.append( res_ent.append(
( (
embedding_manager.entities_embedding_store.store[ embedding_manager.entities_embedding_store.store[res_ent_hash].str,
res_ent_hash
].str,
similarity, similarity,
) )
) # Debug ) # Debug
@ -236,22 +221,16 @@ class KGManager:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE): if node_hash.startswith(ENT_NAMESPACE):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store[ node = embedding_manager.entities_embedding_store.store[node_hash]
node_hash
]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
self.graph.nodes[node_hash]["content"] = node.str self.graph.nodes[node_hash]["content"] = node.str
self.graph.nodes[node_hash]["type"] = "ent" self.graph.nodes[node_hash]["type"] = "ent"
elif node_hash.startswith(PG_NAMESPACE): elif node_hash.startswith(PG_NAMESPACE):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[ node = embedding_manager.paragraphs_embedding_store.store[node_hash]
node_hash
]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
content = node.str.replace("\n", " ") content = node.str.replace("\n", " ")
self.graph.nodes[node_hash]["content"] = ( self.graph.nodes[node_hash]["content"] = content if len(content) < 8 else content[:8] + "..."
content if len(content) < 8 else content[:8] + "..."
)
self.graph.nodes[node_hash]["type"] = "pg" self.graph.nodes[node_hash]["type"] = "pg"
def build_kg( def build_kg(
@ -319,9 +298,7 @@ class KGManager:
ent_sim_scores = {} ent_sim_scores = {}
for relation_hash, similarity, _ in relation_search_result: for relation_hash, similarity, _ in relation_search_result:
# 提取主宾短语 # 提取主宾短语
relation = embed_manager.relation_embedding_store.store.get( relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
relation_hash
).str
assert relation is not None # 断言relation不为空 assert relation is not None # 断言relation不为空
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
@ -335,9 +312,7 @@ class KGManager:
ent_mean_scores = {} # 记录实体的平均相似度 ent_mean_scores = {} # 记录实体的平均相似度
for ent_hash, scores in ent_sim_scores.items(): for ent_hash, scores in ent_sim_scores.items():
# 先对相似度进行累加,然后与实体计数相除获取最终权重 # 先对相似度进行累加,然后与实体计数相除获取最终权重
ent_weights[ent_hash] = ( ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
)
# 记录实体的平均相似度用于后续的top_k筛选 # 记录实体的平均相似度用于后续的top_k筛选
ent_mean_scores[ent_hash] = float(np.mean(scores)) ent_mean_scores[ent_hash] = float(np.mean(scores))
del ent_sim_scores del ent_sim_scores
@ -354,21 +329,14 @@ class KGManager:
for ent_hash, score in ent_weights.items(): for ent_hash, score in ent_weights.items():
# 缩放相似度 # 缩放相似度
ent_weights[ent_hash] = ( ent_weights[ent_hash] = (
(score - ent_weights_min) (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min)
* (1 - down_edge)
/ (ent_weights_max - ent_weights_min)
) + down_edge ) + down_edge
# 取平均相似度的top_k实体 # 取平均相似度的top_k实体
top_k = global_config["qa"]["params"]["ent_filter_top_k"] top_k = global_config["qa"]["params"]["ent_filter_top_k"]
if len(ent_mean_scores) > top_k: if len(ent_mean_scores) > top_k:
# 从大到小排序取后len - k个 # 从大到小排序取后len - k个
ent_mean_scores = { ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
k: v
for k, v in sorted(
ent_mean_scores.items(), key=lambda item: item[1], reverse=True
)
}
for ent_hash, _ in ent_mean_scores.items(): for ent_hash, _ in ent_mean_scores.items():
# 删除被淘汰的实体节点权重设置 # 删除被淘汰的实体节点权重设置
del ent_weights[ent_hash] del ent_weights[ent_hash]
@ -389,9 +357,7 @@ class KGManager:
# 归一化 # 归一化
for pg_hash, similarity in pg_sim_scores.items(): for pg_hash, similarity in pg_sim_scores.items():
# 归一化相似度 # 归一化相似度
pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / ( pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min)
pg_sim_score_max - pg_sim_score_min
)
del pg_sim_score_max, pg_sim_score_min del pg_sim_score_max, pg_sim_score_min
for pg_hash, score in pg_sim_scores.items(): for pg_hash, score in pg_sim_scores.items():
@ -401,9 +367,7 @@ class KGManager:
del pg_sim_scores del pg_sim_scores
# 最终权重数据 = 实体权重 + 文段权重 # 最终权重数据 = 实体权重 + 文段权重
ppr_node_weights = { ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()}
k: v for d in [ent_weights, pg_weights] for k, v in d.items()
}
del ent_weights, pg_weights del ent_weights, pg_weights
# PersonalizedPageRank # PersonalizedPageRank
@ -418,14 +382,12 @@ class KGManager:
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score)
for node_key, score in ppr_res.items() # Iterate over dictionary items for node_key, score in ppr_res.items() # Iterate over dictionary items
if node_key.startswith(PG_NAMESPACE) if node_key.startswith(PG_NAMESPACE)
] ]
del ppr_res del ppr_res
# 排序:按照分数从大到小 # 排序:按照分数从大到小
passage_node_res = sorted( passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
passage_node_res, key=lambda item: item[1], reverse=True
)
return passage_node_res, ppr_node_weights return passage_node_res, ppr_node_weights

View File

@ -21,20 +21,14 @@ class LLMClient:
def send_chat_request(self, model, messages): def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果""" """发送对话请求,等待返回结果"""
response = self.client.chat.completions.create( response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
model=model, messages=messages, stream=False
)
if hasattr(response.choices[0].message, "reasoning_content"): if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块 # 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content content = response.choices[0].message.content
else: else:
# 无单独的推理内容块 # 无单独的推理内容块
response = ( response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
response.choices[0]
.message.content.split("<think>")[-1]
.split("</think>")
)
# 如果有推理内容,则分割推理内容和内容 # 如果有推理内容,则分割推理内容和内容
if len(response) == 2: if len(response) == 2:
reasoning_content = response[0] reasoning_content = response[0]
@ -48,6 +42,4 @@ class LLMClient:
def send_embedding_request(self, model, text): def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果""" """发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ") text = text.replace("\n", " ")
return ( return self.client.embeddings.create(input=[text], model=model).data[0].embedding
self.client.embeddings.create(input=[text], model=model).data[0].embedding
)

View File

@ -58,7 +58,6 @@ def _load_config(config, config_file_path):
config["persistence"] = file_config["persistence"] config["persistence"] = file_config["persistence"]
print(config) print(config)
print("Configurations loaded from file: ", config_file_path) print("Configurations loaded from file: ", config_file_path)
parser = argparse.ArgumentParser(description="Configurations for the pipeline") parser = argparse.ArgumentParser(description="Configurations for the pipeline")
@ -122,9 +121,9 @@ global_config = dict(
"embedding_data_dir": "data/embedding", "embedding_data_dir": "data/embedding",
"rag_data_dir": "data/rag", "rag_data_dir": "data/rag",
}, },
"info_extraction":{ "info_extraction": {
"workers": 10, "workers": 10,
} },
} }
) )

View File

@ -16,13 +16,9 @@ class MemoryActiveManager:
def get_activation(self, question: str) -> float: def get_activation(self, question: str) -> float:
"""获取记忆激活度""" """获取记忆激活度"""
# 生成问题的Embedding # 生成问题的Embedding
question_embedding = self.embedding_client.send_embedding_request( question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
"text-embedding", question
)
# 查询关系库中的相似度 # 查询关系库中的相似度
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k( rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
question_embedding, 10
)
# 动态过滤阈值 # 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0) rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)

View File

@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体""" """过滤无效的实体"""
valid_entities = set() valid_entities = set()
for entity in entities: for entity in entities:
if ( if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities:
not isinstance(entity, str)
or entity.strip() == ""
or entity in INVALID_ENTITY
or entity in valid_entities
):
# 非字符串/空字符串/在无效实体列表中/重复 # 非字符串/空字符串/在无效实体列表中/重复
continue continue
valid_entities.add(entity) valid_entities.add(entity)
@ -74,9 +69,7 @@ class OpenIE:
for doc in self.docs: for doc in self.docs:
# 过滤实体列表 # 过滤实体列表
doc["extracted_entities"] = _filter_invalid_entities( doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"])
doc["extracted_entities"]
)
# 过滤无效的三元组 # 过滤无效的三元组
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@ -100,9 +93,7 @@ class OpenIE:
@staticmethod @staticmethod
def load() -> "OpenIE": def load() -> "OpenIE":
"""从文件中加载OpenIE数据""" """从文件中加载OpenIE数据"""
with open( with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
global_config["persistence"]["openie_data_path"], "r", encoding="utf-8"
) as f:
data = json.loads(f.read()) data = json.loads(f.read())
openie_data = OpenIE._from_dict(data) openie_data = OpenIE._from_dict(data)
@ -112,9 +103,7 @@ class OpenIE:
@staticmethod @staticmethod
def save(openie_data: "OpenIE"): def save(openie_data: "OpenIE"):
"""保存OpenIE数据到文件""" """保存OpenIE数据到文件"""
with open( with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
global_config["persistence"]["openie_data_path"], "w", encoding="utf-8"
) as f:
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
def extract_entity_dict(self): def extract_entity_dict(self):
@ -141,7 +130,5 @@ class OpenIE:
def extract_raw_paragraph_dict(self): def extract_raw_paragraph_dict(self):
"""提取原始段落""" """提取原始段落"""
raw_paragraph_dict = dict( raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
{doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
)
return raw_paragraph_dict return raw_paragraph_dict

View File

@ -41,9 +41,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]: def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
messages = [ messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage( LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
"user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
).to_dict(),
] ]
return messages return messages
@ -58,16 +56,10 @@ qa_system_prompt = """
""" """
def build_qa_context( def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
question: str, knowledge: list[(str, str, str)] knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
) -> List[LLMMessage]:
knowledge = "\n".join(
[f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
)
messages = [ messages = [
LLMMessage("system", qa_system_prompt).to_dict(), LLMMessage("system", qa_system_prompt).to_dict(),
LLMMessage( LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
"user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
).to_dict(),
] ]
return messages return messages

View File

@ -2,6 +2,7 @@ import time
from typing import Tuple, List, Dict from typing import Tuple, List, Dict
from .global_logger import logger from .global_logger import logger
# from . import prompt_template # from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
from .llm_client import LLMClient from .llm_client import LLMClient
@ -31,7 +32,7 @@ class QAManager:
"""处理查询""" """处理查询"""
# 生成问题的Embedding # 生成问题的Embedding
part_start_time =time.perf_counter() part_start_time = time.perf_counter()
question_embedding = self.llm_client_list["embedding"].send_embedding_request( question_embedding = self.llm_client_list["embedding"].send_embedding_request(
global_config["embedding"]["model"], question global_config["embedding"]["model"], question
) )
@ -39,7 +40,7 @@ class QAManager:
logger.debug(f"Embedding用时{part_end_time - part_start_time:.5f}s") logger.debug(f"Embedding用时{part_end_time - part_start_time:.5f}s")
# 根据问题Embedding查询Relation Embedding库 # 根据问题Embedding查询Relation Embedding库
part_start_time =time.perf_counter() part_start_time = time.perf_counter()
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
question_embedding, question_embedding,
global_config["qa"]["params"]["relation_search_top_k"], global_config["qa"]["params"]["relation_search_top_k"],
@ -47,10 +48,7 @@ class QAManager:
# 过滤阈值 # 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
if ( if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
relation_search_res[0][1]
< global_config["qa"]["params"]["relation_threshold"]
):
# 未找到相关关系 # 未找到相关关系
relation_search_res = [] relation_search_res = []
@ -66,12 +64,10 @@ class QAManager:
# part_start_time = time.time() # part_start_time = time.time()
# 根据问题Embedding查询Paragraph Embedding库 # 根据问题Embedding查询Paragraph Embedding库
part_start_time =time.perf_counter() part_start_time = time.perf_counter()
paragraph_search_res = ( paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
self.embed_manager.paragraphs_embedding_store.search_top_k( question_embedding,
question_embedding, global_config["qa"]["params"]["paragraph_search_top_k"],
global_config["qa"]["params"]["paragraph_search_top_k"],
)
) )
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
@ -79,7 +75,7 @@ class QAManager:
if len(relation_search_res) != 0: if len(relation_search_res) != 0:
logger.info("找到相关关系将使用RAG进行检索") logger.info("找到相关关系将使用RAG进行检索")
# 使用KG检索 # 使用KG检索
part_start_time =time.perf_counter() part_start_time = time.perf_counter()
result, ppr_node_weights = self.kg_manager.kg_search( result, ppr_node_weights = self.kg_manager.kg_search(
relation_search_res, paragraph_search_res, self.embed_manager relation_search_res, paragraph_search_res, self.embed_manager
) )
@ -94,9 +90,7 @@ class QAManager:
result = dyn_select_top_k(result, 0.5, 1.0) result = dyn_select_top_k(result, 0.5, 1.0)
for res in result: for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[ raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
res[0]
].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights return result, ppr_node_weights
@ -114,4 +108,4 @@ class QAManager:
for res in query_res for res in query_res
] ]
found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
return found_knowledge return found_knowledge

View File

@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]:
""" """
# 读取import.json文件 # 读取import.json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
with open( with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
global_config["persistence"]["raw_data_path"], "r", encoding="utf-8"
) as f:
import_json = json.loads(f.read()) import_json = json.loads(f.read())
else: else:
raise Exception("原始数据文件读取失败") raise Exception("原始数据文件读取失败")

View File

@ -14,11 +14,7 @@ class DataLoader:
Args: Args:
custom_data_dir: 可选的自定义数据目录路径如果不提供则使用配置文件中的默认路径 custom_data_dir: 可选的自定义数据目录路径如果不提供则使用配置文件中的默认路径
""" """
self.data_dir = ( self.data_dir = Path(custom_data_dir) if custom_data_dir else Path(config["persistence"]["data_root_path"])
Path(custom_data_dir)
if custom_data_dir
else Path(config["persistence"]["data_root_path"])
)
if not self.data_dir.exists(): if not self.data_dir.exists():
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在") raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")

View File

@ -36,14 +36,10 @@ def dyn_select_top_k(
# 计算均值 # 计算均值
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score) mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
# 计算方差 # 计算方差
var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len( var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score)
normalized_score
)
# 动态阈值 # 动态阈值
threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * ( threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score)
mean_score + var_factor * var_score
)
# 重新过滤 # 重新过滤
res = [s for s in normalized_score if s[2] > threshold] res = [s for s in normalized_score if s[2] > threshold]

View File

@ -29,10 +29,7 @@ def _find_unclosed(json_str):
elif char in "{[": elif char in "{[":
unclosed.append(char) unclosed.append(char)
elif char in "}]": elif char in "}]":
if unclosed and ( if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
(char == "}" and unclosed[-1] == "{")
or (char == "]" and unclosed[-1] == "[")
):
unclosed.pop() unclosed.pop()
return unclosed return unclosed

View File

@ -14,4 +14,4 @@ def draw_graph_and_show(graph):
font_family="Sarasa Mono SC", font_family="Sarasa Mono SC",
font_size=8, font_size=8,
) )
fig.show() fig.show()