diff --git a/import_openie.py b/import_openie.py
index 537187db..5e347ef5 100644
--- a/import_openie.py
+++ b/import_openie.py
@@ -20,6 +20,7 @@ import sys
logger = get_module_logger("LPMM知识库-OpenIE导入")
+
def hash_deduplicate(
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
@@ -43,14 +44,10 @@ def hash_deduplicate(
# 保存去重后的三元组
new_triple_list_data = dict()
- for _, (raw_paragraph, triple_list) in enumerate(
- zip(raw_paragraphs.values(), triple_list_data.values())
- ):
+ for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
# 段落hash
paragraph_hash = get_sha256(raw_paragraph)
- if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (
- paragraph_hash in stored_paragraph_hashes
- ):
+ if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes):
continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list
@@ -58,9 +55,7 @@ def hash_deduplicate(
return new_raw_paragraphs, new_triple_list_data
-def handle_import_openie(
- openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager
-) -> bool:
+def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
# 从OpenIE数据中提取段落原文与三元组列表
# 索引的段落原文
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
@@ -68,9 +63,7 @@ def handle_import_openie(
entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict()
- if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(
- triple_list_data
- ):
+ if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("OpenIE数据存在异常")
return False
# 将索引换为对应段落的hash值
@@ -112,11 +105,11 @@ def main():
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
confirm = input("确认继续执行?(y/n): ").strip().lower()
- if confirm != 'y':
+ if confirm != "y":
logger.info("用户取消操作")
print("操作已取消")
sys.exit(1)
- print("\n" + "="*40 + "\n")
+ print("\n" + "=" * 40 + "\n")
logger.info("----开始导入openie数据----\n")
@@ -129,9 +122,7 @@ def main():
)
# 初始化Embedding库
- embed_manager = embed_manager = EmbeddingManager(
- llm_client_list[global_config["embedding"]["provider"]]
- )
+ embed_manager = embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
diff --git a/info_extraction.py b/info_extraction.py
index 04f2dd4f..b6ad8a9c 100644
--- a/info_extraction.py
+++ b/info_extraction.py
@@ -91,11 +91,11 @@ def main():
print("或者使用可以用赠金抵扣的Pro模型")
print("请确保账户余额充足,并且在执行前确认无误。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
- if confirm != 'y':
+ if confirm != "y":
logger.info("用户取消操作")
print("操作已取消")
sys.exit(1)
- print("\n" + "="*40 + "\n")
+ print("\n" + "=" * 40 + "\n")
logger.info("--------进行信息提取--------\n")
diff --git a/raw_data_preprocessor.py b/raw_data_preprocessor.py
index abb1ad64..645a8506 100644
--- a/raw_data_preprocessor.py
+++ b/raw_data_preprocessor.py
@@ -3,18 +3,17 @@ import os
from pathlib import Path
import sys # 新增系统模块导入
+
def check_and_create_dirs():
"""检查并创建必要的目录"""
- required_dirs = [
- "data/lpmm_raw_data",
- "data/imported_lpmm_data"
- ]
-
+ required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
+
for dir_path in required_dirs:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print(f"已创建目录: {dir_path}")
+
def process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
@@ -29,12 +28,13 @@ def process_text_file(file_path):
paragraph = ""
else:
paragraph += line + "\n"
-
+
if paragraph != "":
paragraphs.append(paragraph.strip())
-
+
return paragraphs
+
def main():
# 新增用户确认提示
print("=== 重要操作确认 ===")
@@ -43,42 +43,43 @@ def main():
print("在进行知识库导入之前")
print("请修改config/lpmm_config.toml中的配置项")
confirm = input("确认继续执行?(y/n): ").strip().lower()
- if confirm != 'y':
+ if confirm != "y":
print("操作已取消")
sys.exit(1)
- print("\n" + "="*40 + "\n")
+ print("\n" + "=" * 40 + "\n")
# 检查并创建必要的目录
check_and_create_dirs()
-
+
# 检查输出文件是否存在
if os.path.exists("data/import.json"):
print("错误: data/import.json 已存在,请先处理或删除该文件")
sys.exit(1)
-
+
if os.path.exists("data/openie.json"):
print("错误: data/openie.json 已存在,请先处理或删除该文件")
sys.exit(1)
-
+
# 获取所有原始文本文件
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
if not raw_files:
print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1)
-
+
# 处理所有文件
all_paragraphs = []
for file in raw_files:
print(f"正在处理文件: {file.name}")
paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs)
-
+
# 保存合并后的结果
output_path = "data/import.json"
with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
-
+
print(f"处理完成,结果已保存到: {output_path}")
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/src/do_tool/tool_can_use/lpmm_get_knowledge.py b/src/do_tool/tool_can_use/lpmm_get_knowledge.py
index 5c70adc1..601d6083 100644
--- a/src/do_tool/tool_can_use/lpmm_get_knowledge.py
+++ b/src/do_tool/tool_can_use/lpmm_get_knowledge.py
@@ -1,5 +1,6 @@
from src.do_tool.tool_can_use.base_tool import BaseTool
from src.plugins.chat.utils import get_embedding
+
# from src.common.database import db
from src.common.logger import get_module_logger
from typing import Dict, Any
diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py
index 31167391..c0d2fe61 100644
--- a/src/plugins/knowledge/knowledge_lib.py
+++ b/src/plugins/knowledge/knowledge_lib.py
@@ -20,9 +20,7 @@ for key in global_config["llm_providers"]:
)
# 初始化Embedding库
-embed_manager = EmbeddingManager(
- llm_client_list[global_config["embedding"]["provider"]]
-)
+embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
@@ -62,4 +60,3 @@ inspire_manager = MemoryActiveManager(
embed_manager,
llm_client_list[global_config["embedding"]["provider"]],
)
-
diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py
index 59c804fa..e972db57 100644
--- a/src/plugins/knowledge/src/embedding_store.py
+++ b/src/plugins/knowledge/src/embedding_store.py
@@ -47,9 +47,7 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
- return self.llm_client.send_embedding_request(
- global_config["embedding"]["model"], s
- )
+ return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def batch_insert_strs(self, strs: List[str]) -> None:
"""向库中存入字符串"""
@@ -83,14 +81,10 @@ class EmbeddingStore:
logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None:
- logger.info(
- f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}"
- )
+ logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
faiss.write_index(self.faiss_index, self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
- logger.info(
- f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}"
- )
+ logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
@@ -103,24 +97,18 @@ class EmbeddingStore:
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
- self.store[row["hash"]] = EmbeddingStoreItem(
- row["hash"], row["embedding"], row["str"]
- )
+ self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
logger.info(f"{self.namespace}嵌入库加载成功")
try:
if os.path.exists(self.index_file_path):
- logger.info(
- f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex"
- )
+ logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
self.faiss_index = faiss.read_index(self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
else:
raise Exception(f"文件{self.index_file_path}不存在")
if os.path.exists(self.idx2hash_file_path):
- logger.info(
- f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射"
- )
+ logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
self.idx2hash = json.load(f)
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
@@ -215,9 +203,7 @@ class EmbeddingManager:
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
- self.relation_embedding_store.batch_insert_strs(
- [str(triple) for triple in graph_triples]
- )
+ self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
def load_from_file(self):
"""从文件加载"""
diff --git a/src/plugins/knowledge/src/global_logger.py b/src/plugins/knowledge/src/global_logger.py
index 0311db5f..f7d8297e 100644
--- a/src/plugins/knowledge/src/global_logger.py
+++ b/src/plugins/knowledge/src/global_logger.py
@@ -7,8 +7,6 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
console_logging_handler = logging.StreamHandler()
-console_logging_handler.setFormatter(
- logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
-)
+console_logging_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
console_logging_handler.setLevel(logging.DEBUG)
logger.addHandler(console_logging_handler)
diff --git a/src/plugins/knowledge/src/ie_process.py b/src/plugins/knowledge/src/ie_process.py
index 5da9ad9e..3e53e4b2 100644
--- a/src/plugins/knowledge/src/ie_process.py
+++ b/src/plugins/knowledge/src/ie_process.py
@@ -38,16 +38,12 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
return entity_extract_result
-def _rdf_triple_extract(
- llm_client: LLMClient, paragraph: str, entities: list
-) -> List[List[str]]:
+def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
- _, request_result = llm_client.send_chat_request(
- global_config["rdf_build"]["llm"]["model"], entity_extract_context
- )
+ _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
# 去除‘{’前的内容(结果中可能有多个‘{’)
if "[" in request_result:
@@ -60,11 +56,7 @@ def _rdf_triple_extract(
entity_extract_result = json.loads(fix_broken_generated_json(request_result))
for triple in entity_extract_result:
- if (
- len(triple) != 3
- or (triple[0] is None or triple[1] is None or triple[2] is None)
- or "" in triple
- ):
+ if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
raise Exception("RDF提取结果格式错误")
return entity_extract_result
@@ -91,9 +83,7 @@ def info_extract_from_str(
try_count = 0
while True:
try:
- rdf_triple_extract_result = _rdf_triple_extract(
- llm_client_for_rdf, paragraph, entity_extract_result
- )
+ rdf_triple_extract_result = _rdf_triple_extract(llm_client_for_rdf, paragraph, entity_extract_result)
break
except Exception as e:
logger.warning(f"实体提取失败,错误信息:{e}")
diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py
index 4fcdcf80..71ce65ef 100644
--- a/src/plugins/knowledge/src/kg_manager.py
+++ b/src/plugins/knowledge/src/kg_manager.py
@@ -22,6 +22,7 @@ from .lpmmconfig import (
from .global_logger import logger
+
class KGManager:
def __init__(self):
# 会被保存的字段
@@ -35,9 +36,7 @@ class KGManager:
# 持久化相关
self.dir_path = global_config["persistence"]["rag_data_dir"]
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
- self.ent_cnt_data_path = (
- self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
- )
+ self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
def save_to_file(self):
@@ -50,9 +49,7 @@ class KGManager:
di_graph.save_to_file(self.graph, self.graph_data_path)
# 保存实体计数到文件
- ent_cnt_df = pd.DataFrame(
- [{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]
- )
+ ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()])
ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
# 保存段落hash到文件
@@ -77,9 +74,7 @@ class KGManager:
# 加载实体计数
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
- self.ent_appear_cnt = dict(
- {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
- )
+ self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
# 加载KG
self.graph = di_graph.load_from_file(self.graph_data_path)
@@ -99,20 +94,14 @@ class KGManager:
# 一个triple就是一条边(同时构建双向联系)
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
- node_to_node[(hash_key1, hash_key2)] = (
- node_to_node.get((hash_key1, hash_key2), 0) + 1.0
- )
- node_to_node[(hash_key2, hash_key1)] = (
- node_to_node.get((hash_key2, hash_key1), 0) + 1.0
- )
+ node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
+ node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1)
entity_set.add(hash_key2)
# 实体出现次数统计
for hash_key in entity_set:
- self.ent_appear_cnt[hash_key] = (
- self.ent_appear_cnt.get(hash_key, 0) + 1.0
- )
+ self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0
@staticmethod
def _build_edges_between_ent_pg(
@@ -124,9 +113,7 @@ class KGManager:
for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
- node_to_node[(ent_hash_key, pg_hash_key)] = (
- node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
- )
+ node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod
def _synonym_connect(
@@ -175,9 +162,7 @@ class KGManager:
new_edge_cnt += 1
res_ent.append(
(
- embedding_manager.entities_embedding_store.store[
- res_ent_hash
- ].str,
+ embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
@@ -235,9 +220,7 @@ class KGManager:
if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE):
# 新增实体节点
- node = embedding_manager.entities_embedding_store.store[
- node_hash
- ]
+ node = embedding_manager.entities_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
node_item = self.graph[node_hash]
node_item["content"] = node.str
@@ -246,15 +229,11 @@ class KGManager:
self.graph.update_node(node_item)
elif node_hash.startswith(PG_NAMESPACE):
# 新增文段节点
- node = embedding_manager.paragraphs_embedding_store.store[
- node_hash
- ]
+ node = embedding_manager.paragraphs_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
content = node.str.replace("\n", " ")
node_item = self.graph[node_hash]
- node_item["content"] = (
- content if len(content) < 8 else content[:8] + "..."
- )
+ node_item["content"] = content if len(content) < 8 else content[:8] + "..."
node_item["type"] = "pg"
node_item["create_time"] = now_time
self.graph.update_node(node_item)
@@ -324,9 +303,7 @@ class KGManager:
ent_sim_scores = {}
for relation_hash, similarity, _ in relation_search_result:
# 提取主宾短语
- relation = embed_manager.relation_embedding_store.store.get(
- relation_hash
- ).str
+ relation = embed_manager.relation_embedding_store.store.get(relation_hash).str
assert relation is not None # 断言:relation不为空
# 关系三元组
triple = relation[2:-2].split("', '")
@@ -340,9 +317,7 @@ class KGManager:
ent_mean_scores = {} # 记录实体的平均相似度
for ent_hash, scores in ent_sim_scores.items():
# 先对相似度进行累加,然后与实体计数相除获取最终权重
- ent_weights[ent_hash] = (
- float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
- )
+ ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
# 记录实体的平均相似度,用于后续的top_k筛选
ent_mean_scores[ent_hash] = float(np.mean(scores))
del ent_sim_scores
@@ -359,21 +334,14 @@ class KGManager:
for ent_hash, score in ent_weights.items():
# 缩放相似度
ent_weights[ent_hash] = (
- (score - ent_weights_min)
- * (1 - down_edge)
- / (ent_weights_max - ent_weights_min)
+ (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min)
) + down_edge
# 取平均相似度的top_k实体
top_k = global_config["qa"]["params"]["ent_filter_top_k"]
if len(ent_mean_scores) > top_k:
# 从大到小排序,取后len - k个
- ent_mean_scores = {
- k: v
- for k, v in sorted(
- ent_mean_scores.items(), key=lambda item: item[1], reverse=True
- )
- }
+ ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
for ent_hash, _ in ent_mean_scores.items():
# 删除被淘汰的实体节点权重设置
del ent_weights[ent_hash]
@@ -394,9 +362,7 @@ class KGManager:
# 归一化
for pg_hash, similarity in pg_sim_scores.items():
# 归一化相似度
- pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (
- pg_sim_score_max - pg_sim_score_min
- )
+ pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min)
del pg_sim_score_max, pg_sim_score_min
for pg_hash, score in pg_sim_scores.items():
@@ -406,9 +372,7 @@ class KGManager:
del pg_sim_scores
# 最终权重数据 = 实体权重 + 文段权重
- ppr_node_weights = {
- k: v for d in [ent_weights, pg_weights] for k, v in d.items()
- }
+ ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()}
del ent_weights, pg_weights
# PersonalizedPageRank
@@ -422,15 +386,11 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
- (node_key, score)
- for node_key, score in ppr_res.items()
- if node_key.startswith(PG_NAMESPACE)
+ (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
]
del ppr_res
# 排序:按照分数从大到小
- passage_node_res = sorted(
- passage_node_res, key=lambda item: item[1], reverse=True
- )
+ passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True)
return passage_node_res, ppr_node_weights
diff --git a/src/plugins/knowledge/src/llm_client.py b/src/plugins/knowledge/src/llm_client.py
index 3036662c..52d0dca0 100644
--- a/src/plugins/knowledge/src/llm_client.py
+++ b/src/plugins/knowledge/src/llm_client.py
@@ -21,20 +21,14 @@ class LLMClient:
def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果"""
- response = self.client.chat.completions.create(
- model=model, messages=messages, stream=False
- )
+ response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
else:
# 无单独的推理内容块
- response = (
- response.choices[0]
- .message.content.split("")[-1]
- .split("")
- )
+ response = response.choices[0].message.content.split("")[-1].split("")
# 如果有推理内容,则分割推理内容和内容
if len(response) == 2:
reasoning_content = response[0]
@@ -48,6 +42,4 @@ class LLMClient:
def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ")
- return (
- self.client.embeddings.create(input=[text], model=model).data[0].embedding
- )
+ return self.client.embeddings.create(input=[text], model=model).data[0].embedding
diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py
index ff1ac8fa..7f59bc89 100644
--- a/src/plugins/knowledge/src/lpmmconfig.py
+++ b/src/plugins/knowledge/src/lpmmconfig.py
@@ -65,7 +65,6 @@ def _load_config(config, config_file_path):
config["persistence"] = file_config["persistence"]
print(config)
print("Configurations loaded from file: ", config_file_path)
-
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
@@ -129,9 +128,9 @@ global_config = dict(
"embedding_data_dir": "data/embedding",
"rag_data_dir": "data/rag",
},
- "info_extraction":{
+ "info_extraction": {
"workers": 10,
- }
+ },
}
)
diff --git a/src/plugins/knowledge/src/mem_active_manager.py b/src/plugins/knowledge/src/mem_active_manager.py
index 073e3bda..3998c066 100644
--- a/src/plugins/knowledge/src/mem_active_manager.py
+++ b/src/plugins/knowledge/src/mem_active_manager.py
@@ -16,13 +16,9 @@ class MemoryActiveManager:
def get_activation(self, question: str) -> float:
"""获取记忆激活度"""
# 生成问题的Embedding
- question_embedding = self.embedding_client.send_embedding_request(
- "text-embedding", question
- )
+ question_embedding = self.embedding_client.send_embedding_request("text-embedding", question)
# 查询关系库中的相似度
- rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(
- question_embedding, 10
- )
+ rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(question_embedding, 10)
# 动态过滤阈值
rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py
index 58259ef8..5fe163bb 100644
--- a/src/plugins/knowledge/src/open_ie.py
+++ b/src/plugins/knowledge/src/open_ie.py
@@ -9,12 +9,7 @@ def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体"""
valid_entities = set()
for entity in entities:
- if (
- not isinstance(entity, str)
- or entity.strip() == ""
- or entity in INVALID_ENTITY
- or entity in valid_entities
- ):
+ if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities:
# 非字符串/空字符串/在无效实体列表中/重复
continue
valid_entities.add(entity)
@@ -74,9 +69,7 @@ class OpenIE:
for doc in self.docs:
# 过滤实体列表
- doc["extracted_entities"] = _filter_invalid_entities(
- doc["extracted_entities"]
- )
+ doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"])
# 过滤无效的三元组
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@@ -100,9 +93,7 @@ class OpenIE:
@staticmethod
def load() -> "OpenIE":
"""从文件中加载OpenIE数据"""
- with open(
- global_config["persistence"]["openie_data_path"], "r", encoding="utf-8"
- ) as f:
+ with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
data = json.loads(f.read())
openie_data = OpenIE._from_dict(data)
@@ -112,9 +103,7 @@ class OpenIE:
@staticmethod
def save(openie_data: "OpenIE"):
"""保存OpenIE数据到文件"""
- with open(
- global_config["persistence"]["openie_data_path"], "w", encoding="utf-8"
- ) as f:
+ with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
def extract_entity_dict(self):
@@ -141,7 +130,5 @@ class OpenIE:
def extract_raw_paragraph_dict(self):
"""提取原始段落"""
- raw_paragraph_dict = dict(
- {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
- )
+ raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict
diff --git a/src/plugins/knowledge/src/prompt_template.py b/src/plugins/knowledge/src/prompt_template.py
index ab6ac08c..18a5002e 100644
--- a/src/plugins/knowledge/src/prompt_template.py
+++ b/src/plugins/knowledge/src/prompt_template.py
@@ -41,9 +41,7 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
- LLMMessage(
- "user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
- ).to_dict(),
+ LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
]
return messages
@@ -58,16 +56,10 @@ qa_system_prompt = """
"""
-def build_qa_context(
- question: str, knowledge: list[(str, str, str)]
-) -> List[LLMMessage]:
- knowledge = "\n".join(
- [f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
- )
+def build_qa_context(question: str, knowledge: list[(str, str, str)]) -> List[LLMMessage]:
+ knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [
LLMMessage("system", qa_system_prompt).to_dict(),
- LLMMessage(
- "user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
- ).to_dict(),
+ LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
]
return messages
diff --git a/src/plugins/knowledge/src/qa_manager.py b/src/plugins/knowledge/src/qa_manager.py
index 986c8419..894c0124 100644
--- a/src/plugins/knowledge/src/qa_manager.py
+++ b/src/plugins/knowledge/src/qa_manager.py
@@ -2,6 +2,7 @@ import time
from typing import Tuple, List, Dict
from .global_logger import logger
+
# from . import prompt_template
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
@@ -31,7 +32,7 @@ class QAManager:
"""处理查询"""
# 生成问题的Embedding
- part_start_time =time.perf_counter()
+ part_start_time = time.perf_counter()
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
global_config["embedding"]["model"], question
)
@@ -39,7 +40,7 @@ class QAManager:
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
# 根据问题Embedding查询Relation Embedding库
- part_start_time =time.perf_counter()
+ part_start_time = time.perf_counter()
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
question_embedding,
global_config["qa"]["params"]["relation_search_top_k"],
@@ -47,10 +48,7 @@ class QAManager:
# 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
- if (
- relation_search_res[0][1]
- < global_config["qa"]["params"]["relation_threshold"]
- ):
+ if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
# 未找到相关关系
relation_search_res = []
@@ -66,12 +64,10 @@ class QAManager:
# part_start_time = time.time()
# 根据问题Embedding查询Paragraph Embedding库
- part_start_time =time.perf_counter()
- paragraph_search_res = (
- self.embed_manager.paragraphs_embedding_store.search_top_k(
- question_embedding,
- global_config["qa"]["params"]["paragraph_search_top_k"],
- )
+ part_start_time = time.perf_counter()
+ paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
+ question_embedding,
+ global_config["qa"]["params"]["paragraph_search_top_k"],
)
part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
@@ -79,7 +75,7 @@ class QAManager:
if len(relation_search_res) != 0:
logger.info("找到相关关系,将使用RAG进行检索")
# 使用KG检索
- part_start_time =time.perf_counter()
+ part_start_time = time.perf_counter()
result, ppr_node_weights = self.kg_manager.kg_search(
relation_search_res, paragraph_search_res, self.embed_manager
)
@@ -94,9 +90,7 @@ class QAManager:
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
- raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[
- res[0]
- ].str
+ raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
@@ -113,5 +107,7 @@ class QAManager:
)
for res in query_res
]
- found_knowledge = "\n".join([f"第{i + 1}条知识:{k[1]}\n 该条知识对于问题的相关性:{k[0]}" for i, k in enumerate(knowledge)])
- return found_knowledge
\ No newline at end of file
+ found_knowledge = "\n".join(
+ [f"第{i + 1}条知识:{k[1]}\n 该条知识对于问题的相关性:{k[0]}" for i, k in enumerate(knowledge)]
+ )
+ return found_knowledge
diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py
index 75e4b863..91e681c7 100644
--- a/src/plugins/knowledge/src/raw_processing.py
+++ b/src/plugins/knowledge/src/raw_processing.py
@@ -17,9 +17,7 @@ def load_raw_data() -> tuple[list[str], list[str]]:
"""
# 读取import.json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
- with open(
- global_config["persistence"]["raw_data_path"], "r", encoding="utf-8"
- ) as f:
+ with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
import_json = json.loads(f.read())
else:
raise Exception("原始数据文件读取失败")
diff --git a/src/plugins/knowledge/src/utils/data_loader.py b/src/plugins/knowledge/src/utils/data_loader.py
index 08bdf414..3b5e8d2e 100644
--- a/src/plugins/knowledge/src/utils/data_loader.py
+++ b/src/plugins/knowledge/src/utils/data_loader.py
@@ -14,11 +14,7 @@ class DataLoader:
Args:
custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
"""
- self.data_dir = (
- Path(custom_data_dir)
- if custom_data_dir
- else Path(config["persistence"]["data_root_path"])
- )
+ self.data_dir = Path(custom_data_dir) if custom_data_dir else Path(config["persistence"]["data_root_path"])
if not self.data_dir.exists():
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
diff --git a/src/plugins/knowledge/src/utils/dyn_topk.py b/src/plugins/knowledge/src/utils/dyn_topk.py
index 02bc5e3e..eb40ef3a 100644
--- a/src/plugins/knowledge/src/utils/dyn_topk.py
+++ b/src/plugins/knowledge/src/utils/dyn_topk.py
@@ -36,14 +36,10 @@ def dyn_select_top_k(
# 计算均值
mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
# 计算方差
- var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(
- normalized_score
- )
+ var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(normalized_score)
# 动态阈值
- threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (
- mean_score + var_factor * var_score
- )
+ threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (mean_score + var_factor * var_score)
# 重新过滤
res = [s for s in normalized_score if s[2] > threshold]
diff --git a/src/plugins/knowledge/src/utils/json_fix.py b/src/plugins/knowledge/src/utils/json_fix.py
index 672fb1f8..a83eb491 100644
--- a/src/plugins/knowledge/src/utils/json_fix.py
+++ b/src/plugins/knowledge/src/utils/json_fix.py
@@ -29,10 +29,7 @@ def _find_unclosed(json_str):
elif char in "{[":
unclosed.append(char)
elif char in "}]":
- if unclosed and (
- (char == "}" and unclosed[-1] == "{")
- or (char == "]" and unclosed[-1] == "[")
- ):
+ if unclosed and ((char == "}" and unclosed[-1] == "{") or (char == "]" and unclosed[-1] == "[")):
unclosed.pop()
return unclosed
diff --git a/src/plugins/knowledge/src/utils/visualize_graph.py b/src/plugins/knowledge/src/utils/visualize_graph.py
index 845e18a2..7ca9b7e6 100644
--- a/src/plugins/knowledge/src/utils/visualize_graph.py
+++ b/src/plugins/knowledge/src/utils/visualize_graph.py
@@ -14,4 +14,4 @@ def draw_graph_and_show(graph):
font_family="Sarasa Mono SC",
font_size=8,
)
- fig.show()
\ No newline at end of file
+ fig.show()