mirror of https://github.com/Mai-with-u/MaiBot.git
🤖 自动格式化代码 [skip ci]
parent
997e4bd203
commit
51f398176f
|
|
@ -20,6 +20,7 @@ import sys
|
||||||
|
|
||||||
logger = get_module_logger("LPMM知识库-OpenIE导入")
|
logger = get_module_logger("LPMM知识库-OpenIE导入")
|
||||||
|
|
||||||
|
|
||||||
def hash_deduplicate(
|
def hash_deduplicate(
|
||||||
raw_paragraphs: Dict[str, str],
|
raw_paragraphs: Dict[str, str],
|
||||||
triple_list_data: Dict[str, List[List[str]]],
|
triple_list_data: Dict[str, List[List[str]]],
|
||||||
|
|
@ -43,14 +44,10 @@ def hash_deduplicate(
|
||||||
# 保存去重后的三元组
|
# 保存去重后的三元组
|
||||||
new_triple_list_data = dict()
|
new_triple_list_data = dict()
|
||||||
|
|
||||||
for _, (raw_paragraph, triple_list) in enumerate(
|
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
|
||||||
zip(raw_paragraphs.values(), triple_list_data.values())
|
|
||||||
):
|
|
||||||
# 段落hash
|
# 段落hash
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
paragraph_hash = get_sha256(raw_paragraph)
|
||||||
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (
|
if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes):
|
||||||
paragraph_hash in stored_paragraph_hashes
|
|
||||||
):
|
|
||||||
continue
|
continue
|
||||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||||
new_triple_list_data[paragraph_hash] = triple_list
|
new_triple_list_data[paragraph_hash] = triple_list
|
||||||
|
|
@ -58,9 +55,7 @@ def hash_deduplicate(
|
||||||
return new_raw_paragraphs, new_triple_list_data
|
return new_raw_paragraphs, new_triple_list_data
|
||||||
|
|
||||||
|
|
||||||
def handle_import_openie(
|
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
||||||
openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager
|
|
||||||
) -> bool:
|
|
||||||
# 从OpenIE数据中提取段落原文与三元组列表
|
# 从OpenIE数据中提取段落原文与三元组列表
|
||||||
# 索引的段落原文
|
# 索引的段落原文
|
||||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||||
|
|
@ -68,9 +63,7 @@ def handle_import_openie(
|
||||||
entity_list_data = openie_data.extract_entity_dict()
|
entity_list_data = openie_data.extract_entity_dict()
|
||||||
# 索引的三元组列表
|
# 索引的三元组列表
|
||||||
triple_list_data = openie_data.extract_triple_dict()
|
triple_list_data = openie_data.extract_triple_dict()
|
||||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(
|
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||||
triple_list_data
|
|
||||||
):
|
|
||||||
logger.error("OpenIE数据存在异常")
|
logger.error("OpenIE数据存在异常")
|
||||||
return False
|
return False
|
||||||
# 将索引换为对应段落的hash值
|
# 将索引换为对应段落的hash值
|
||||||
|
|
@ -112,11 +105,11 @@ def main():
|
||||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||||
if confirm != 'y':
|
if confirm != "y":
|
||||||
logger.info("用户取消操作")
|
logger.info("用户取消操作")
|
||||||
print("操作已取消")
|
print("操作已取消")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
print("\n" + "="*40 + "\n")
|
print("\n" + "=" * 40 + "\n")
|
||||||
|
|
||||||
logger.info("----开始导入openie数据----\n")
|
logger.info("----开始导入openie数据----\n")
|
||||||
|
|
||||||
|
|
@ -129,9 +122,7 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = embed_manager = EmbeddingManager(
|
embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||||
llm_client_list[global_config["embedding"]["provider"]]
|
|
||||||
)
|
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
|
|
|
||||||
|
|
@ -91,11 +91,11 @@ def main():
|
||||||
print("或者使用可以用赠金抵扣的Pro模型")
|
print("或者使用可以用赠金抵扣的Pro模型")
|
||||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||||
if confirm != 'y':
|
if confirm != "y":
|
||||||
logger.info("用户取消操作")
|
logger.info("用户取消操作")
|
||||||
print("操作已取消")
|
print("操作已取消")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
print("\n" + "="*40 + "\n")
|
print("\n" + "=" * 40 + "\n")
|
||||||
|
|
||||||
logger.info("--------进行信息提取--------\n")
|
logger.info("--------进行信息提取--------\n")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,17 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys # 新增系统模块导入
|
import sys # 新增系统模块导入
|
||||||
|
|
||||||
|
|
||||||
def check_and_create_dirs():
|
def check_and_create_dirs():
|
||||||
"""检查并创建必要的目录"""
|
"""检查并创建必要的目录"""
|
||||||
required_dirs = [
|
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
|
||||||
"data/lpmm_raw_data",
|
|
||||||
"data/imported_lpmm_data"
|
|
||||||
]
|
|
||||||
|
|
||||||
for dir_path in required_dirs:
|
for dir_path in required_dirs:
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
os.makedirs(dir_path)
|
os.makedirs(dir_path)
|
||||||
print(f"已创建目录: {dir_path}")
|
print(f"已创建目录: {dir_path}")
|
||||||
|
|
||||||
|
|
||||||
def process_text_file(file_path):
|
def process_text_file(file_path):
|
||||||
"""处理单个文本文件,返回段落列表"""
|
"""处理单个文本文件,返回段落列表"""
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
|
@ -29,12 +28,13 @@ def process_text_file(file_path):
|
||||||
paragraph = ""
|
paragraph = ""
|
||||||
else:
|
else:
|
||||||
paragraph += line + "\n"
|
paragraph += line + "\n"
|
||||||
|
|
||||||
if paragraph != "":
|
if paragraph != "":
|
||||||
paragraphs.append(paragraph.strip())
|
paragraphs.append(paragraph.strip())
|
||||||
|
|
||||||
return paragraphs
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
|
|
@ -43,42 +43,43 @@ def main():
|
||||||
print("在进行知识库导入之前")
|
print("在进行知识库导入之前")
|
||||||
print("请修改config/lpmm_config.toml中的配置项")
|
print("请修改config/lpmm_config.toml中的配置项")
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||||
if confirm != 'y':
|
if confirm != "y":
|
||||||
print("操作已取消")
|
print("操作已取消")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
print("\n" + "="*40 + "\n")
|
print("\n" + "=" * 40 + "\n")
|
||||||
|
|
||||||
# 检查并创建必要的目录
|
# 检查并创建必要的目录
|
||||||
check_and_create_dirs()
|
check_and_create_dirs()
|
||||||
|
|
||||||
# 检查输出文件是否存在
|
# 检查输出文件是否存在
|
||||||
if os.path.exists("data/import.json"):
|
if os.path.exists("data/import.json"):
|
||||||
print("错误: data/import.json 已存在,请先处理或删除该文件")
|
print("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if os.path.exists("data/openie.json"):
|
if os.path.exists("data/openie.json"):
|
||||||
print("错误: data/openie.json 已存在,请先处理或删除该文件")
|
print("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 获取所有原始文本文件
|
# 获取所有原始文本文件
|
||||||
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
|
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
|
||||||
if not raw_files:
|
if not raw_files:
|
||||||
print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 处理所有文件
|
# 处理所有文件
|
||||||
all_paragraphs = []
|
all_paragraphs = []
|
||||||
for file in raw_files:
|
for file in raw_files:
|
||||||
print(f"正在处理文件: {file.name}")
|
print(f"正在处理文件: {file.name}")
|
||||||
paragraphs = process_text_file(file)
|
paragraphs = process_text_file(file)
|
||||||
all_paragraphs.extend(paragraphs)
|
all_paragraphs.extend(paragraphs)
|
||||||
|
|
||||||
# 保存合并后的结果
|
# 保存合并后的结果
|
||||||
output_path = "data/import.json"
|
output_path = "data/import.json"
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
print(f"处理完成,结果已保存到: {output_path}")
|
print(f"处理完成,结果已保存到: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from src.do_tool.tool_can_use.base_tool import BaseTool
|
from src.do_tool.tool_can_use.base_tool import BaseTool
|
||||||
from src.plugins.chat.utils import get_embedding
|
from src.plugins.chat.utils import get_embedding
|
||||||
|
|
||||||
# from src.common.database import db
|
# from src.common.database import db
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
|
||||||
|
|
@ -364,7 +364,7 @@ class PromptBuilder:
|
||||||
# grouped_results[topic] = []
|
# grouped_results[topic] = []
|
||||||
# grouped_results[topic].append(result)
|
# grouped_results[topic].append(result)
|
||||||
|
|
||||||
# 按主题组织输出
|
# 按主题组织输出
|
||||||
# for topic, results in grouped_results.items():
|
# for topic, results in grouped_results.items():
|
||||||
# related_info += f"【主题: {topic}】\n"
|
# related_info += f"【主题: {topic}】\n"
|
||||||
# for _i, result in enumerate(results, 1):
|
# for _i, result in enumerate(results, 1):
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,7 @@ for key in global_config["llm_providers"]:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = EmbeddingManager(
|
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||||
llm_client_list[global_config["embedding"]["provider"]]
|
|
||||||
)
|
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
|
|
@ -63,4 +61,3 @@ inspire_manager = MemoryActiveManager(
|
||||||
embed_manager,
|
embed_manager,
|
||||||
llm_client_list[global_config["embedding"]["provider"]],
|
llm_client_list[global_config["embedding"]["provider"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,7 @@ def _nx_graph_to_lists(
|
||||||
A tuple containing the list of edges and the list of nodes.
|
A tuple containing the list of edges and the list of nodes.
|
||||||
"""
|
"""
|
||||||
nodes = [node for node in graph.nodes()]
|
nodes = [node for node in graph.nodes()]
|
||||||
edges = [
|
edges = [(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()]
|
||||||
(u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()
|
|
||||||
]
|
|
||||||
|
|
||||||
return edges, nodes
|
return edges, nodes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ def pagerank_py(
|
||||||
personalization: Optional[Dict[str, float]] = None,
|
personalization: Optional[Dict[str, float]] = None,
|
||||||
alpha: float = 0.85,
|
alpha: float = 0.85,
|
||||||
max_iter: int = 100,
|
max_iter: int = 100,
|
||||||
tol: float = 1e-6
|
tol: float = 1e-6,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""使用 Python、NumPy 和 SciPy 计算个性化 PageRank。
|
"""使用 Python、NumPy 和 SciPy 计算个性化 PageRank。
|
||||||
|
|
||||||
|
|
@ -44,14 +44,14 @@ def pagerank_py(
|
||||||
raw_values = np.maximum(raw_values, 0)
|
raw_values = np.maximum(raw_values, 0)
|
||||||
norm_sum = np.sum(raw_values)
|
norm_sum = np.sum(raw_values)
|
||||||
|
|
||||||
if norm_sum > 1e-9: # 避免除以零
|
if norm_sum > 1e-9: # 避免除以零
|
||||||
personalization_vec = raw_values / norm_sum
|
personalization_vec = raw_values / norm_sum
|
||||||
else:
|
else:
|
||||||
# 如果所有提供的个性化值都为零或负数,则回退到均匀分布
|
# 如果所有提供的个性化值都为零或负数,则回退到均匀分布
|
||||||
print("警告:个性化值总和为零或所有值均为非正数。回退到均匀个性化设置。")
|
print("警告:个性化值总和为零或所有值均为非正数。回退到均匀个性化设置。")
|
||||||
personalization_vec.fill(1.0 / num_nodes)
|
personalization_vec.fill(1.0 / num_nodes)
|
||||||
|
|
||||||
# --- 构建稀疏邻接矩阵 ---
|
# --- 构建稀疏邻接矩阵 ---
|
||||||
# 标准 PageRank 需要基于出度的归一化
|
# 标准 PageRank 需要基于出度的归一化
|
||||||
row_ind = []
|
row_ind = []
|
||||||
col_ind = []
|
col_ind = []
|
||||||
|
|
@ -66,9 +66,9 @@ def pagerank_py(
|
||||||
col_ind.append(src_idx)
|
col_ind.append(src_idx)
|
||||||
# 暂存原始权重,如果需要加权 PageRank,可以在此使用 w
|
# 暂存原始权重,如果需要加权 PageRank,可以在此使用 w
|
||||||
# 对于标准 PageRank,我们只需要知道连接存在
|
# 对于标准 PageRank,我们只需要知道连接存在
|
||||||
data.append(1.0) # 初始数据设为 1,之后归一化
|
data.append(1.0) # 初始数据设为 1,之后归一化
|
||||||
# 标准 PageRank 的出度是边的数量,加权 PageRank 可以用 w
|
# 标准 PageRank 的出度是边的数量,加权 PageRank 可以用 w
|
||||||
out_degree[src_idx] += 1
|
out_degree[src_idx] += 1
|
||||||
|
|
||||||
# 归一化权重(构建转移矩阵 M 的转置 M.T)
|
# 归一化权重(构建转移矩阵 M 的转置 M.T)
|
||||||
# M[j, i] 是从 i 到 j 的概率
|
# M[j, i] 是从 i 到 j 的概率
|
||||||
|
|
@ -82,17 +82,16 @@ def pagerank_py(
|
||||||
if out_degree[c] > 0:
|
if out_degree[c] > 0:
|
||||||
# 标准 PageRank: 1.0 / out_degree[c]
|
# 标准 PageRank: 1.0 / out_degree[c]
|
||||||
# 如果要用原始权重 w 作为转移概率(需确保它们已归一化),则用 w / sum(w for edges from c)
|
# 如果要用原始权重 w 作为转移概率(需确保它们已归一化),则用 w / sum(w for edges from c)
|
||||||
normalized_data.append(d / out_degree[c])
|
normalized_data.append(d / out_degree[c])
|
||||||
new_row_ind.append(c) # M.T 的行索引是 src_idx
|
new_row_ind.append(c) # M.T 的行索引是 src_idx
|
||||||
new_col_ind.append(r) # M.T 的列索引是 dst_idx
|
new_col_ind.append(r) # M.T 的列索引是 dst_idx
|
||||||
|
|
||||||
# 创建稀疏矩阵 (M.T)
|
# 创建稀疏矩阵 (M.T)
|
||||||
# 注意:scipy.sparse 期望 (data, (row_ind, col_ind)) 格式
|
# 注意:scipy.sparse 期望 (data, (row_ind, col_ind)) 格式
|
||||||
# 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ...
|
# 这里构建的是 M 的转置,方便后续计算 scores = alpha * M.T @ scores + ...
|
||||||
if len(normalized_data) > 0:
|
if len(normalized_data) > 0:
|
||||||
# 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法)
|
# 使用 csc_matrix 以便高效地进行列操作(矩阵向量乘法)
|
||||||
M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)),
|
M_T = sp.csc_matrix((normalized_data, (new_row_ind, new_col_ind)), shape=(num_nodes, num_nodes))
|
||||||
shape=(num_nodes, num_nodes))
|
|
||||||
else:
|
else:
|
||||||
M_T = sp.csc_matrix((num_nodes, num_nodes))
|
M_T = sp.csc_matrix((num_nodes, num_nodes))
|
||||||
|
|
||||||
|
|
@ -109,44 +108,45 @@ def pagerank_py(
|
||||||
# 还有一种做法是仅分配给个性化向量中非零的节点
|
# 还有一种做法是仅分配给个性化向量中非零的节点
|
||||||
|
|
||||||
# --- PageRank 迭代 ---
|
# --- PageRank 迭代 ---
|
||||||
scores = personalization_vec.copy() # 从个性化向量开始
|
scores = personalization_vec.copy() # 从个性化向量开始
|
||||||
|
|
||||||
for iteration in range(max_iter):
|
for iteration in range(max_iter):
|
||||||
prev_scores = scores.copy()
|
prev_scores = scores.copy()
|
||||||
|
|
||||||
# 计算来自链接的贡献
|
# 计算来自链接的贡献
|
||||||
linked_scores = M_T @ scores
|
linked_scores = M_T @ scores
|
||||||
|
|
||||||
# 计算来自悬挂节点的贡献
|
# 计算来自悬挂节点的贡献
|
||||||
# 悬挂节点的总分数 * 悬挂权重向量
|
# 悬挂节点的总分数 * 悬挂权重向量
|
||||||
dangling_sum = np.sum(scores[is_dangling])
|
dangling_sum = np.sum(scores[is_dangling])
|
||||||
dangling_contribution = dangling_sum * dangling_weights
|
dangling_contribution = dangling_sum * dangling_weights
|
||||||
|
|
||||||
# 结合瞬移、链接贡献和悬挂节点贡献
|
# 结合瞬移、链接贡献和悬挂节点贡献
|
||||||
scores = alpha * (linked_scores + dangling_contribution) + (1 - alpha) * personalization_vec
|
scores = alpha * (linked_scores + dangling_contribution) + (1 - alpha) * personalization_vec
|
||||||
|
|
||||||
# 检查收敛性 (L1 范数)
|
# 检查收敛性 (L1 范数)
|
||||||
diff = np.sum(np.abs(scores - prev_scores))
|
diff = np.sum(np.abs(scores - prev_scores))
|
||||||
if diff < tol:
|
if diff < tol:
|
||||||
print(f"在 {iteration + 1} 次迭代后收敛。")
|
print(f"在 {iteration + 1} 次迭代后收敛。")
|
||||||
break
|
break
|
||||||
else: # 循环完成但未中断
|
else: # 循环完成但未中断
|
||||||
print(f"达到最大迭代次数 ({max_iter}) 但未收敛。")
|
print(f"达到最大迭代次数 ({max_iter}) 但未收敛。")
|
||||||
|
|
||||||
# --- 格式化输出 ---
|
# --- 格式化输出 ---
|
||||||
result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)}
|
result_dict = {index_to_node[i]: scores[i] for i in range(num_nodes)}
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
# --- 示例用法(类似于 pr.c 中的 main)---
|
# --- 示例用法(类似于 pr.c 中的 main)---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
nodes_test = ["0", "1", "2", "3", "4"]
|
nodes_test = ["0", "1", "2", "3", "4"]
|
||||||
edges_test = [
|
edges_test = [
|
||||||
("0", "1", 0.5), # 权重在此实现中仅用于确定出度
|
("0", "1", 0.5), # 权重在此实现中仅用于确定出度
|
||||||
("1", "2", 0.3),
|
("1", "2", 0.3),
|
||||||
("2", "0", 0.2),
|
("2", "0", 0.2),
|
||||||
("1", "3", 0.4),
|
("1", "3", 0.4),
|
||||||
("3", "4", 0.6),
|
("3", "4", 0.6),
|
||||||
("4", "1", 0.7)
|
("4", "1", 0.7),
|
||||||
]
|
]
|
||||||
# 添加一个悬挂节点示例
|
# 添加一个悬挂节点示例
|
||||||
nodes_test.append("5")
|
nodes_test.append("5")
|
||||||
|
|
@ -161,37 +161,33 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print("运行优化的 Python PageRank 实现...")
|
print("运行优化的 Python PageRank 实现...")
|
||||||
result = pagerank_py(
|
result = pagerank_py(
|
||||||
nodes_test,
|
nodes_test, edges_test, personalization_test, alpha=alpha_test, max_iter=max_iter_test, tol=tol_test
|
||||||
edges_test,
|
|
||||||
personalization_test,
|
|
||||||
alpha=alpha_test,
|
|
||||||
max_iter=max_iter_test,
|
|
||||||
tol=tol_test
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print("\nPageRank 分数:")
|
print("\nPageRank 分数:")
|
||||||
# 按节点索引排序以获得一致的输出
|
# 按节点索引排序以获得一致的输出
|
||||||
sorted_nodes = sorted(result.keys(), key=lambda x: int(x))
|
sorted_nodes = sorted(result.keys(), key=lambda x: int(x))
|
||||||
for node_id in sorted_nodes:
|
for node_id in sorted_nodes:
|
||||||
print(f"节点 {node_id}: {result[node_id]:.6f}")
|
print(f"节点 {node_id}: {result[node_id]:.6f}")
|
||||||
|
|
||||||
print("\n使用默认个性化设置运行...")
|
print("\n使用默认个性化设置运行...")
|
||||||
result_default_pers = pagerank_py(
|
result_default_pers = pagerank_py(
|
||||||
nodes_test,
|
nodes_test,
|
||||||
edges_test,
|
edges_test,
|
||||||
personalization=None, # 使用默认的统一性化设置
|
personalization=None, # 使用默认的统一性化设置
|
||||||
alpha=alpha_test,
|
alpha=alpha_test,
|
||||||
max_iter=max_iter_test,
|
max_iter=max_iter_test,
|
||||||
tol=tol_test
|
tol=tol_test,
|
||||||
)
|
)
|
||||||
print("\nPageRank 分数(默认个性化):")
|
print("\nPageRank 分数(默认个性化):")
|
||||||
sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x))
|
sorted_nodes_default = sorted(result_default_pers.keys(), key=lambda x: int(x))
|
||||||
for node_id in sorted_nodes_default:
|
for node_id in sorted_nodes_default:
|
||||||
print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}")
|
print(f"节点 {node_id}: {result_default_pers[node_id]:.6f}")
|
||||||
|
|
||||||
# 与 NetworkX 对比 (如果安装了)
|
# 与 NetworkX 对比 (如果安装了)
|
||||||
try:
|
try:
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
print("\n与 NetworkX PageRank 对比 (个性化)...")
|
print("\n与 NetworkX PageRank 对比 (个性化)...")
|
||||||
G = nx.DiGraph()
|
G = nx.DiGraph()
|
||||||
G.add_nodes_from(nodes_test)
|
G.add_nodes_from(nodes_test)
|
||||||
|
|
@ -200,25 +196,29 @@ if __name__ == "__main__":
|
||||||
# 为了更接近我们的实现,我们不传递权重给 add_edges_from
|
# 为了更接近我们的实现,我们不传递权重给 add_edges_from
|
||||||
edges_for_nx = [(u, v) for u, v, w in edges_test]
|
edges_for_nx = [(u, v) for u, v, w in edges_test]
|
||||||
G.add_edges_from(edges_for_nx)
|
G.add_edges_from(edges_for_nx)
|
||||||
|
|
||||||
# 归一化 NetworkX 的个性化向量
|
# 归一化 NetworkX 的个性化向量
|
||||||
nx_pers = {node: personalization_test.get(node, 0.0) for node in nodes_test}
|
nx_pers = {node: personalization_test.get(node, 0.0) for node in nodes_test}
|
||||||
pers_sum = sum(nx_pers.values())
|
pers_sum = sum(nx_pers.values())
|
||||||
if pers_sum > 0:
|
if pers_sum > 0:
|
||||||
nx_pers = {k: v / pers_sum for k, v in nx_pers.items()}
|
nx_pers = {k: v / pers_sum for k, v in nx_pers.items()}
|
||||||
else: # 如果全为0,NetworkX 会报错或行为未定义,我们设为 None
|
else: # 如果全为0,NetworkX 会报错或行为未定义,我们设为 None
|
||||||
nx_pers = None
|
nx_pers = None
|
||||||
|
|
||||||
nx_result = nx.pagerank(G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None) # weight=None 强制标准 PageRank
|
nx_result = nx.pagerank(
|
||||||
|
G, alpha=alpha_test, personalization=nx_pers, max_iter=max_iter_test, tol=tol_test, weight=None
|
||||||
|
) # weight=None 强制标准 PageRank
|
||||||
for node_id in sorted_nodes:
|
for node_id in sorted_nodes:
|
||||||
print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}")
|
print(f"节点 {node_id}: {nx_result.get(node_id, 0.0):.6f}")
|
||||||
|
|
||||||
print("\n与 NetworkX PageRank 对比 (默认)...")
|
print("\n与 NetworkX PageRank 对比 (默认)...")
|
||||||
nx_result_default = nx.pagerank(G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None)
|
nx_result_default = nx.pagerank(
|
||||||
|
G, alpha=alpha_test, personalization=None, max_iter=max_iter_test, tol=tol_test, weight=None
|
||||||
|
)
|
||||||
for node_id in sorted_nodes_default:
|
for node_id in sorted_nodes_default:
|
||||||
print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}")
|
print(f"节点 {node_id}: {nx_result_default.get(node_id, 0.0):.6f}")
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("\n未安装 NetworkX,跳过对比。")
|
print("\n未安装 NetworkX,跳过对比。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n运行 NetworkX PageRank 时出错: {e}")
|
print(f"\n运行 NetworkX PageRank 时出错: {e}")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
PG_NAMESPACE = "paragraph"
|
PG_NAMESPACE = "paragraph"
|
||||||
ENT_NAMESPACE = "entity"
|
ENT_NAMESPACE = "entity"
|
||||||
REL_NAMESPACE = "relation"
|
REL_NAMESPACE = "relation"
|
||||||
|
|
@ -58,6 +59,7 @@ def _load_config(config, config_file_path):
|
||||||
|
|
||||||
logger.info(f"Configurations loaded from file: {config_file_path}")
|
logger.info(f"Configurations loaded from file: {config_file_path}")
|
||||||
|
|
||||||
|
|
||||||
global_config = dict(
|
global_config = dict(
|
||||||
{
|
{
|
||||||
"llm_providers": {
|
"llm_providers": {
|
||||||
|
|
@ -119,4 +121,4 @@ file_path = os.path.abspath(__file__)
|
||||||
dir_path = os.path.dirname(file_path)
|
dir_path = os.path.dirname(file_path)
|
||||||
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
||||||
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
|
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
|
||||||
_load_config(global_config, config_path)
|
_load_config(global_config, config_path)
|
||||||
|
|
|
||||||
|
|
@ -47,9 +47,7 @@ class EmbeddingStore:
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return self.llm_client.send_embedding_request(
|
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||||
global_config["embedding"]["model"], s
|
|
||||||
)
|
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str]) -> None:
|
def batch_insert_strs(self, strs: List[str]) -> None:
|
||||||
"""向库中存入字符串"""
|
"""向库中存入字符串"""
|
||||||
|
|
@ -83,14 +81,10 @@ class EmbeddingStore:
|
||||||
logger.info(f"{self.namespace}嵌入库保存成功")
|
logger.info(f"{self.namespace}嵌入库保存成功")
|
||||||
|
|
||||||
if self.faiss_index is not None and self.idx2hash is not None:
|
if self.faiss_index is not None and self.idx2hash is not None:
|
||||||
logger.info(
|
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
|
||||||
f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}"
|
|
||||||
)
|
|
||||||
faiss.write_index(self.faiss_index, self.index_file_path)
|
faiss.write_index(self.faiss_index, self.index_file_path)
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
|
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
|
||||||
logger.info(
|
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
|
||||||
f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}"
|
|
||||||
)
|
|
||||||
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
|
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
|
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
|
||||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
|
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
|
||||||
|
|
@ -103,24 +97,18 @@ class EmbeddingStore:
|
||||||
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
|
||||||
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
|
||||||
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
|
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
|
||||||
self.store[row["hash"]] = EmbeddingStoreItem(
|
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
|
||||||
row["hash"], row["embedding"], row["str"]
|
|
||||||
)
|
|
||||||
logger.info(f"{self.namespace}嵌入库加载成功")
|
logger.info(f"{self.namespace}嵌入库加载成功")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(self.index_file_path):
|
if os.path.exists(self.index_file_path):
|
||||||
logger.info(
|
logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
|
||||||
f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex"
|
|
||||||
)
|
|
||||||
self.faiss_index = faiss.read_index(self.index_file_path)
|
self.faiss_index = faiss.read_index(self.index_file_path)
|
||||||
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
|
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
|
||||||
else:
|
else:
|
||||||
raise Exception(f"文件{self.index_file_path}不存在")
|
raise Exception(f"文件{self.index_file_path}不存在")
|
||||||
if os.path.exists(self.idx2hash_file_path):
|
if os.path.exists(self.idx2hash_file_path):
|
||||||
logger.info(
|
logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
|
||||||
f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射"
|
|
||||||
)
|
|
||||||
with open(self.idx2hash_file_path, "r") as f:
|
with open(self.idx2hash_file_path, "r") as f:
|
||||||
self.idx2hash = json.load(f)
|
self.idx2hash = json.load(f)
|
||||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
|
||||||
|
|
@ -215,9 +203,7 @@ class EmbeddingManager:
|
||||||
for triples in triple_list_data.values():
|
for triples in triple_list_data.values():
|
||||||
graph_triples.extend([tuple(t) for t in triples])
|
graph_triples.extend([tuple(t) for t in triples])
|
||||||
graph_triples = list(set(graph_triples))
|
graph_triples = list(set(graph_triples))
|
||||||
self.relation_embedding_store.batch_insert_strs(
|
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
|
||||||
[str(triple) for triple in graph_triples]
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_from_file(self):
|
def load_from_file(self):
|
||||||
"""从文件加载"""
|
"""从文件加载"""
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@ logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
console_logging_handler = logging.StreamHandler()
|
console_logging_handler = logging.StreamHandler()
|
||||||
console_logging_handler.setFormatter(
|
console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
|
||||||
logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
)
|
|
||||||
console_logging_handler.setLevel(logging.DEBUG)
|
console_logging_handler.setLevel(logging.DEBUG)
|
||||||
logger.addHandler(console_logging_handler)
|
logger.addHandler(console_logging_handler)
|
||||||
|
|
|
||||||
|
|
@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
||||||
|
|
||||||
def _rdf_triple_extract(
|
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
|
||||||
llm_client: LLMClient, paragraph: str, entities: list
|
|
||||||
) -> List[List[str]]:
|
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
_, request_result = llm_client.send_chat_request(
|
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
|
||||||
global_config["rdf_build"]["llm"]["model"], entity_extract_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
||||||
if "[" in request_result:
|
if "[" in request_result:
|
||||||
|
|
@ -60,11 +56,7 @@ def _rdf_triple_extract(
|
||||||
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
|
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
|
||||||
|
|
||||||
for triple in entity_extract_result:
|
for triple in entity_extract_result:
|
||||||
if (
|
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
len(triple) != 3
|
|
||||||
or (triple[0] is None or triple[1] is None or triple[2] is None)
|
|
||||||
or "" in triple
|
|
||||||
):
|
|
||||||
raise Exception("RDF提取结果格式错误")
|
raise Exception("RDF提取结果格式错误")
|
||||||
|
|
||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
@ -91,9 +83,7 @@ def info_extract_from_str(
|
||||||
try_count = 0
|
try_count = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
rdf_triple_extract_result = _rdf_triple_extract(
|
rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result)
|
||||||
llm_client_for_rdf, paragraph, entity_extract_result
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"实体提取失败,错误信息:{e}")
|
logger.warning(f"实体提取失败,错误信息:{e}")
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from .config import (
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
|
|
||||||
class KGManager:
|
class KGManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 会被保存的字段
|
# 会被保存的字段
|
||||||
|
|
@ -35,9 +36,7 @@ class KGManager:
|
||||||
# 持久化相关
|
# 持久化相关
|
||||||
self.dir_path = global_config["persistence"]["rag_data_dir"]
|
self.dir_path = global_config["persistence"]["rag_data_dir"]
|
||||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz"
|
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz"
|
||||||
self.ent_cnt_data_path = (
|
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
||||||
self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
|
||||||
)
|
|
||||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
|
|
@ -50,9 +49,7 @@ class KGManager:
|
||||||
nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8")
|
nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8")
|
||||||
|
|
||||||
# 保存实体计数到文件
|
# 保存实体计数到文件
|
||||||
ent_cnt_df = pd.DataFrame(
|
ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()])
|
||||||
[{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]
|
|
||||||
)
|
|
||||||
ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
|
ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
|
||||||
|
|
||||||
# 保存段落hash到文件
|
# 保存段落hash到文件
|
||||||
|
|
@ -77,9 +74,7 @@ class KGManager:
|
||||||
|
|
||||||
# 加载实体计数
|
# 加载实体计数
|
||||||
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
|
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
|
||||||
self.ent_appear_cnt = dict(
|
self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
|
||||||
{row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 加载KG
|
# 加载KG
|
||||||
self.graph = nx.read_graphml(self.graph_data_path)
|
self.graph = nx.read_graphml(self.graph_data_path)
|
||||||
|
|
@ -101,20 +96,14 @@ class KGManager:
|
||||||
# 一个triple就是一条边(同时构建双向联系)
|
# 一个triple就是一条边(同时构建双向联系)
|
||||||
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||||
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
|
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
|
||||||
node_to_node[(hash_key1, hash_key2)] = (
|
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||||
node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||||
)
|
|
||||||
node_to_node[(hash_key2, hash_key1)] = (
|
|
||||||
node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
|
||||||
)
|
|
||||||
entity_set.add(hash_key1)
|
entity_set.add(hash_key1)
|
||||||
entity_set.add(hash_key2)
|
entity_set.add(hash_key2)
|
||||||
|
|
||||||
# 实体出现次数统计
|
# 实体出现次数统计
|
||||||
for hash_key in entity_set:
|
for hash_key in entity_set:
|
||||||
self.ent_appear_cnt[hash_key] = (
|
self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0
|
||||||
self.ent_appear_cnt.get(hash_key, 0) + 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_edges_between_ent_pg(
|
def _build_edges_between_ent_pg(
|
||||||
|
|
@ -126,9 +115,7 @@ class KGManager:
|
||||||
for triple in triple_list_data[idx]:
|
for triple in triple_list_data[idx]:
|
||||||
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||||
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
|
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
|
||||||
node_to_node[(ent_hash_key, pg_hash_key)] = (
|
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||||
node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _synonym_connect(
|
def _synonym_connect(
|
||||||
|
|
@ -177,9 +164,7 @@ class KGManager:
|
||||||
new_edge_cnt += 1
|
new_edge_cnt += 1
|
||||||
res_ent.append(
|
res_ent.append(
|
||||||
(
|
(
|
||||||
embedding_manager.entities_embedding_store.store[
|
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
|
||||||
res_ent_hash
|
|
||||||
].str,
|
|
||||||
similarity,
|
similarity,
|
||||||
)
|
)
|
||||||
) # Debug
|
) # Debug
|
||||||
|
|
@ -236,22 +221,16 @@ class KGManager:
|
||||||
if node_hash not in existed_nodes:
|
if node_hash not in existed_nodes:
|
||||||
if node_hash.startswith(ENT_NAMESPACE):
|
if node_hash.startswith(ENT_NAMESPACE):
|
||||||
# 新增实体节点
|
# 新增实体节点
|
||||||
node = embedding_manager.entities_embedding_store.store[
|
node = embedding_manager.entities_embedding_store.store[node_hash]
|
||||||
node_hash
|
|
||||||
]
|
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
self.graph.nodes[node_hash]["content"] = node.str
|
self.graph.nodes[node_hash]["content"] = node.str
|
||||||
self.graph.nodes[node_hash]["type"] = "ent"
|
self.graph.nodes[node_hash]["type"] = "ent"
|
||||||
elif node_hash.startswith(PG_NAMESPACE):
|
elif node_hash.startswith(PG_NAMESPACE):
|
||||||
# 新增文段节点
|
# 新增文段节点
|
||||||
node = embedding_manager.paragraphs_embedding_store.store[
|
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
||||||
node_hash
|
|
||||||
]
|
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
content = node.str.replace("\n", " ")
|
content = node.str.replace("\n", " ")
|
||||||
self.graph.nodes[node_hash]["content"] = (
|
self.graph.nodes[node_hash]["content"] = content if len(content) < 8 else content[:8] + "..."
|
||||||
content if len(content) < 8 else content[:8] + "..."
|
|
||||||
)
|
|
||||||
self.graph.nodes[node_hash]["type"] = "pg"
|
self.graph.nodes[node_hash]["type"] = "pg"
|
||||||
|
|
||||||
def build_kg(
|
def build_kg(
|
||||||
|
|
@ -319,9 +298,7 @@ class KGManager:
|
||||||
ent_sim_scores = {}
|
ent_sim_scores = {}
|
||||||
for relation_hash, similarity, _ in relation_search_result:
|
for relation_hash, similarity, _ in relation_search_result:
|
||||||
# 提取主宾短语
|
# 提取主宾短语
|
||||||
relation = embed_manager.relation_embedding_store.store.get(
|
relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
|
||||||
relation_hash
|
|
||||||
).str
|
|
||||||
assert relation is not None # 断言:relation不为空
|
assert relation is not None # 断言:relation不为空
|
||||||
# 关系三元组
|
# 关系三元组
|
||||||
triple = relation[2:-2].split("', '")
|
triple = relation[2:-2].split("', '")
|
||||||
|
|
@ -335,9 +312,7 @@ class KGManager:
|
||||||
ent_mean_scores = {} # 记录实体的平均相似度
|
ent_mean_scores = {} # 记录实体的平均相似度
|
||||||
for ent_hash, scores in ent_sim_scores.items():
|
for ent_hash, scores in ent_sim_scores.items():
|
||||||
# 先对相似度进行累加,然后与实体计数相除获取最终权重
|
# 先对相似度进行累加,然后与实体计数相除获取最终权重
|
||||||
ent_weights[ent_hash] = (
|
ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
|
||||||
float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
|
|
||||||
)
|
|
||||||
# 记录实体的平均相似度,用于后续的top_k筛选
|
# 记录实体的平均相似度,用于后续的top_k筛选
|
||||||
ent_mean_scores[ent_hash] = float(np.mean(scores))
|
ent_mean_scores[ent_hash] = float(np.mean(scores))
|
||||||
del ent_sim_scores
|
del ent_sim_scores
|
||||||
|
|
@ -354,21 +329,14 @@ class KGManager:
|
||||||
for ent_hash, score in ent_weights.items():
|
for ent_hash, score in ent_weights.items():
|
||||||
# 缩放相似度
|
# 缩放相似度
|
||||||
ent_weights[ent_hash] = (
|
ent_weights[ent_hash] = (
|
||||||
(score - ent_weights_min)
|
(score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min)
|
||||||
* (1 - down_edge)
|
|
||||||
/ (ent_weights_max - ent_weights_min)
|
|
||||||
) + down_edge
|
) + down_edge
|
||||||
|
|
||||||
# 取平均相似度的top_k实体
|
# 取平均相似度的top_k实体
|
||||||
top_k = global_config["qa"]["params"]["ent_filter_top_k"]
|
top_k = global_config["qa"]["params"]["ent_filter_top_k"]
|
||||||
if len(ent_mean_scores) > top_k:
|
if len(ent_mean_scores) > top_k:
|
||||||
# 从大到小排序,取后len - k个
|
# 从大到小排序,取后len - k个
|
||||||
ent_mean_scores = {
|
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
|
||||||
k: v
|
|
||||||
for k, v in sorted(
|
|
||||||
ent_mean_scores.items(), key=lambda item: item[1], reverse=True
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for ent_hash, _ in ent_mean_scores.items():
|
for ent_hash, _ in ent_mean_scores.items():
|
||||||
# 删除被淘汰的实体节点权重设置
|
# 删除被淘汰的实体节点权重设置
|
||||||
del ent_weights[ent_hash]
|
del ent_weights[ent_hash]
|
||||||
|
|
@ -389,9 +357,7 @@ class KGManager:
|
||||||
# 归一化
|
# 归一化
|
||||||
for pg_hash, similarity in pg_sim_scores.items():
|
for pg_hash, similarity in pg_sim_scores.items():
|
||||||
# 归一化相似度
|
# 归一化相似度
|
||||||
pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (
|
pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min)
|
||||||
pg_sim_score_max - pg_sim_score_min
|
|
||||||
)
|
|
||||||
del pg_sim_score_max, pg_sim_score_min
|
del pg_sim_score_max, pg_sim_score_min
|
||||||
|
|
||||||
for pg_hash, score in pg_sim_scores.items():
|
for pg_hash, score in pg_sim_scores.items():
|
||||||
|
|
@ -401,9 +367,7 @@ class KGManager:
|
||||||
del pg_sim_scores
|
del pg_sim_scores
|
||||||
|
|
||||||
# 最终权重数据 = 实体权重 + 文段权重
|
# 最终权重数据 = 实体权重 + 文段权重
|
||||||
ppr_node_weights = {
|
ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()}
|
||||||
k: v for d in [ent_weights, pg_weights] for k, v in d.items()
|
|
||||||
}
|
|
||||||
del ent_weights, pg_weights
|
del ent_weights, pg_weights
|
||||||
|
|
||||||
# PersonalizedPageRank
|
# PersonalizedPageRank
|
||||||
|
|
@ -418,14 +382,12 @@ class KGManager:
|
||||||
# 从搜索结果中提取文段节点的结果
|
# 从搜索结果中提取文段节点的结果
|
||||||
passage_node_res = [
|
passage_node_res = [
|
||||||
(node_key, score)
|
(node_key, score)
|
||||||
for node_key, score in ppr_res.items() # Iterate over dictionary items
|
for node_key, score in ppr_res.items() # Iterate over dictionary items
|
||||||
if node_key.startswith(PG_NAMESPACE)
|
if node_key.startswith(PG_NAMESPACE)
|
||||||
]
|
]
|
||||||
del ppr_res
|
del ppr_res
|
||||||
|
|
||||||
# 排序:按照分数从大到小
|
# 排序:按照分数从大到小
|
||||||
passage_node_res = sorted(
|
passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
|
||||||
passage_node_res, key=lambda item: item[1], reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return passage_node_res, ppr_node_weights
|
return passage_node_res, ppr_node_weights
|
||||||
|
|
|
||||||
|
|
@ -21,20 +21,14 @@ class LLMClient:
|
||||||
|
|
||||||
def send_chat_request(self, model, messages):
|
def send_chat_request(self, model, messages):
|
||||||
"""发送对话请求,等待返回结果"""
|
"""发送对话请求,等待返回结果"""
|
||||||
response = self.client.chat.completions.create(
|
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
|
||||||
model=model, messages=messages, stream=False
|
|
||||||
)
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content"):
|
if hasattr(response.choices[0].message, "reasoning_content"):
|
||||||
# 有单独的推理内容块
|
# 有单独的推理内容块
|
||||||
reasoning_content = response.choices[0].message.reasoning_content
|
reasoning_content = response.choices[0].message.reasoning_content
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
else:
|
else:
|
||||||
# 无单独的推理内容块
|
# 无单独的推理内容块
|
||||||
response = (
|
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
|
||||||
response.choices[0]
|
|
||||||
.message.content.split("<think>")[-1]
|
|
||||||
.split("</think>")
|
|
||||||
)
|
|
||||||
# 如果有推理内容,则分割推理内容和内容
|
# 如果有推理内容,则分割推理内容和内容
|
||||||
if len(response) == 2:
|
if len(response) == 2:
|
||||||
reasoning_content = response[0]
|
reasoning_content = response[0]
|
||||||
|
|
@ -48,6 +42,4 @@ class LLMClient:
|
||||||
def send_embedding_request(self, model, text):
|
def send_embedding_request(self, model, text):
|
||||||
"""发送嵌入请求,等待返回结果"""
|
"""发送嵌入请求,等待返回结果"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
return (
|
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
||||||
self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,6 @@ def _load_config(config, config_file_path):
|
||||||
config["persistence"] = file_config["persistence"]
|
config["persistence"] = file_config["persistence"]
|
||||||
print(config)
|
print(config)
|
||||||
print("Configurations loaded from file: ", config_file_path)
|
print("Configurations loaded from file: ", config_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
|
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
|
||||||
|
|
@ -122,9 +121,9 @@ global_config = dict(
|
||||||
"embedding_data_dir": "data/embedding",
|
"embedding_data_dir": "data/embedding",
|
||||||
"rag_data_dir": "data/rag",
|
"rag_data_dir": "data/rag",
|
||||||
},
|
},
|
||||||
"info_extraction":{
|
"info_extraction": {
|
||||||
"workers": 10,
|
"workers": 10,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,9 @@ class MemoryActiveManager:
|
||||||
def get_activation(self, question: str) -> float:
|
def get_activation(self, question: str) -> float:
|
||||||
"""获取记忆激活度"""
|
"""获取记忆激活度"""
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
question_embedding = self.embedding_client.send_embedding_request(
|
question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
|
||||||
"text-embedding", question
|
|
||||||
)
|
|
||||||
# 查询关系库中的相似度
|
# 查询关系库中的相似度
|
||||||
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
|
||||||
question_embedding, 10
|
|
||||||
)
|
|
||||||
|
|
||||||
# 动态过滤阈值
|
# 动态过滤阈值
|
||||||
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
|
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||||
"""过滤无效的实体"""
|
"""过滤无效的实体"""
|
||||||
valid_entities = set()
|
valid_entities = set()
|
||||||
for entity in entities:
|
for entity in entities:
|
||||||
if (
|
if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities:
|
||||||
not isinstance(entity, str)
|
|
||||||
or entity.strip() == ""
|
|
||||||
or entity in INVALID_ENTITY
|
|
||||||
or entity in valid_entities
|
|
||||||
):
|
|
||||||
# 非字符串/空字符串/在无效实体列表中/重复
|
# 非字符串/空字符串/在无效实体列表中/重复
|
||||||
continue
|
continue
|
||||||
valid_entities.add(entity)
|
valid_entities.add(entity)
|
||||||
|
|
@ -74,9 +69,7 @@ class OpenIE:
|
||||||
|
|
||||||
for doc in self.docs:
|
for doc in self.docs:
|
||||||
# 过滤实体列表
|
# 过滤实体列表
|
||||||
doc["extracted_entities"] = _filter_invalid_entities(
|
doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"])
|
||||||
doc["extracted_entities"]
|
|
||||||
)
|
|
||||||
# 过滤无效的三元组
|
# 过滤无效的三元组
|
||||||
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
|
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
|
||||||
|
|
||||||
|
|
@ -100,9 +93,7 @@ class OpenIE:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load() -> "OpenIE":
|
def load() -> "OpenIE":
|
||||||
"""从文件中加载OpenIE数据"""
|
"""从文件中加载OpenIE数据"""
|
||||||
with open(
|
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
|
||||||
global_config["persistence"]["openie_data_path"], "r", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
data = json.loads(f.read())
|
data = json.loads(f.read())
|
||||||
|
|
||||||
openie_data = OpenIE._from_dict(data)
|
openie_data = OpenIE._from_dict(data)
|
||||||
|
|
@ -112,9 +103,7 @@ class OpenIE:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save(openie_data: "OpenIE"):
|
def save(openie_data: "OpenIE"):
|
||||||
"""保存OpenIE数据到文件"""
|
"""保存OpenIE数据到文件"""
|
||||||
with open(
|
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
|
||||||
global_config["persistence"]["openie_data_path"], "w", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
def extract_entity_dict(self):
|
def extract_entity_dict(self):
|
||||||
|
|
@ -141,7 +130,5 @@ class OpenIE:
|
||||||
|
|
||||||
def extract_raw_paragraph_dict(self):
|
def extract_raw_paragraph_dict(self):
|
||||||
"""提取原始段落"""
|
"""提取原始段落"""
|
||||||
raw_paragraph_dict = dict(
|
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
||||||
{doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
|
|
||||||
)
|
|
||||||
return raw_paragraph_dict
|
return raw_paragraph_dict
|
||||||
|
|
|
||||||
|
|
@ -41,9 +41,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
||||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
|
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
|
||||||
messages = [
|
messages = [
|
||||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
||||||
LLMMessage(
|
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
||||||
"user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
|
|
||||||
).to_dict(),
|
|
||||||
]
|
]
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
@ -58,16 +56,10 @@ qa_system_prompt = """
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_qa_context(
|
def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
|
||||||
question: str, knowledge: list[(str, str, str)]
|
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||||
) -> List[LLMMessage]:
|
|
||||||
knowledge = "\n".join(
|
|
||||||
[f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
|
|
||||||
)
|
|
||||||
messages = [
|
messages = [
|
||||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
LLMMessage("system", qa_system_prompt).to_dict(),
|
||||||
LLMMessage(
|
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
|
||||||
"user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
|
|
||||||
).to_dict(),
|
|
||||||
]
|
]
|
||||||
return messages
|
return messages
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import time
|
||||||
from typing import Tuple, List, Dict
|
from typing import Tuple, List, Dict
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
# from . import prompt_template
|
# from . import prompt_template
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
|
|
@ -31,7 +32,7 @@ class QAManager:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
part_start_time =time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
||||||
global_config["embedding"]["model"], question
|
global_config["embedding"]["model"], question
|
||||||
)
|
)
|
||||||
|
|
@ -39,7 +40,7 @@ class QAManager:
|
||||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
# 根据问题Embedding查询Relation Embedding库
|
# 根据问题Embedding查询Relation Embedding库
|
||||||
part_start_time =time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||||
question_embedding,
|
question_embedding,
|
||||||
global_config["qa"]["params"]["relation_search_top_k"],
|
global_config["qa"]["params"]["relation_search_top_k"],
|
||||||
|
|
@ -47,10 +48,7 @@ class QAManager:
|
||||||
# 过滤阈值
|
# 过滤阈值
|
||||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||||
if (
|
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||||
relation_search_res[0][1]
|
|
||||||
< global_config["qa"]["params"]["relation_threshold"]
|
|
||||||
):
|
|
||||||
# 未找到相关关系
|
# 未找到相关关系
|
||||||
relation_search_res = []
|
relation_search_res = []
|
||||||
|
|
||||||
|
|
@ -66,12 +64,10 @@ class QAManager:
|
||||||
# part_start_time = time.time()
|
# part_start_time = time.time()
|
||||||
|
|
||||||
# 根据问题Embedding查询Paragraph Embedding库
|
# 根据问题Embedding查询Paragraph Embedding库
|
||||||
part_start_time =time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
paragraph_search_res = (
|
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||||
self.embed_manager.paragraphs_embedding_store.search_top_k(
|
question_embedding,
|
||||||
question_embedding,
|
global_config["qa"]["params"]["paragraph_search_top_k"],
|
||||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
@ -79,7 +75,7 @@ class QAManager:
|
||||||
if len(relation_search_res) != 0:
|
if len(relation_search_res) != 0:
|
||||||
logger.info("找到相关关系,将使用RAG进行检索")
|
logger.info("找到相关关系,将使用RAG进行检索")
|
||||||
# 使用KG检索
|
# 使用KG检索
|
||||||
part_start_time =time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
result, ppr_node_weights = self.kg_manager.kg_search(
|
result, ppr_node_weights = self.kg_manager.kg_search(
|
||||||
relation_search_res, paragraph_search_res, self.embed_manager
|
relation_search_res, paragraph_search_res, self.embed_manager
|
||||||
)
|
)
|
||||||
|
|
@ -94,9 +90,7 @@ class QAManager:
|
||||||
result = dyn_select_top_k(result, 0.5, 1.0)
|
result = dyn_select_top_k(result, 0.5, 1.0)
|
||||||
|
|
||||||
for res in result:
|
for res in result:
|
||||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[
|
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||||
res[0]
|
|
||||||
].str
|
|
||||||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||||
|
|
||||||
return result, ppr_node_weights
|
return result, ppr_node_weights
|
||||||
|
|
@ -114,4 +108,4 @@ class QAManager:
|
||||||
for res in query_res
|
for res in query_res
|
||||||
]
|
]
|
||||||
found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||||
return found_knowledge
|
return found_knowledge
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
# 读取import.json文件
|
# 读取import.json文件
|
||||||
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
|
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
|
||||||
with open(
|
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
|
||||||
global_config["persistence"]["raw_data_path"], "r", encoding="utf-8"
|
|
||||||
) as f:
|
|
||||||
import_json = json.loads(f.read())
|
import_json = json.loads(f.read())
|
||||||
else:
|
else:
|
||||||
raise Exception("原始数据文件读取失败")
|
raise Exception("原始数据文件读取失败")
|
||||||
|
|
|
||||||
|
|
@ -14,11 +14,7 @@ class DataLoader:
|
||||||
Args:
|
Args:
|
||||||
custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
|
custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
|
||||||
"""
|
"""
|
||||||
self.data_dir = (
|
self.data_dir = Path(custom_data_dir) if custom_data_dir else Path(config["persistence"]["data_root_path"])
|
||||||
Path(custom_data_dir)
|
|
||||||
if custom_data_dir
|
|
||||||
else Path(config["persistence"]["data_root_path"])
|
|
||||||
)
|
|
||||||
if not self.data_dir.exists():
|
if not self.data_dir.exists():
|
||||||
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
|
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,14 +36,10 @@ def dyn_select_top_k(
|
||||||
# 计算均值
|
# 计算均值
|
||||||
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
|
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
|
||||||
# 计算方差
|
# 计算方差
|
||||||
var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(
|
var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score)
|
||||||
normalized_score
|
|
||||||
)
|
|
||||||
|
|
||||||
# 动态阈值
|
# 动态阈值
|
||||||
threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (
|
threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score)
|
||||||
mean_score + var_factor * var_score
|
|
||||||
)
|
|
||||||
|
|
||||||
# 重新过滤
|
# 重新过滤
|
||||||
res = [s for s in normalized_score if s[2] > threshold]
|
res = [s for s in normalized_score if s[2] > threshold]
|
||||||
|
|
|
||||||
|
|
@ -29,10 +29,7 @@ def _find_unclosed(json_str):
|
||||||
elif char in "{[":
|
elif char in "{[":
|
||||||
unclosed.append(char)
|
unclosed.append(char)
|
||||||
elif char in "}]":
|
elif char in "}]":
|
||||||
if unclosed and (
|
if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
|
||||||
(char == "}" and unclosed[-1] == "{")
|
|
||||||
or (char == "]" and unclosed[-1] == "[")
|
|
||||||
):
|
|
||||||
unclosed.pop()
|
unclosed.pop()
|
||||||
|
|
||||||
return unclosed
|
return unclosed
|
||||||
|
|
|
||||||
|
|
@ -14,4 +14,4 @@ def draw_graph_and_show(graph):
|
||||||
font_family="Sarasa Mono SC",
|
font_family="Sarasa Mono SC",
|
||||||
font_size=8,
|
font_size=8,
|
||||||
)
|
)
|
||||||
fig.show()
|
fig.show()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue