diff --git a/.gitignore b/.gitignore
index b9e101e4..c93feb2e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -226,4 +226,10 @@ logs
.vscode
-/config/*
\ No newline at end of file
+/config/*
+
+!**/knowledge/lib
+**/lib/quick_algo/build/
+**/lib/quick_algo/cython_debug/
+**/lib/quick_algo/pagerank.c
+**/lib/quick_algo/pagerank.cp312-win_amd64.pyd
\ No newline at end of file
diff --git a/bot.py b/bot.py
index a0bf3a3c..7cf0c044 100644
--- a/bot.py
+++ b/bot.py
@@ -46,6 +46,10 @@ def init_config():
shutil.copy("template/bot_config_template.toml", "config/bot_config.toml")
logger.info("复制完成,请修改config/bot_config.toml和.env中的配置后重新启动")
+ if not os.path.exists("config/lpmm_config.toml"):
+ logger.warning("检测到lpmm_config.toml不存在,正在从模板复制")
+ shutil.copy("template/lpmm_config_template.toml", "config/lpmm_config.toml")
+ logger.info("复制完成,请修改config/lpmm_config.toml中的配置后重新启动")
def init_env():
diff --git a/import_openie.py b/import_openie.py
new file mode 100644
index 00000000..0f0bf0cd
--- /dev/null
+++ b/import_openie.py
@@ -0,0 +1,153 @@
+try:
+ import src.plugins.knowledge.lib.quick_algo
+except ImportError:
+ print("未找到quick_algo库,无法使用quick_algo算法")
+ print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
+
+
+import sys
+from typing import Dict, List
+
+from src.plugins.knowledge.src.config import PG_NAMESPACE, global_config
+from src.plugins.knowledge.src.embedding_store import EmbeddingManager
+from src.plugins.knowledge.src.llm_client import LLMClient
+from src.plugins.knowledge.src.open_ie import OpenIE
+from src.plugins.knowledge.src.kg_manager import KGManager
+from src.plugins.knowledge.src.global_logger import logger
+from src.plugins.knowledge.src.utils.hash import get_sha256
+
+
+def hash_deduplicate(
+ raw_paragraphs: Dict[str, str],
+ triple_list_data: Dict[str, List[List[str]]],
+ stored_pg_hashes: set,
+ stored_paragraph_hashes: set,
+):
+ """Hash去重
+
+ Args:
+ raw_paragraphs: 索引的段落原文
+ triple_list_data: 索引的三元组列表
+ stored_pg_hashes: 已存储的段落hash集合
+ stored_paragraph_hashes: 已存储的段落hash集合
+
+ Returns:
+ new_raw_paragraphs: 去重后的段落
+ new_triple_list_data: 去重后的三元组
+ """
+ # 保存去重后的段落
+ new_raw_paragraphs = dict()
+ # 保存去重后的三元组
+ new_triple_list_data = dict()
+
+ for _, (raw_paragraph, triple_list) in enumerate(
+ zip(raw_paragraphs.values(), triple_list_data.values())
+ ):
+ # 段落hash
+ paragraph_hash = get_sha256(raw_paragraph)
+ if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (
+ paragraph_hash in stored_paragraph_hashes
+ ):
+ continue
+ new_raw_paragraphs[paragraph_hash] = raw_paragraph
+ new_triple_list_data[paragraph_hash] = triple_list
+
+ return new_raw_paragraphs, new_triple_list_data
+
+
+def handle_import_openie(
+ openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager
+) -> bool:
+ # 从OpenIE数据中提取段落原文与三元组列表
+ # 索引的段落原文
+ raw_paragraphs = openie_data.extract_raw_paragraph_dict()
+ # 索引的实体列表
+ entity_list_data = openie_data.extract_entity_dict()
+ # 索引的三元组列表
+ triple_list_data = openie_data.extract_triple_dict()
+ if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(
+ triple_list_data
+ ):
+ logger.error("OpenIE数据存在异常")
+ return False
+ # 将索引换为对应段落的hash值
+ logger.info("正在进行段落去重与重索引")
+ raw_paragraphs, triple_list_data = hash_deduplicate(
+ raw_paragraphs,
+ triple_list_data,
+ embed_manager.stored_pg_hashes,
+ kg_manager.stored_paragraph_hashes,
+ )
+ if len(raw_paragraphs) != 0:
+ # 获取嵌入并保存
+ logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
+ logger.info("开始Embedding")
+ embed_manager.store_new_data_set(raw_paragraphs, triple_list_data)
+ # Embedding-Faiss重索引
+ logger.info("正在重新构建向量索引")
+ embed_manager.rebuild_faiss_index()
+ logger.info("向量索引构建完成")
+ embed_manager.save_to_file()
+ logger.info("Embedding完成")
+ # 构建新段落的RAG
+ logger.info("开始构建RAG")
+ kg_manager.build_kg(triple_list_data, embed_manager)
+ kg_manager.save_to_file()
+ logger.info("RAG构建完成")
+ else:
+ logger.info("无新段落需要处理")
+ return True
+
+
+def main():
+ logger.info("----开始导入openie数据----\n")
+
+ logger.info("创建LLM客户端")
+ llm_client_list = dict()
+ for key in global_config["llm_providers"]:
+ llm_client_list[key] = LLMClient(
+ global_config["llm_providers"][key]["base_url"],
+ global_config["llm_providers"][key]["api_key"],
+ )
+
+ # 初始化Embedding库
+ embed_manager = embed_manager = EmbeddingManager(
+ llm_client_list[global_config["embedding"]["provider"]]
+ )
+ logger.info("正在从文件加载Embedding库")
+ try:
+ embed_manager.load_from_file()
+ except Exception as e:
+ logger.error("从文件加载Embedding库时发生错误:{}".format(e))
+ logger.info("Embedding库加载完成")
+ # 初始化KG
+ kg_manager = KGManager()
+ logger.info("正在从文件加载KG")
+ try:
+ kg_manager.load_from_file()
+ except Exception as e:
+ logger.error("从文件加载KG时发生错误:{}".format(e))
+ logger.info("KG加载完成")
+
+ logger.info(f"KG节点数量:{len(kg_manager.graph.nodes)}")
+ logger.info(f"KG边数量:{len(kg_manager.graph.edges)}")
+
+ # 数据比对:Embedding库与KG的段落hash集合
+ for pg_hash in kg_manager.stored_paragraph_hashes:
+ key = PG_NAMESPACE + "-" + pg_hash
+ if key not in embed_manager.stored_pg_hashes:
+ logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
+
+ logger.info("正在导入OpenIE数据文件")
+ try:
+ openie_data = OpenIE.load()
+ except Exception as e:
+ logger.error("导入OpenIE数据文件时发生错误:{}".format(e))
+ return False
+ if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
+ logger.error("处理OpenIE数据时发生错误")
+ return False
+
+
+if __name__ == "__main__":
+ main()
diff --git a/info_extraction.py b/info_extraction.py
new file mode 100644
index 00000000..2cdd7301
--- /dev/null
+++ b/info_extraction.py
@@ -0,0 +1,157 @@
+import json
+import os
+import signal
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from threading import Lock, Event
+
+import tqdm
+
+from src.plugins.knowledge.src.global_logger import logger
+from src.plugins.knowledge.src.config import global_config
+from src.plugins.knowledge.src.ie_process import info_extract_from_str
+from src.plugins.knowledge.src.llm_client import LLMClient
+from src.plugins.knowledge.src.open_ie import OpenIE
+from src.plugins.knowledge.src.raw_processing import load_raw_data
+
+TEMP_DIR = "./temp"
+
+# 创建一个线程安全的锁,用于保护文件操作和共享数据
+file_lock = Lock()
+open_ie_doc_lock = Lock()
+
+# 创建一个事件标志,用于控制程序终止
+shutdown_event = Event()
+
+
+def process_single_text(pg_hash, raw_data, llm_client_list):
+ """处理单个文本的函数,用于线程池"""
+ temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
+
+ # 使用文件锁检查和读取缓存文件
+ with file_lock:
+ if os.path.exists(temp_file_path):
+ try:
+ # 存在对应的提取结果
+ logger.info(f"找到缓存的提取结果:{pg_hash}")
+ with open(temp_file_path, "r", encoding="utf-8") as f:
+ return json.load(f), None
+ except json.JSONDecodeError:
+ # 如果JSON文件损坏,删除它并重新处理
+ logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
+ os.remove(temp_file_path)
+
+ entity_list, rdf_triple_list = info_extract_from_str(
+ llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
+ llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
+ raw_data,
+ )
+ if entity_list is None or rdf_triple_list is None:
+ return None, pg_hash
+ else:
+ doc_item = {
+ "idx": pg_hash,
+ "passage": raw_data,
+ "extracted_entities": entity_list,
+ "extracted_triples": rdf_triple_list,
+ }
+ # 保存临时提取结果
+ with file_lock:
+ try:
+ with open(temp_file_path, "w", encoding="utf-8") as f:
+ json.dump(doc_item, f, ensure_ascii=False, indent=4)
+ except Exception as e:
+ logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
+ # 如果保存失败,确保不会留下损坏的文件
+ if os.path.exists(temp_file_path):
+ os.remove(temp_file_path)
+ # 设置shutdown_event以终止程序
+ shutdown_event.set()
+ return None, pg_hash
+ return doc_item, None
+
+
+def signal_handler(signum, frame):
+ """处理Ctrl+C信号"""
+ logger.info("\n接收到中断信号,正在优雅地关闭程序...")
+ shutdown_event.set()
+
+
+def main():
+ # 设置信号处理器
+ signal.signal(signal.SIGINT, signal_handler)
+
+ logger.info("--------进行信息提取--------\n")
+
+ logger.info("创建LLM客户端")
+ llm_client_list = dict()
+ for key in global_config["llm_providers"]:
+ llm_client_list[key] = LLMClient(
+ global_config["llm_providers"][key]["base_url"],
+ global_config["llm_providers"][key]["api_key"],
+ )
+
+ logger.info("正在加载原始数据")
+ sha256_list, raw_datas = load_raw_data()
+ logger.info("原始数据加载完成\n")
+
+ # 创建临时目录
+ if not os.path.exists(f"{TEMP_DIR}"):
+ os.makedirs(f"{TEMP_DIR}")
+
+ failed_sha256 = []
+ open_ie_doc = []
+
+ # 创建线程池,最大线程数为50
+ with ThreadPoolExecutor(max_workers=20) as executor:
+ # 提交所有任务到线程池
+ future_to_hash = {
+ executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
+ for pg_hash, raw_data in zip(sha256_list, raw_datas)
+ }
+
+ # 使用tqdm显示进度
+ with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar:
+ # 处理完成的任务
+ try:
+ for future in as_completed(future_to_hash):
+ if shutdown_event.is_set():
+ # 取消所有未完成的任务
+ for f in future_to_hash:
+ if not f.done():
+ f.cancel()
+ break
+
+ doc_item, failed_hash = future.result()
+ if failed_hash:
+ failed_sha256.append(failed_hash)
+ logger.error(f"提取失败:{failed_hash}")
+ elif doc_item:
+ with open_ie_doc_lock:
+ open_ie_doc.append(doc_item)
+ pbar.update(1)
+ except KeyboardInterrupt:
+ # 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作
+ logger.info("\n接收到中断信号,正在优雅地关闭程序...")
+ shutdown_event.set()
+ # 取消所有未完成的任务
+ for f in future_to_hash:
+ if not f.done():
+ f.cancel()
+
+ # 保存信息提取结果
+ sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
+ sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
+ num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
+ openie_obj = OpenIE(
+ open_ie_doc,
+ round(sum_phrase_chars / num_phrases, 4),
+ round(sum_phrase_words / num_phrases, 4),
+ )
+ OpenIE.save(openie_obj)
+
+ logger.info("--------信息提取完成--------")
+ logger.info(f"提取失败的文段SHA256:{failed_sha256}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
index cea511f1..d1c81c52 100644
Binary files a/requirements.txt and b/requirements.txt differ
diff --git a/src/common/logger.py b/src/common/logger.py
index 9e118622..cf978fc3 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -290,6 +290,22 @@ WILLING_STYLE_CONFIG = {
},
}
+KNOWLEDGE_STYLE_CONFIG = {
+ "advanced": {
+ "console_format": (
+ "{time:YYYY-MM-DD HH:mm:ss} | "
+ "{level: <8} | "
+ "{extra[module]: <12} | "
+ "知识 | "
+ "{message}"
+ ),
+ "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 知识 | {message}"),
+ },
+ "simple": {
+ "console_format": ("{time:MM-DD HH:mm} | 知识 | {message}"),
+ "file_format": ("{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | 知识 | {message}"),
+ },
+}
# 根据SIMPLE_OUTPUT选择配置
MEMORY_STYLE_CONFIG = MEMORY_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MEMORY_STYLE_CONFIG["advanced"]
@@ -306,7 +322,7 @@ SUB_HEARTFLOW_STYLE_CONFIG = (
) # noqa: E501
WILLING_STYLE_CONFIG = WILLING_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else WILLING_STYLE_CONFIG["advanced"]
CONFIG_STYLE_CONFIG = CONFIG_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CONFIG_STYLE_CONFIG["advanced"]
-
+KNOWLEDGE_STYLE_CONFIG = KNOWLEDGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else KNOWLEDGE_STYLE_CONFIG["advanced"]
def is_registered_module(record: dict) -> bool:
"""检查是否为已注册的模块"""
diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
index 508febec..3185793c 100644
--- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
+++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py
@@ -10,6 +10,7 @@ from ...config.config import global_config
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text
from ...chat.chat_stream import chat_manager
from src.common.logger import get_module_logger
+from ...knowledge.knowledge_lib import qa_manager
logger = get_module_logger("prompt")
@@ -140,74 +141,9 @@ class PromptBuilder:
async def get_prompt_info(self, message: str, threshold: float):
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
- embedding = await get_embedding(message, request_type="prompt_build")
- related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold)
+ related_info += qa_manager.get_knowledge(message)
return related_info
- def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
- if not query_embedding:
- return ""
- # 使用余弦相似度计算
- pipeline = [
- {
- "$addFields": {
- "dotProduct": {
- "$reduce": {
- "input": {"$range": [0, {"$size": "$embedding"}]},
- "initialValue": 0,
- "in": {
- "$add": [
- "$$value",
- {
- "$multiply": [
- {"$arrayElemAt": ["$embedding", "$$this"]},
- {"$arrayElemAt": [query_embedding, "$$this"]},
- ]
- },
- ]
- },
- }
- },
- "magnitude1": {
- "$sqrt": {
- "$reduce": {
- "input": "$embedding",
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- "magnitude2": {
- "$sqrt": {
- "$reduce": {
- "input": query_embedding,
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- }
- },
- {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
- {
- "$match": {
- "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
- }
- },
- {"$sort": {"similarity": -1}},
- {"$limit": limit},
- {"$project": {"content": 1, "similarity": 1}},
- ]
-
- results = list(db.knowledges.aggregate(pipeline))
- # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
-
- if not results:
- return ""
-
- # 返回所有找到的内容,用换行分隔
- return "\n".join(str(result["content"]) for result in results)
-
prompt_builder = PromptBuilder()
diff --git a/src/plugins/knowledge/LICENSE b/src/plugins/knowledge/LICENSE
new file mode 100644
index 00000000..f288702d
--- /dev/null
+++ b/src/plugins/knowledge/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py
new file mode 100644
index 00000000..7822ac01
--- /dev/null
+++ b/src/plugins/knowledge/knowledge_lib.py
@@ -0,0 +1,66 @@
+try:
+ import lib.quick_algo
+except ImportError:
+ print("未找到quick_algo库,无法使用quick_algo算法")
+ print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
+
+from .src.config import PG_NAMESPACE, global_config
+from .src.embedding_store import EmbeddingManager
+from .src.llm_client import LLMClient
+from .src.mem_active_manager import MemoryActiveManager
+from .src.qa_manager import QAManager
+from .src.kg_manager import KGManager
+from .src.global_logger import logger
+
+logger.info("正在初始化Mai-LPMM\n")
+logger.info("创建LLM客户端")
+llm_client_list = dict()
+for key in global_config["llm_providers"]:
+ llm_client_list[key] = LLMClient(
+ global_config["llm_providers"][key]["base_url"],
+ global_config["llm_providers"][key]["api_key"],
+ )
+
+# 初始化Embedding库
+embed_manager = EmbeddingManager(
+ llm_client_list[global_config["embedding"]["provider"]]
+)
+logger.info("正在从文件加载Embedding库")
+try:
+ embed_manager.load_from_file()
+except Exception as e:
+ logger.error("从文件加载Embedding库时发生错误:{}".format(e))
+logger.info("Embedding库加载完成")
+# 初始化KG
+kg_manager = KGManager()
+logger.info("正在从文件加载KG")
+try:
+ kg_manager.load_from_file()
+except Exception as e:
+ logger.error("从文件加载KG时发生错误:{}".format(e))
+logger.info("KG加载完成")
+
+logger.info(f"KG节点数量:{len(kg_manager.graph.nodes)}")
+logger.info(f"KG边数量:{len(kg_manager.graph.edges)}")
+
+# 数据比对:Embedding库与KG的段落hash集合
+for pg_hash in kg_manager.stored_paragraph_hashes:
+ key = PG_NAMESPACE + "-" + pg_hash
+ if key not in embed_manager.stored_pg_hashes:
+ logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
+
+# 问答系统(用于知识库)
+qa_manager = QAManager(
+ embed_manager,
+ kg_manager,
+ llm_client_list[global_config["embedding"]["provider"]],
+ llm_client_list[global_config["qa"]["llm"]["provider"]],
+ llm_client_list[global_config["qa"]["llm"]["provider"]],
+)
+
+# 记忆激活(用于记忆库)
+inspire_manager = MemoryActiveManager(
+ embed_manager,
+ llm_client_list[global_config["embedding"]["provider"]],
+)
+
diff --git a/src/plugins/knowledge/lib/quick_algo/__init__.py b/src/plugins/knowledge/lib/quick_algo/__init__.py
new file mode 100644
index 00000000..c3269ea6
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/__init__.py
@@ -0,0 +1,64 @@
+from typing import Tuple, List, Dict
+
+import networkx as nx
+
+from .pagerank import run_personalized_pagerank
+
+
+def _nx_graph_to_lists(
+ graph: nx.Graph,
+) -> Tuple[List[Tuple[str, str, float]], List[str]]:
+ """
+ Convert a NetworkX graph to lists of edges and nodes.
+
+ Parameters
+ ----------
+ graph : NetworkX graph
+ The input graph.
+
+ Returns
+ -------
+ tuple
+ A tuple containing the list of edges and the list of nodes.
+ """
+ nodes = [node for node in graph.nodes()]
+ edges = [
+ (u, v, graph.get_edge_data(u, v).get("weight", 0.0)) for u, v in graph.edges()
+ ]
+
+ return edges, nodes
+
+
+def pagerank(
+ graph: nx.Graph,
+ personalized: None | Dict[str, float] = None,
+ alpha: float = 0.85,
+ max_iter: int = 100,
+ tol: float = 1e-6,
+) -> List[Tuple[str, float]]:
+ """
+ Compute the PageRank of a graph.
+
+ Parameters
+ ----------
+ graph : NetworkX graph
+ The input graph.
+ personalized : dict, optional
+ The personalization vector.
+ alpha : float, optional
+ The teleport probability.
+ max_iter : int, optional
+ The maximum number of iterations.
+ tol : float, optional
+ The tolerance for convergence.
+ return_type : str, optional
+ The return type. Can be 'numpy' or 'list'.
+
+ Returns
+ -------
+ numpy.ndarray or list
+ The PageRank vector.
+ """
+ edges, nodes = _nx_graph_to_lists(graph)
+
+ return run_personalized_pagerank(nodes, edges, personalized, alpha, max_iter, tol)
diff --git a/src/plugins/knowledge/lib/quick_algo/pagerank.pxd b/src/plugins/knowledge/lib/quick_algo/pagerank.pxd
new file mode 100644
index 00000000..e508cd4f
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/pagerank.pxd
@@ -0,0 +1,15 @@
+cdef extern from "quick_algo.h":
+ struct Edge:
+ long long src
+ long long dst
+ double weight
+
+ double *pagerank(
+ Edge *edges,
+ long long num_edges,
+ double *personalization,
+ long long num_nodes,
+ double alpha,
+ int max_iter,
+ double tol
+ )
\ No newline at end of file
diff --git a/src/plugins/knowledge/lib/quick_algo/pagerank.pyx b/src/plugins/knowledge/lib/quick_algo/pagerank.pyx
new file mode 100644
index 00000000..4edd0081
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/pagerank.pyx
@@ -0,0 +1,53 @@
+import time
+
+from cpython.mem cimport PyMem_Malloc, PyMem_Free
+
+def run_personalized_pagerank(
+ node_list: list[str],
+ edge_list: list[tuple[str, str, float]],
+ personalization: dict[str, float] = None,
+ double alpha=0.85,
+ int max_iter=100,
+ double tol=1e-6
+) -> list[tuple[str, double]]:
+ cdef long long num_nodes = len(node_list)
+ cdef long long num_edges = len(edge_list)
+ cdef Edge* edges = PyMem_Malloc(num_edges * sizeof(Edge))
+ cdef double* personalization_array = PyMem_Malloc(num_nodes * sizeof(double))
+ cdef double* result_items
+ cdef long long i
+
+ # 映射结构:节点name到索引
+ node_to_index = {node: i for i, node in enumerate(node_list)}
+
+ # 将图的边数据转化为 C 结构
+ for i, (u, v, w) in enumerate(edge_list):
+ edges[i].src = node_to_index[u]
+ edges[i].dst = node_to_index[v]
+ edges[i].weight = w
+
+ # 将个性化参数转化为 C 数组
+ if personalization is None or len(personalization) == 0:
+ for i in range(num_nodes):
+ personalization_array[i] = 1.0 / num_nodes
+ else:
+ for node, i in node_to_index.items():
+ personalization_array[i] = personalization.get(node, 0.0)
+
+ ppr_start_time = time.perf_counter()
+ # 调用C语言实现的PageRank算法
+ result_items = pagerank(edges, num_edges, personalization_array, num_nodes, alpha, max_iter, tol)
+ ppr_end_time = time.perf_counter()
+ print(f"PageRank计算耗时: {ppr_end_time - ppr_start_time:.8f}秒")
+
+ # 将返回结果转化为Python列表
+ result_list = []
+ for node, i in node_to_index.items():
+ result_list.append((node, result_items[i]))
+
+ # 释放分配的内存
+ PyMem_Free(edges)
+ PyMem_Free(personalization_array)
+ PyMem_Free(result_items)
+
+ return result_list
\ No newline at end of file
diff --git a/src/plugins/knowledge/lib/quick_algo/pr.c b/src/plugins/knowledge/lib/quick_algo/pr.c
new file mode 100644
index 00000000..07f79a20
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/pr.c
@@ -0,0 +1,220 @@
+#define __USE_MINGW_ANSI_STDIO 1
+#include "stdio.h"
+#include "malloc.h"
+#include "string.h"
+#include "math.h"
+#include "immintrin.h"
+#include "quick_algo.h"
+
+// 以下头文件用于多线程优化
+#include "omp.h"
+
+// Comparison function for qsort
+int compare_edges(const void *a, const void *b)
+{
+ struct Edge *edge_a = (struct Edge *)a;
+ struct Edge *edge_b = (struct Edge *)b;
+ // 异或操作符用于比较两个边的起始节点和结束节点
+ if (edge_a->src ^ edge_b->src)
+ return edge_a->src - edge_b->src;
+ else
+ return edge_a->dst - edge_b->dst;
+}
+
+/**
+ * 个性化PageRank算法
+ */
+double *pagerank(
+ struct Edge *edges, // 边数组
+ long long num_edges, // 边数量
+ double *personalization, // 个性化向量
+ long long num_nodes, // 节点数量
+ double alpha, // 阻尼系数
+ int max_iter, // 最大迭代次数
+ double tol // 收敛阈值
+)
+{
+ int num_threads = omp_get_max_threads(); // 获取最大线程数
+
+ // 重新排列边顺序,按照先起始节点,后结束节点的顺序排列
+ // 该操作将相同源点的边放在一起,减少跨越内存页的访问
+ qsort(edges, num_edges, sizeof(struct Edge), compare_edges);
+
+ {
+ // 将同源边根据权重转化为概率分布
+ // 转化后,同源边的权重之和为1
+ double sum_weight;
+ long long now_src;
+ for (long long i_start = 0; i_start < num_edges; i_start++)
+ {
+ now_src = edges[i_start].src; // 当前源点
+ sum_weight = edges[i_start].weight; // 初始化权重之和为当前边的权重
+ // 寻找其它同源边
+ for (long long i_end = i_start + 1; i_end <= num_edges; i_end++)
+ {
+ if (i_end == num_edges || edges[i_end].src != now_src)
+ {
+ // 若结束指针指向了不同的源点,或者已经到达了最后一条边
+ // 则将区间内的边的权重进行归一化
+ // 归一化后的权重为:weight[i] = weight[i] / sum_weight
+ for (long long i = i_start; i < i_end; i++)
+ {
+ edges[i].weight /= sum_weight;
+ }
+ // 更新起始指针到结束指针
+ i_start = i_end;
+ }
+ else
+ {
+ // 否则,继续累加权重
+ sum_weight += edges[i_end].weight;
+ }
+ }
+ }
+ }
+
+ {
+ // 个性化向量归一化
+ double max_value = 0.0L;
+ double min_value = 1.0L;
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ if (personalization[i] > max_value)
+ {
+ max_value = personalization[i];
+ }
+ if (personalization[i] < min_value)
+ {
+ min_value = personalization[i];
+ }
+ }
+ if (max_value == min_value)
+ {
+ // 如果所有值相同,则将所有值设置为1.0/num_nodes
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ personalization[i] = 1.0L / num_nodes;
+ }
+ }
+ else
+ {
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ personalization[i] = (personalization[i] - min_value) / (max_value - min_value);
+ }
+ }
+ }
+
+ // 初始化Score向量
+ double *score = (double *)calloc(num_nodes, sizeof(double)); // 初始化Score向量为0
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ score[i] = personalization[i];
+ }
+
+ // 迭代计算PageRank
+ // 对于每轮迭代:
+ // 1. 计算新的Score向量:即新的Score[i] = (1 - alpha) * personalization[i] + alpha * sum(Score[j] / weight[j])
+ // 其中,j是所有指向i的节点,weight[j]是边的权重
+ // 2. 检查收敛条件
+ // 3. 更新Score向量
+ double *tmp_score = (double *)malloc(4 * sizeof(double)); // 临时Score向量
+
+ for (int iter = 0; iter < max_iter; iter++)
+ {
+ double *new_score = (double *)calloc(num_nodes, sizeof(double)); // 初始化新Score向量为0
+
+ // 多线程优化
+ // 将边初始化过程交给多个线程
+ // 原始算法:
+ // for (long long i = 0; i < num_nodes; i++)
+ // new_score[i] = (1 - alpha) * personalization[i];
+ // 这里使用了OpenMP的并行化方法
+
+#pragma omp parallel for
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ new_score[i] = (1 - alpha) * personalization[i];
+ }
+
+ // 计算新的Score向量
+ // 原始算法:
+ // for (long long i = 0; i < num_edges; i++)
+ // new_score[edges[i].dst] += alpha * score[edges[i].src] * edges[i].weight;
+ // 这里应用SIMD指令进行向量化计算
+ {
+ long long i;
+ __m256d alpha_val = _mm256_set1_pd(alpha);
+
+ for (i = 0; i < num_edges - 3; i += 4)
+ {
+ // 使用SIMD指令进行向量化计算
+ __m256d src_score = _mm256_set_pd(score[edges[i].src], score[edges[i + 1].src], score[edges[i + 2].src], score[edges[i + 3].src]);
+ __m256d weights = _mm256_set_pd(edges[i].weight, edges[i + 1].weight, edges[i + 2].weight, edges[i + 3].weight);
+ __m256d new_score_val = _mm256_mul_pd(alpha_val, src_score);
+ new_score_val = _mm256_mul_pd(new_score_val, weights);
+ // 将结果存储到临时变量中
+ _mm256_store_pd(tmp_score, new_score_val); // 存储结果
+
+ // 更新新Score向量
+ for (int j = 0; j < 4; j++)
+ {
+ // tmp_score里的数据是反向存储的
+ new_score[edges[i + j].dst] += tmp_score[3 - j];
+ }
+ }
+ // 处理剩余的边
+ for (; i < num_edges; i++)
+ {
+ new_score[edges[i].dst] += alpha * score[edges[i].src] * edges[i].weight;
+ }
+ }
+
+ // 检查收敛
+ double diff = 0.0L;
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ diff += fabs(new_score[i] - score[i]);
+ }
+
+ // 更新Score向量
+ free(score);
+ score = new_score;
+
+ if (diff < tol)
+ break;
+ }
+
+ // 释放临时Score向量
+ free(tmp_score);
+
+ return score;
+}
+
+int main()
+{
+ // 测试代码
+ struct Edge edges[] = {
+ {0, 1, 0.5},
+ {1, 2, 0.3},
+ {2, 0, 0.2},
+ {1, 3, 0.4},
+ {3, 4, 0.6},
+ {4, 1, 0.7}};
+ long long num_edges = sizeof(edges) / sizeof(edges[0]);
+ double personalization[] = {1.0, 2.0, 3.0, 4.0, 5.0};
+ long long num_nodes = sizeof(personalization) / sizeof(personalization[0]);
+ double alpha = 0.85;
+ int max_iter = 100;
+ double tol = 1e-6;
+
+ double *result = pagerank(edges, num_edges, personalization, num_nodes, alpha, max_iter, tol);
+
+ for (long long i = 0; i < num_nodes; i++)
+ {
+ printf("Node %lld: %f\n", i + 1, result[i]);
+ }
+
+ free(result);
+ return 0;
+}
\ No newline at end of file
diff --git a/src/plugins/knowledge/lib/quick_algo/quick_algo.h b/src/plugins/knowledge/lib/quick_algo/quick_algo.h
new file mode 100644
index 00000000..4a6d26f5
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/quick_algo.h
@@ -0,0 +1,22 @@
+#ifndef PAGERANK_H
+#define PAGERANK_H
+
+// Struct of edge
+struct Edge
+{
+ long long src; // 边的起始节点
+ long long dst; // 边的结束节点
+ double weight; // 边的权重
+};
+
+double *pagerank(
+ struct Edge *edges, // 边数组
+ long long num_edges, // 边数量
+ double *personalization, // 个性化向量
+ long long num_nodes, // 节点数量
+ double alpha, // 阻尼系数
+ int max_iter, // 最大迭代次数
+ double tol // 收敛阈值
+);
+
+#endif // PAGERANK_H
\ No newline at end of file
diff --git a/src/plugins/knowledge/lib/quick_algo/setup.py b/src/plugins/knowledge/lib/quick_algo/setup.py
new file mode 100644
index 00000000..6921e101
--- /dev/null
+++ b/src/plugins/knowledge/lib/quick_algo/setup.py
@@ -0,0 +1,17 @@
+from setuptools import setup, Extension
+from Cython.Build import cythonize
+
+ext_modules = [
+ Extension(
+ "pagerank",
+ sources=["pagerank.pyx", "pr.c"],
+ include_dirs=["."],
+ libraries=[],
+ language="c",
+ )
+]
+
+setup(
+ name="quick_algo",
+ ext_modules=cythonize(ext_modules, gdb_debug=True),
+)
diff --git a/src/plugins/knowledge/src/__init__.py b/src/plugins/knowledge/src/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/plugins/knowledge/src/config.py b/src/plugins/knowledge/src/config.py
new file mode 100644
index 00000000..fa42a481
--- /dev/null
+++ b/src/plugins/knowledge/src/config.py
@@ -0,0 +1,122 @@
+import os
+import toml
+from .global_logger import logger
+PG_NAMESPACE = "paragraph"
+ENT_NAMESPACE = "entity"
+REL_NAMESPACE = "relation"
+
+RAG_GRAPH_NAMESPACE = "rag-graph"
+RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
+RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
+
+# 无效实体
+INVALID_ENTITY = [
+ "",
+ "你",
+ "他",
+ "她",
+ "它",
+ "我们",
+ "你们",
+ "他们",
+ "她们",
+ "它们",
+]
+
+
+def _load_config(config, config_file_path):
+ """读取TOML格式的配置文件"""
+ if not os.path.exists(config_file_path):
+ return
+ with open(config_file_path, "r", encoding="utf-8") as f:
+ file_config = toml.load(f)
+
+ if "llm_providers" in file_config:
+ for provider in file_config["llm_providers"]:
+ if provider["name"] not in config["llm_providers"]:
+ config["llm_providers"][provider["name"]] = dict()
+ config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
+ config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
+
+ if "entity_extract" in file_config:
+ config["entity_extract"] = file_config["entity_extract"]
+
+ if "rdf_build" in file_config:
+ config["rdf_build"] = file_config["rdf_build"]
+
+ if "embedding" in file_config:
+ config["embedding"] = file_config["embedding"]
+
+ if "rag" in file_config:
+ config["rag"] = file_config["rag"]
+
+ if "qa" in file_config:
+ config["qa"] = file_config["qa"]
+
+ if "persistence" in file_config:
+ config["persistence"] = file_config["persistence"]
+
+ logger.info(f"Configurations loaded from file: {config_file_path}")
+
+global_config = dict(
+ {
+ "llm_providers": {
+ "localhost": {
+ "base_url": "http://localhost:8000",
+ "api_key": "",
+ }
+ },
+ "entity_extract": {
+ "llm": {
+ "provider": "localhost",
+ "model": "entity-extract",
+ }
+ },
+ "rdf_build": {
+ "llm": {
+ "provider": "localhost",
+ "model": "rdf-build",
+ }
+ },
+ "embedding": {
+ "provider": "localhost",
+ "model": "embed",
+ "dimension": 1024,
+ },
+ "rag": {
+ "params": {
+ "synonym_search_top_k": 10,
+ "synonym_threshold": 0.75,
+ }
+ },
+ "qa": {
+ "params": {
+ "relation_search_top_k": 10,
+ "relation_threshold": 0.75,
+ "paragraph_search_top_k": 10,
+ "paragraph_node_weight": 0.05,
+ "ent_filter_top_k": 10,
+ "ppr_damping": 0.8,
+ "res_top_k": 10,
+ },
+ "llm": {
+ "provider": "localhost",
+ "model": "qa",
+ },
+ },
+ "persistence": {
+ "data_root_path": "data",
+ "raw_data_path": "data/raw.json",
+ "openie_data_path": "data/openie.json",
+ "embedding_data_dir": "data/embedding",
+ "rag_data_dir": "data/rag",
+ },
+ }
+)
+
+# _load_config(global_config, parser.parse_args().config_path)
+file_path = os.path.abspath(__file__)
+dir_path = os.path.dirname(file_path)
+root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
+config_path = os.path.join(root_path, "config", "lpmm_config.toml")
+_load_config(global_config, config_path)
\ No newline at end of file
diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py
new file mode 100644
index 00000000..0f1dfdde
--- /dev/null
+++ b/src/plugins/knowledge/src/embedding_store.py
@@ -0,0 +1,251 @@
+from dataclasses import dataclass
+import json
+import os
+from typing import Dict, List, Tuple
+
+import numpy as np
+import pandas as pd
+import tqdm
+import faiss
+
+from .llm_client import LLMClient
+from .config import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
+from .utils.hash import get_sha256
+from .global_logger import logger
+
+
+@dataclass
+class EmbeddingStoreItem:
+ """嵌入库中的项"""
+
+ def __init__(self, item_hash: str, embedding: List[float], content: str):
+ self.hash = item_hash
+ self.embedding = embedding
+ self.str = content
+
+ def to_dict(self) -> dict:
+ """转为dict"""
+ return {
+ "hash": self.hash,
+ "embedding": self.embedding,
+ "str": self.str,
+ }
+
+
+class EmbeddingStore:
+ def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
+ self.namespace = namespace
+ self.llm_client = llm_client
+ self.dir = dir_path
+ self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
+ self.index_file_path = dir_path + "/" + namespace + ".index"
+ self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
+
+ self.store = dict()
+
+ self.faiss_index = None
+ self.idx2hash = None
+
+ def _get_embedding(self, s: str) -> List[float]:
+ return self.llm_client.send_embedding_request(
+ global_config["embedding"]["model"], s
+ )
+
+ def batch_insert_strs(self, strs: List[str]) -> None:
+ """向库中存入字符串"""
+ # 逐项处理
+ for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"):
+ # 计算hash去重
+ item_hash = self.namespace + "-" + get_sha256(s)
+ if item_hash in self.store:
+ continue
+
+ # 获取embedding
+ embedding = self._get_embedding(s)
+
+ # 存入
+ self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
+
+ def save_to_file(self) -> None:
+ """保存到文件"""
+ data = []
+ logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
+ for item in self.store.values():
+ data.append(item.to_dict())
+ data_frame = pd.DataFrame(data)
+
+ if not os.path.exists(self.dir):
+ os.makedirs(self.dir, exist_ok=True)
+ if not os.path.exists(self.embedding_file_path):
+ open(self.embedding_file_path, "w").close()
+
+ data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
+ logger.info(f"{self.namespace}嵌入库保存成功")
+
+ if self.faiss_index is not None and self.idx2hash is not None:
+ logger.info(
+ f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}"
+ )
+ faiss.write_index(self.faiss_index, self.index_file_path)
+ logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
+ logger.info(
+ f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}"
+ )
+ with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
+ f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
+ logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
+
+ def load_from_file(self) -> None:
+ """从文件中加载"""
+ if not os.path.exists(self.embedding_file_path):
+ raise Exception(f"文件{self.embedding_file_path}不存在")
+
+ logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
+ data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
+ for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)):
+ self.store[row["hash"]] = EmbeddingStoreItem(
+ row["hash"], row["embedding"], row["str"]
+ )
+ logger.info(f"{self.namespace}嵌入库加载成功")
+
+ try:
+ if os.path.exists(self.index_file_path):
+ logger.info(
+ f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex"
+ )
+ self.faiss_index = faiss.read_index(self.index_file_path)
+ logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
+ else:
+ raise Exception(f"文件{self.index_file_path}不存在")
+ if os.path.exists(self.idx2hash_file_path):
+ logger.info(
+ f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射"
+ )
+ with open(self.idx2hash_file_path, "r") as f:
+ self.idx2hash = json.load(f)
+ logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
+ else:
+ raise Exception(f"文件{self.idx2hash_file_path}不存在")
+ except Exception as e:
+ logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误:{e}")
+ logger.warning("正在重建Faiss索引")
+ self.build_faiss_index()
+ logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
+ self.save_to_file()
+
+ def build_faiss_index(self) -> None:
+ """重新构建Faiss索引,以余弦相似度为度量"""
+ # 获取所有的embedding
+ array = []
+ self.idx2hash = dict()
+ for key in self.store:
+ array.append(self.store[key].embedding)
+ self.idx2hash[str(len(array) - 1)] = key
+ embeddings = np.array(array, dtype=np.float32)
+ # L2归一化
+ faiss.normalize_L2(embeddings)
+ # 构建索引
+ self.faiss_index = faiss.IndexFlatIP(global_config["embedding"]["dimension"])
+ self.faiss_index.add(embeddings)
+
+ def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
+ """搜索最相似的k个项,以余弦相似度为度量
+ Args:
+ query: 查询的embedding
+ k: 返回的最相似的k个项
+ Returns:
+ result: 最相似的k个项的(hash, 余弦相似度)列表
+ """
+ if self.faiss_index is None:
+ raise Exception("Faiss索引尚未构建")
+ if self.idx2hash is None:
+ raise Exception("idx2hash映射尚未构建")
+
+ # L2归一化
+ faiss.normalize_L2(np.array([query], dtype=np.float32))
+ # 搜索
+ distances, indices = self.faiss_index.search(np.array([query]), k)
+ # 整理结果
+ indices = list(indices.flatten())
+ distances = list(distances.flatten())
+ result = [
+ (self.idx2hash[str(int(idx))], float(sim))
+ for (idx, sim) in zip(indices, distances)
+ if idx in range(len(self.idx2hash))
+ ]
+
+ return result
+
+
+class EmbeddingManager:
+ def __init__(self, llm_client: LLMClient):
+ self.paragraphs_embedding_store = EmbeddingStore(
+ llm_client,
+ PG_NAMESPACE,
+ global_config["persistence"]["embedding_data_dir"],
+ )
+ self.entities_embedding_store = EmbeddingStore(
+ llm_client,
+ ENT_NAMESPACE,
+ global_config["persistence"]["embedding_data_dir"],
+ )
+ self.relation_embedding_store = EmbeddingStore(
+ llm_client,
+ REL_NAMESPACE,
+ global_config["persistence"]["embedding_data_dir"],
+ )
+ self.stored_pg_hashes = set()
+
+ def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
+ """将段落编码存入Embedding库"""
+ self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()))
+
+ def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
+ """将实体编码存入Embedding库"""
+ entities = set()
+ for triple_list in triple_list_data.values():
+ for triple in triple_list:
+ entities.add(triple[0])
+ entities.add(triple[2])
+ self.entities_embedding_store.batch_insert_strs(list(entities))
+
+ def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
+ """将关系编码存入Embedding库"""
+ graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
+ for triples in triple_list_data.values():
+ graph_triples.extend([tuple(t) for t in triples])
+ graph_triples = list(set(graph_triples))
+ self.relation_embedding_store.batch_insert_strs(
+ [str(triple) for triple in graph_triples]
+ )
+
+ def load_from_file(self):
+ """从文件加载"""
+ self.paragraphs_embedding_store.load_from_file()
+ self.entities_embedding_store.load_from_file()
+ self.relation_embedding_store.load_from_file()
+ # 从段落库中获取已存储的hash
+ self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
+
+ def store_new_data_set(
+ self,
+ raw_paragraphs: Dict[str, str],
+ triple_list_data: Dict[str, List[List[str]]],
+ ):
+ """存储新的数据集"""
+ self._store_pg_into_embedding(raw_paragraphs)
+ self._store_ent_into_embedding(triple_list_data)
+ self._store_rel_into_embedding(triple_list_data)
+ self.stored_pg_hashes.update(raw_paragraphs.keys())
+
+ def save_to_file(self):
+ """保存到文件"""
+ self.paragraphs_embedding_store.save_to_file()
+ self.entities_embedding_store.save_to_file()
+ self.relation_embedding_store.save_to_file()
+
+ def rebuild_faiss_index(self):
+ """重建Faiss索引(请在添加新数据后调用)"""
+ self.paragraphs_embedding_store.build_faiss_index()
+ self.entities_embedding_store.build_faiss_index()
+ self.relation_embedding_store.build_faiss_index()
diff --git a/src/plugins/knowledge/src/global_logger.py b/src/plugins/knowledge/src/global_logger.py
new file mode 100644
index 00000000..a99197ed
--- /dev/null
+++ b/src/plugins/knowledge/src/global_logger.py
@@ -0,0 +1,10 @@
+# Configure logger
+
+from src.common.logger import get_module_logger, LogConfig, KNOWLEDGE_STYLE_CONFIG
+
+lib_config = LogConfig(
+ # 使用知识专用样式
+ console_format=KNOWLEDGE_STYLE_CONFIG["console_format"],
+ file_format=KNOWLEDGE_STYLE_CONFIG["file_format"],
+)
+logger = get_module_logger("knowledge", config=lib_config)
diff --git a/src/plugins/knowledge/src/ie_process.py b/src/plugins/knowledge/src/ie_process.py
new file mode 100644
index 00000000..72435dc1
--- /dev/null
+++ b/src/plugins/knowledge/src/ie_process.py
@@ -0,0 +1,108 @@
+import json
+import time
+from typing import List
+
+from .global_logger import logger
+from . import prompt_template
+from .config import global_config, INVALID_ENTITY
+from .llm_client import LLMClient
+from .utils.json_fix import fix_broken_generated_json
+
+
+def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
+ """对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
+ entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
+ _, request_result = llm_client.send_chat_request(
+ global_config["entity_extract"]["llm"]["model"], entity_extract_context
+ )
+
+ # 去除‘{’前的内容(结果中可能有多个‘{’)
+ if "[" in request_result:
+ request_result = request_result[request_result.index("[") :]
+
+ # 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
+ if "]" in request_result:
+ request_result = request_result[: request_result.rindex("]") + 1]
+
+ entity_extract_result = json.loads(fix_broken_generated_json(request_result))
+
+ entity_extract_result = [
+ entity
+ for entity in entity_extract_result
+ if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
+ ]
+
+ if len(entity_extract_result) == 0:
+ raise Exception("实体提取结果为空")
+
+ return entity_extract_result
+
+
+def _rdf_triple_extract(
+ llm_client: LLMClient, paragraph: str, entities: list
+) -> List[List[str]]:
+ """对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
+ entity_extract_context = prompt_template.build_rdf_triple_extract_context(
+ paragraph, entities=json.dumps(entities, ensure_ascii=False)
+ )
+ _, request_result = llm_client.send_chat_request(
+ global_config["rdf_build"]["llm"]["model"], entity_extract_context
+ )
+
+ # 去除‘{’前的内容(结果中可能有多个‘{’)
+ if "[" in request_result:
+ request_result = request_result[request_result.index("[") :]
+
+ # 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
+ if "]" in request_result:
+ request_result = request_result[: request_result.rindex("]") + 1]
+
+ entity_extract_result = json.loads(fix_broken_generated_json(request_result))
+
+ for triple in entity_extract_result:
+ if (
+ len(triple) != 3
+ or (triple[0] is None or triple[1] is None or triple[2] is None)
+ or "" in triple
+ ):
+ raise Exception("RDF提取结果格式错误")
+
+ return entity_extract_result
+
+
+def info_extract_from_str(
+ llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
+) -> tuple[None, None] | tuple[list[str], list[list[str]]]:
+ try_count = 0
+ while True:
+ try:
+ entity_extract_result = _entity_extract(llm_client_for_ner, paragraph)
+ break
+ except Exception as e:
+ logger.warning(f"实体提取失败,错误信息:{e}")
+ try_count += 1
+ if try_count < 3:
+ logger.warning("将于5秒后重试")
+ time.sleep(5)
+ else:
+ logger.error("实体提取失败,已达最大重试次数")
+ return None, None
+
+ try_count = 0
+ while True:
+ try:
+ rdf_triple_extract_result = _rdf_triple_extract(
+ llm_client_for_rdf, paragraph, entity_extract_result
+ )
+ break
+ except Exception as e:
+ logger.warning(f"实体提取失败,错误信息:{e}")
+ try_count += 1
+ if try_count < 3:
+ logger.warning("将于5秒后重试")
+ time.sleep(5)
+ else:
+ logger.error("实体提取失败,已达最大重试次数")
+ return None, None
+
+ return entity_extract_result, rdf_triple_extract_result
diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py
new file mode 100644
index 00000000..dae90d11
--- /dev/null
+++ b/src/plugins/knowledge/src/kg_manager.py
@@ -0,0 +1,431 @@
+import json
+import os
+import time
+from typing import Dict, List, Tuple
+
+import networkx as nx
+import numpy as np
+import pandas as pd
+import tqdm
+
+from ..lib import quick_algo
+from .utils.hash import get_sha256
+from .embedding_store import EmbeddingManager, EmbeddingStoreItem
+from .config import (
+ ENT_NAMESPACE,
+ PG_NAMESPACE,
+ RAG_ENT_CNT_NAMESPACE,
+ RAG_GRAPH_NAMESPACE,
+ RAG_PG_HASH_NAMESPACE,
+ global_config,
+)
+
+from .global_logger import logger
+
+class KGManager:
+ def __init__(self):
+ # 会被保存的字段
+ # 存储段落的hash值,用于去重
+ self.stored_paragraph_hashes = set()
+ # 实体出现次数
+ self.ent_appear_cnt = dict()
+ # KG
+ self.graph = nx.DiGraph()
+
+ # 持久化相关
+ self.dir_path = global_config["persistence"]["rag_data_dir"]
+ self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphmlz"
+ self.ent_cnt_data_path = (
+ self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
+ )
+ self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
+
+ def save_to_file(self):
+ """将KG数据保存到文件"""
+ # 确保目录存在
+ if not os.path.exists(self.dir_path):
+ os.makedirs(self.dir_path, exist_ok=True)
+
+ # 保存KG
+ nx.write_graphml(self.graph, path=self.graph_data_path, encoding="utf-8")
+
+ # 保存实体计数到文件
+ ent_cnt_df = pd.DataFrame(
+ [{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]
+ )
+ ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False)
+
+ # 保存段落hash到文件
+ with open(self.pg_hash_file_path, "w", encoding="utf-8") as f:
+ data = {"stored_paragraph_hashes": list(self.stored_paragraph_hashes)}
+ f.write(json.dumps(data, ensure_ascii=False, indent=4))
+
+ def load_from_file(self):
+ """从文件加载KG数据"""
+ # 确保文件存在
+ if not os.path.exists(self.pg_hash_file_path):
+ raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在")
+ if not os.path.exists(self.ent_cnt_data_path):
+ raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
+ if not os.path.exists(self.graph_data_path):
+ raise Exception(f"KG图文件{self.graph_data_path}不存在")
+
+ # 加载段落hash
+ with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"])
+
+ # 加载实体计数
+ ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
+ self.ent_appear_cnt = dict(
+ {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
+ )
+
+ # 加载KG
+ self.graph = nx.read_graphml(self.graph_data_path)
+ if not isinstance(self.graph, nx.DiGraph):
+ raise Exception("KG图文件存在问题")
+
+ def _build_edges_between_ent(
+ self,
+ node_to_node: Dict[Tuple[str, str], float],
+ triple_list_data: Dict[str, List[List[str]]],
+ ):
+ """构建实体节点之间的关系,同时统计实体出现次数"""
+ for triple_list in triple_list_data.values():
+ entity_set = set()
+ for triple in triple_list:
+ if triple[0] == triple[2]:
+ # 避免自连接
+ continue
+ # 一个triple就是一条边(同时构建双向联系)
+ hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
+ hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
+ node_to_node[(hash_key1, hash_key2)] = (
+ node_to_node.get((hash_key1, hash_key2), 0) + 1.0
+ )
+ node_to_node[(hash_key2, hash_key1)] = (
+ node_to_node.get((hash_key2, hash_key1), 0) + 1.0
+ )
+ entity_set.add(hash_key1)
+ entity_set.add(hash_key2)
+
+ # 实体出现次数统计
+ for hash_key in entity_set:
+ self.ent_appear_cnt[hash_key] = (
+ self.ent_appear_cnt.get(hash_key, 0) + 1.0
+ )
+
+ @staticmethod
+ def _build_edges_between_ent_pg(
+ node_to_node: Dict[Tuple[str, str], float],
+ triple_list_data: Dict[str, List[List[str]]],
+ ):
+ """构建实体节点与文段节点之间的关系"""
+ for idx in triple_list_data:
+ for triple in triple_list_data[idx]:
+ ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
+ pg_hash_key = PG_NAMESPACE + "-" + str(idx)
+ node_to_node[(ent_hash_key, pg_hash_key)] = (
+ node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
+ )
+
+ @staticmethod
+ def _synonym_connect(
+ node_to_node: Dict[Tuple[str, str], float],
+ triple_list_data: Dict[str, List[List[str]]],
+ embedding_manager: EmbeddingManager,
+ ) -> int:
+ """同义词连接"""
+ new_edge_cnt = 0
+ # 获取所有实体节点的hash值
+ ent_hash_list = set()
+ for triple_list in triple_list_data.values():
+ for triple in triple_list:
+ ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
+ ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
+ ent_hash_list = list(ent_hash_list)
+
+ synonym_hash_set = set()
+
+ synonym_result = dict()
+
+ # 对每个实体节点,查找其相似的实体节点,建立扩展连接
+ for ent_hash in tqdm.tqdm(ent_hash_list):
+ if ent_hash in synonym_hash_set:
+ # 避免同一批次内重复添加
+ continue
+ ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
+ assert isinstance(ent, EmbeddingStoreItem)
+ if ent is None:
+ continue
+ # 查询相似实体
+ similar_ents = embedding_manager.entities_embedding_store.search_top_k(
+ ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
+ )
+ res_ent = [] # Debug
+ for res_ent_hash, similarity in similar_ents:
+ if res_ent_hash == ent_hash:
+ # 避免自连接
+ continue
+ if similarity < global_config["rag"]["params"]["synonym_threshold"]:
+ # 相似度阈值
+ continue
+ node_to_node[(res_ent_hash, ent_hash)] = similarity
+ node_to_node[(ent_hash, res_ent_hash)] = similarity
+ synonym_hash_set.add(res_ent_hash)
+ new_edge_cnt += 1
+ res_ent.append(
+ (
+ embedding_manager.entities_embedding_store.store[
+ res_ent_hash
+ ].str,
+ similarity,
+ )
+ ) # Debug
+ synonym_result[ent.str] = res_ent
+
+ for k, v in synonym_result.items():
+ print(f'"{k}"的相似实体为:{v}')
+ return new_edge_cnt
+
+ def _update_graph(
+ self,
+ node_to_node: Dict[Tuple[str, str], float],
+ embedding_manager: EmbeddingManager,
+ ):
+ """更新KG图结构
+
+ 流程:
+ 1. 更新图结构:遍历所有待添加的新边
+ - 若是新边,则添加到图中
+ - 若是已存在的边,则更新边的权重
+ 2. 更新新节点的属性
+ """
+ existed_nodes = [str(node) for node in self.graph.nodes]
+ existed_edges = [str((edge[0], edge[1])) for edge in self.graph.edges]
+
+ now_time = time.time()
+
+ # 更新图结构
+ for src_tgt, weight in node_to_node.items():
+ key = str(src_tgt)
+ # 检查边是否已存在
+ if key not in existed_edges:
+ # 新边
+ new_edges = [
+ (
+ src_tgt[0],
+ src_tgt[1],
+ {
+ "weight": weight,
+ "create_time": now_time,
+ "update_time": now_time,
+ },
+ )
+ ]
+ self.graph.add_edges_from(new_edges)
+ else:
+ # 已存在的边
+ self.graph.edges[src_tgt[0], src_tgt[1]]["weight"] += weight
+ self.graph.edges[src_tgt[0], src_tgt[1]]["update_time"] = now_time
+
+ # 更新新节点属性
+ for src_tgt in node_to_node.keys():
+ for node_hash in src_tgt:
+ if node_hash not in existed_nodes:
+ if node_hash.startswith(ENT_NAMESPACE):
+ # 新增实体节点
+ node = embedding_manager.entities_embedding_store.store[
+ node_hash
+ ]
+ assert isinstance(node, EmbeddingStoreItem)
+ self.graph.nodes[node_hash]["content"] = node.str
+ self.graph.nodes[node_hash]["type"] = "ent"
+ elif node_hash.startswith(PG_NAMESPACE):
+ # 新增文段节点
+ node = embedding_manager.paragraphs_embedding_store.store[
+ node_hash
+ ]
+ assert isinstance(node, EmbeddingStoreItem)
+ content = node.str.replace("\n", " ")
+ self.graph.nodes[node_hash]["content"] = (
+ content if len(content) < 8 else content[:8] + "..."
+ )
+ self.graph.nodes[node_hash]["type"] = "pg"
+
+ def build_kg(
+ self,
+ triple_list_data: Dict[str, List[List[str]]],
+ embedding_manager: EmbeddingManager,
+ ):
+ """增量式构建KG
+
+ 注意:应当在调用该方法后保存KG
+
+ Args:
+ triple_list_data: 三元组数据
+ embedding_manager: EmbeddingManager对象
+ """
+ # 实体之间的联系
+ node_to_node = dict()
+
+ # 构建实体节点之间的关系,同时统计实体出现次数
+ logger.info("正在构建KG实体节点之间的关系,同时统计实体出现次数")
+ # 从三元组提取实体对
+ self._build_edges_between_ent(node_to_node, triple_list_data)
+
+ # 构建实体节点与文段节点之间的关系
+ logger.info("正在构建KG实体节点与文段节点之间的关系")
+ self._build_edges_between_ent_pg(node_to_node, triple_list_data)
+
+ # 近义词扩展链接
+ # 对每个实体节点,找到最相似的实体节点,建立扩展连接
+ logger.info("正在进行近义词扩展链接")
+ self._synonym_connect(node_to_node, triple_list_data, embedding_manager)
+
+ # 构建图
+ self._update_graph(node_to_node, embedding_manager)
+
+ # 记录已处理(存储)的段落hash
+ for idx in triple_list_data:
+ self.stored_paragraph_hashes.add(str(idx))
+
+ def kg_search(
+ self,
+ relation_search_result: List[Tuple[Tuple[str, str, str], float]],
+ paragraph_search_result: List[Tuple[str, float]],
+ embed_manager: EmbeddingManager,
+ ):
+ """RAG搜索与PageRank
+
+ Args:
+ relation_search_result: RelationEmbedding的搜索结果(relation_tripple, similarity)
+ paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity)
+ embed_manager: EmbeddingManager对象
+ """
+ # 图中存在的节点总集
+ existed_nodes = [str(node) for node in self.graph.nodes]
+
+ # 准备PPR使用的数据
+ # 节点权重:实体
+ ent_weights = {}
+ # 节点权重:文段
+ pg_weights = {}
+
+ # 以下部分处理实体权重ent_weights
+
+ # 针对每个关系,提取出其中的主宾短语作为两个实体,并记录对应的三元组的相似度作为权重依据
+ ent_sim_scores = {}
+ for relation_hash, similarity, _ in relation_search_result:
+ # 提取主宾短语
+ relation = embed_manager.relation_embedding_store.store.get(
+ relation_hash
+ ).str
+ assert relation is not None # 断言:relation不为空
+ # 关系三元组
+ triple = relation[2:-2].split("', '")
+ for ent in [(triple[0]), (triple[2])]:
+ ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
+ if ent_hash in existed_nodes: # 该实体需在KG中存在
+ if ent_hash not in ent_sim_scores: # 尚未记录的实体
+ ent_sim_scores[ent_hash] = []
+ ent_sim_scores[ent_hash].append(similarity)
+
+ ent_mean_scores = {} # 记录实体的平均相似度
+ for ent_hash, scores in ent_sim_scores.items():
+ # 先对相似度进行累加,然后与实体计数相除获取最终权重
+ ent_weights[ent_hash] = (
+ float(np.sum(scores)) / self.ent_appear_cnt[ent_hash]
+ )
+ # 记录实体的平均相似度,用于后续的top_k筛选
+ ent_mean_scores[ent_hash] = float(np.mean(scores))
+ del ent_sim_scores
+
+ ent_weights_max = max(ent_weights.values())
+ ent_weights_min = min(ent_weights.values())
+ if ent_weights_max == ent_weights_min:
+ # 只有一个相似度,则全赋值为1
+ for ent_hash in ent_weights.keys():
+ ent_weights[ent_hash] = 1.0
+ else:
+ down_edge = global_config["qa"]["params"]["paragraph_node_weight"]
+ # 缩放取值区间至[down_edge, 1]
+ for ent_hash, score in ent_weights.items():
+ # 缩放相似度
+ ent_weights[ent_hash] = (
+ (score - ent_weights_min)
+ * (1 - down_edge)
+ / (ent_weights_max - ent_weights_min)
+ ) + down_edge
+
+ # 取平均相似度的top_k实体
+ top_k = global_config["qa"]["params"]["ent_filter_top_k"]
+ if len(ent_mean_scores) > top_k:
+ # 从大到小排序,取后len - k个
+ ent_mean_scores = {
+ k: v
+ for k, v in sorted(
+ ent_mean_scores.items(), key=lambda item: item[1], reverse=True
+ )
+ }
+ for ent_hash, _ in ent_mean_scores.items():
+ # 删除被淘汰的实体节点权重设置
+ del ent_weights[ent_hash]
+ del top_k, ent_mean_scores
+
+ # 以下部分处理文段权重pg_weights
+
+ # 将搜索结果中文段的相似度归一化作为权重
+ pg_sim_scores = {}
+ pg_sim_score_max = 0.0
+ pg_sim_score_min = 1.0
+ for pg_hash, similarity in paragraph_search_result:
+ # 查找最大和最小值
+ pg_sim_score_max = max(pg_sim_score_max, similarity)
+ pg_sim_score_min = min(pg_sim_score_min, similarity)
+ pg_sim_scores[pg_hash] = similarity
+
+ # 归一化
+ for pg_hash, similarity in pg_sim_scores.items():
+ # 归一化相似度
+ pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (
+ pg_sim_score_max - pg_sim_score_min
+ )
+ del pg_sim_score_max, pg_sim_score_min
+
+ for pg_hash, score in pg_sim_scores.items():
+ pg_weights[pg_hash] = (
+ score * global_config["qa"]["params"]["paragraph_node_weight"]
+ ) # 文段权重 = 归一化相似度 * 文段节点权重参数
+ del pg_sim_scores
+
+ # 最终权重数据 = 实体权重 + 文段权重
+ ppr_node_weights = {
+ k: v for d in [ent_weights, pg_weights] for k, v in d.items()
+ }
+ del ent_weights, pg_weights
+
+ # PersonalizedPageRank
+ ppr_res = quick_algo.pagerank(
+ self.graph,
+ personalized=ppr_node_weights,
+ max_iter=1000,
+ alpha=global_config["qa"]["params"]["ppr_damping"],
+ )
+
+ # 获取最终结果
+ # 从搜索结果中提取文段节点的结果
+ passage_node_res = [
+ (node_key, score)
+ for node_key, score in ppr_res
+ if node_key.startswith(PG_NAMESPACE)
+ ]
+ del ppr_res
+
+ # 排序:按照分数从大到小
+ passage_node_res = sorted(
+ passage_node_res, key=lambda item: item[1], reverse=True
+ )
+
+ return passage_node_res, ppr_node_weights
diff --git a/src/plugins/knowledge/src/llm_client.py b/src/plugins/knowledge/src/llm_client.py
new file mode 100644
index 00000000..3036662c
--- /dev/null
+++ b/src/plugins/knowledge/src/llm_client.py
@@ -0,0 +1,53 @@
+from openai import OpenAI
+
+
+class LLMMessage:
+ def __init__(self, role, content):
+ self.role = role
+ self.content = content
+
+ def to_dict(self):
+ return {"role": self.role, "content": self.content}
+
+
+class LLMClient:
+ """LLM客户端,对应一个API服务商"""
+
+ def __init__(self, url, api_key):
+ self.client = OpenAI(
+ base_url=url,
+ api_key=api_key,
+ )
+
+ def send_chat_request(self, model, messages):
+ """发送对话请求,等待返回结果"""
+ response = self.client.chat.completions.create(
+ model=model, messages=messages, stream=False
+ )
+ if hasattr(response.choices[0].message, "reasoning_content"):
+ # 有单独的推理内容块
+ reasoning_content = response.choices[0].message.reasoning_content
+ content = response.choices[0].message.content
+ else:
+ # 无单独的推理内容块
+ response = (
+ response.choices[0]
+ .message.content.split("")[-1]
+ .split("")
+ )
+ # 如果有推理内容,则分割推理内容和内容
+ if len(response) == 2:
+ reasoning_content = response[0]
+ content = response[1]
+ else:
+ reasoning_content = None
+ content = response[0]
+
+ return reasoning_content, content
+
+ def send_embedding_request(self, model, text):
+ """发送嵌入请求,等待返回结果"""
+ text = text.replace("\n", " ")
+ return (
+ self.client.embeddings.create(input=[text], model=model).data[0].embedding
+ )
diff --git a/src/plugins/knowledge/src/mem_active_manager.py b/src/plugins/knowledge/src/mem_active_manager.py
new file mode 100644
index 00000000..9f5847e2
--- /dev/null
+++ b/src/plugins/knowledge/src/mem_active_manager.py
@@ -0,0 +1,36 @@
+from .config import global_config
+from .embedding_store import EmbeddingManager
+from .llm_client import LLMClient
+from .utils.dyn_topk import dyn_select_top_k
+
+
+class MemoryActiveManager:
+ def __init__(
+ self,
+ embed_manager: EmbeddingManager,
+ llm_client_embedding: LLMClient,
+ ):
+ self.embed_manager = embed_manager
+ self.embedding_client = llm_client_embedding
+
+ def get_activation(self, question: str) -> float:
+ """获取记忆激活度"""
+ # 生成问题的Embedding
+ question_embedding = self.embedding_client.send_embedding_request(
+ "text-embedding", question
+ )
+ # 查询关系库中的相似度
+ rel_search_res = self.embed_manager.relation_embedding_store.search_top_k(
+ question_embedding, 10
+ )
+
+ # 动态过滤阈值
+ rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0)
+ if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]:
+ # 未找到相关关系
+ return 0.0
+
+ # 计算激活度
+ activation = sum([item[2] for item in rel_scores]) * 10
+
+ return activation
diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py
new file mode 100644
index 00000000..3b5d704d
--- /dev/null
+++ b/src/plugins/knowledge/src/open_ie.py
@@ -0,0 +1,147 @@
+import json
+from typing import Any, Dict, List
+
+
+from .config import INVALID_ENTITY, global_config
+
+
+def _filter_invalid_entities(entities: List[str]) -> List[str]:
+ """过滤无效的实体"""
+ valid_entities = set()
+ for entity in entities:
+ if (
+ not isinstance(entity, str)
+ or entity.strip() == ""
+ or entity in INVALID_ENTITY
+ or entity in valid_entities
+ ):
+ # 非字符串/空字符串/在无效实体列表中/重复
+ continue
+ valid_entities.add(entity)
+
+ return list(valid_entities)
+
+
+def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]:
+ """过滤无效的三元组"""
+ unique_triples = set()
+ valid_triples = []
+
+ for triple in triples:
+ if len(triple) != 3 or (
+ (not isinstance(triple[0], str) or triple[0].strip() == "")
+ or (not isinstance(triple[1], str) or triple[1].strip() == "")
+ or (not isinstance(triple[2], str) or triple[2].strip() == "")
+ ):
+ # 三元组长度不为3,或其中存在空值
+ continue
+
+ valid_triple = [str(item) for item in triple]
+ if tuple(valid_triple) not in unique_triples:
+ unique_triples.add(tuple(valid_triple))
+ valid_triples.append(valid_triple)
+
+ return valid_triples
+
+
+class OpenIE:
+ """
+ OpenIE规约的数据格式为如下
+ {
+ "docs": [
+ {
+ "idx": "文档的唯一标识符(通常是文本的SHA256哈希值)",
+ "passage": "文档的原始文本",
+ "extracted_entities": ["实体1", "实体2", ...],
+ "extracted_triples": [["主语", "谓语", "宾语"], ...]
+ },
+ ...
+ ],
+ "avg_ent_chars": "实体平均字符数",
+ "avg_ent_words": "实体平均词数"
+ }
+ """
+
+ def __init__(
+ self,
+ docs: List[Dict[str, Any]],
+ avg_ent_chars,
+ avg_ent_words,
+ ):
+ self.docs = docs
+ self.avg_ent_chars = avg_ent_chars
+ self.avg_ent_words = avg_ent_words
+
+ for doc in self.docs:
+ # 过滤实体列表
+ doc["extracted_entities"] = _filter_invalid_entities(
+ doc["extracted_entities"]
+ )
+ # 过滤无效的三元组
+ doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
+
+ @staticmethod
+ def _from_dict(data):
+ """从字典中获取OpenIE对象"""
+ return OpenIE(
+ docs=data["docs"],
+ avg_ent_chars=data["avg_ent_chars"],
+ avg_ent_words=data["avg_ent_words"],
+ )
+
+ def _to_dict(self):
+ """转换为字典"""
+ return {
+ "docs": self.docs,
+ "avg_ent_chars": self.avg_ent_chars,
+ "avg_ent_words": self.avg_ent_words,
+ }
+
+ @staticmethod
+ def load() -> "OpenIE":
+ """从文件中加载OpenIE数据"""
+ with open(
+ global_config["persistence"]["openie_data_path"], "r", encoding="utf-8"
+ ) as f:
+ data = json.loads(f.read())
+
+ openie_data = OpenIE._from_dict(data)
+
+ return openie_data
+
+ @staticmethod
+ def save(openie_data: "OpenIE"):
+ """保存OpenIE数据到文件"""
+ with open(
+ global_config["persistence"]["openie_data_path"], "w", encoding="utf-8"
+ ) as f:
+ f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
+
+ def extract_entity_dict(self):
+ """提取实体列表"""
+ ner_output_dict = dict(
+ {
+ doc_item["idx"]: doc_item["extracted_entities"]
+ for doc_item in self.docs
+ if len(doc_item["extracted_entities"]) > 0
+ }
+ )
+ return ner_output_dict
+
+ def extract_triple_dict(self):
+ """提取三元组列表"""
+ triple_output_dict = dict(
+ {
+ doc_item["idx"]: doc_item["extracted_triples"]
+ for doc_item in self.docs
+ if len(doc_item["extracted_triples"]) > 0
+ }
+ )
+ return triple_output_dict
+
+ def extract_raw_paragraph_dict(self):
+ """提取原始段落"""
+ raw_paragraph_dict = dict(
+ {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
+ )
+ return raw_paragraph_dict
diff --git a/src/plugins/knowledge/src/prompt_template.py b/src/plugins/knowledge/src/prompt_template.py
new file mode 100644
index 00000000..ab6ac08c
--- /dev/null
+++ b/src/plugins/knowledge/src/prompt_template.py
@@ -0,0 +1,73 @@
+from typing import List
+
+from .llm_client import LLMMessage
+
+entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
+
+输出格式示例:
+[ "实体A", "实体B", "实体C" ]
+
+请注意以下要求:
+- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
+- 尽可能多的提取出段落中的全部实体;
+"""
+
+
+def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
+ messages = [
+ LLMMessage("system", entity_extract_system_prompt).to_dict(),
+ LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
+ ]
+ return messages
+
+
+rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
+
+请使用JSON回复,使用三元组的JSON列表输出RDF图中的关系(每个三元组代表一个关系)。
+
+输出格式示例:
+[
+ ["某实体","关系","某属性"],
+ ["某实体","关系","某实体"],
+ ["某资源","关系","某属性"]
+]
+
+请注意以下要求:
+- 每个三元组应包含每个段落的实体命名列表中的至少一个命名实体,但最好是两个。
+- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
+"""
+
+
+def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
+ messages = [
+ LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
+ LLMMessage(
+ "user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
+ ).to_dict(),
+ ]
+ return messages
+
+
+qa_system_prompt = """
+你是一个性能优异的QA系统。请根据给定的问题和一些可能对你有帮助的信息作出回答。
+
+请注意以下要求:
+- 你可以使用给定的信息来回答问题,但请不要直接引用它们。
+- 你的回答应该简洁明了,避免冗长的解释。
+- 如果你无法回答问题,请直接说“我不知道”。
+"""
+
+
+def build_qa_context(
+ question: str, knowledge: list[(str, str, str)]
+) -> List[LLMMessage]:
+ knowledge = "\n".join(
+ [f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
+ )
+ messages = [
+ LLMMessage("system", qa_system_prompt).to_dict(),
+ LLMMessage(
+ "user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
+ ).to_dict(),
+ ]
+ return messages
diff --git a/src/plugins/knowledge/src/qa_manager.py b/src/plugins/knowledge/src/qa_manager.py
new file mode 100644
index 00000000..d243d59f
--- /dev/null
+++ b/src/plugins/knowledge/src/qa_manager.py
@@ -0,0 +1,116 @@
+import time
+from typing import Tuple, List, Dict
+
+from .global_logger import logger
+from .embedding_store import EmbeddingManager
+from .llm_client import LLMClient
+from .kg_manager import KGManager
+from .config import global_config
+from .utils.dyn_topk import dyn_select_top_k
+
+
+class QAManager:
+ def __init__(
+ self,
+ embed_manager: EmbeddingManager,
+ kg_manager: KGManager,
+ llm_client_embedding: LLMClient,
+ llm_client_filter: LLMClient,
+ llm_client_qa: LLMClient,
+ ):
+ self.embed_manager = embed_manager
+ self.kg_manager = kg_manager
+ self.llm_client_list = {
+ "embedding": llm_client_embedding,
+ "filter": llm_client_filter,
+ "qa": llm_client_qa,
+ }
+
+ def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Dict[str, float] | None]:
+ """处理查询"""
+
+ # 生成问题的Embedding
+ part_start_time =time.perf_counter()
+ question_embedding = self.llm_client_list["embedding"].send_embedding_request(
+ global_config["embedding"]["model"], question
+ )
+ part_end_time = time.perf_counter()
+ logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
+
+ # 根据问题Embedding查询Relation Embedding库
+ part_start_time =time.perf_counter()
+ relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
+ question_embedding,
+ global_config["qa"]["params"]["relation_search_top_k"],
+ )
+ # 过滤阈值
+ # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
+ relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
+ if (
+ relation_search_res[0][1]
+ < global_config["qa"]["params"]["relation_threshold"]
+ ):
+ # 未找到相关关系
+ relation_search_res = []
+
+ part_end_time = time.perf_counter()
+ logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
+
+ for res in relation_search_res:
+ rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
+ print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
+
+ # TODO: 使用LLM过滤三元组结果
+ # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
+ # part_start_time = time.time()
+
+ # 根据问题Embedding查询Paragraph Embedding库
+ part_start_time =time.perf_counter()
+ paragraph_search_res = (
+ self.embed_manager.paragraphs_embedding_store.search_top_k(
+ question_embedding,
+ global_config["qa"]["params"]["paragraph_search_top_k"],
+ )
+ )
+ part_end_time = time.perf_counter()
+ logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
+
+ if len(relation_search_res) != 0:
+ logger.info("找到相关关系,将使用RAG进行检索")
+ # 使用KG检索
+ part_start_time =time.perf_counter()
+ result, ppr_node_weights = self.kg_manager.kg_search(
+ relation_search_res, paragraph_search_res, self.embed_manager
+ )
+ part_end_time = time.perf_counter()
+ logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
+ else:
+ logger.info("未找到相关关系,将使用文段检索结果")
+ result = paragraph_search_res
+ ppr_node_weights = None
+
+ # 过滤阈值
+ result = dyn_select_top_k(result, 0.5, 1.0)
+
+ for res in result:
+ raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[
+ res[0]
+ ].str
+ print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
+
+ return result, ppr_node_weights
+
+ def get_knowledge(self, question: str) -> str:
+ """获取知识"""
+ # 处理查询
+ query_res, _ = self.process_query(question)
+
+ knowledge = [
+ (
+ self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
+ res[1],
+ )
+ for res in query_res
+ ]
+ found_knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
+ return found_knowledge
\ No newline at end of file
diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py
new file mode 100644
index 00000000..3670a705
--- /dev/null
+++ b/src/plugins/knowledge/src/raw_processing.py
@@ -0,0 +1,46 @@
+import json
+import os
+
+from .global_logger import logger
+from .config import global_config
+from .utils.hash import get_sha256
+
+
+def load_raw_data() -> tuple[list[str], list[str]]:
+ """加载原始数据文件
+
+ 读取原始数据文件,将原始数据加载到内存中
+
+ Returns:
+ - raw_data: 原始数据字典
+ - md5_set: 原始数据的SHA256集合
+ """
+ # 读取import.json文件
+ if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
+ with open(
+ global_config["persistence"]["raw_data_path"], "r", encoding="utf-8"
+ ) as f:
+ import_json = json.loads(f.read())
+ else:
+ raise Exception("原始数据文件读取失败")
+ # import_json内容示例:
+ # import_json = [
+ # "The capital of China is Beijing. The capital of France is Paris.",
+ # ]
+ raw_data = []
+ sha256_list = []
+ sha256_set = set()
+ for item in import_json:
+ if not isinstance(item, str):
+ logger.warning("数据类型错误:{}".format(item))
+ continue
+ pg_hash = get_sha256(item)
+ if pg_hash in sha256_set:
+ logger.warning("重复数据:{}".format(item))
+ continue
+ sha256_set.add(pg_hash)
+ sha256_list.append(pg_hash)
+ raw_data.append(item)
+ logger.info("共读取到{}条数据".format(len(raw_data)))
+
+ return sha256_list, raw_data
diff --git a/src/plugins/knowledge/src/scripts/__init__.py b/src/plugins/knowledge/src/scripts/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/plugins/knowledge/src/scripts/text_pre_process.py b/src/plugins/knowledge/src/scripts/text_pre_process.py
new file mode 100644
index 00000000..788cb953
--- /dev/null
+++ b/src/plugins/knowledge/src/scripts/text_pre_process.py
@@ -0,0 +1,23 @@
+import json
+
+
+with open("raw.txt", "r", encoding="utf-8") as f:
+ raw = f.read()
+
+# 一行一行的处理文件
+paragraphs = []
+paragraph = ""
+for line in raw.split("\n"):
+ if line.strip() == "":
+ # 有空行,表示段落结束
+ if paragraph != "":
+ paragraphs.append(paragraph)
+ paragraph = ""
+ else:
+ paragraph += line + "\n"
+
+if paragraph != "":
+ paragraphs.append(paragraph)
+
+with open("import.json", "w", encoding="utf-8") as f:
+ json.dump(paragraphs, f, ensure_ascii=False, indent=4)
diff --git a/src/plugins/knowledge/src/utils/__init__.py b/src/plugins/knowledge/src/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/plugins/knowledge/src/utils/data_loader.py b/src/plugins/knowledge/src/utils/data_loader.py
new file mode 100644
index 00000000..aab20993
--- /dev/null
+++ b/src/plugins/knowledge/src/utils/data_loader.py
@@ -0,0 +1,58 @@
+import jsonlines
+from pathlib import Path
+from typing import List, Dict, Any, Union, Optional
+from ..config import global_config as config
+
+
+class DataLoader:
+ """数据加载工具类,用于从/data目录下加载各种格式的数据文件"""
+
+ def __init__(self, custom_data_dir: Optional[Union[str, Path]] = None):
+ """
+ 初始化数据加载器
+
+ Args:
+ custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
+ """
+ self.data_dir = (
+ Path(custom_data_dir)
+ if custom_data_dir
+ else Path(config["persistence"]["data_root_path"])
+ )
+ if not self.data_dir.exists():
+ raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
+
+ def _resolve_file_path(self, filename: str) -> Path:
+ """
+ 解析文件路径
+
+ Args:
+ filename: 文件名
+
+ Returns:
+ 完整的文件路径
+
+ Raises:
+ FileNotFoundError: 当文件不存在时抛出
+ """
+ file_path = self.data_dir / filename
+ if not file_path.exists():
+ raise FileNotFoundError(f"文件 {filename} 不存在")
+ return file_path
+
+ def load_jsonl(self, filename: str) -> List[Dict[str, Any]]:
+ """
+ 加载JSONL格式的文件
+
+ Args:
+ filename: 文件名
+
+ Returns:
+ 包含所有数据的列表
+ """
+ file_path = self._resolve_file_path(filename)
+ data = []
+ with jsonlines.open(file_path) as reader:
+ for obj in reader:
+ data.append(obj)
+ return data
diff --git a/src/plugins/knowledge/src/utils/dyn_topk.py b/src/plugins/knowledge/src/utils/dyn_topk.py
new file mode 100644
index 00000000..02bc5e3e
--- /dev/null
+++ b/src/plugins/knowledge/src/utils/dyn_topk.py
@@ -0,0 +1,51 @@
+from typing import List, Any, Tuple
+
+
+def dyn_select_top_k(
+ score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
+) -> List[Tuple[Any, float, float]]:
+ """动态TopK选择"""
+ # 按照分数排序(降序)
+ sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
+
+ # 归一化
+ max_score = sorted_score[0][1]
+ min_score = sorted_score[-1][1]
+ normalized_score = []
+ for score_item in sorted_score:
+ normalized_score.append(
+ tuple(
+ [
+ score_item[0],
+ score_item[1],
+ (score_item[1] - min_score) / (max_score - min_score),
+ ]
+ )
+ )
+
+ # 寻找跳变点:score变化最大的位置
+ jump_idx = 0
+ for i in range(1, len(normalized_score)):
+ if abs(normalized_score[i][2] - normalized_score[i - 1][2]) > abs(
+ normalized_score[jump_idx][2] - normalized_score[jump_idx - 1][2]
+ ):
+ jump_idx = i
+ # 跳变阈值
+ jump_threshold = normalized_score[jump_idx][2]
+
+ # 计算均值
+ mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score)
+ # 计算方差
+ var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len(
+ normalized_score
+ )
+
+ # 动态阈值
+ threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * (
+ mean_score + var_factor * var_score
+ )
+
+ # 重新过滤
+ res = [s for s in normalized_score if s[2] > threshold]
+
+ return res
diff --git a/src/plugins/knowledge/src/utils/hash.py b/src/plugins/knowledge/src/utils/hash.py
new file mode 100644
index 00000000..b3e12b87
--- /dev/null
+++ b/src/plugins/knowledge/src/utils/hash.py
@@ -0,0 +1,8 @@
+import hashlib
+
+
+def get_sha256(string: str) -> str:
+ """获取字符串的SHA256值"""
+ sha256 = hashlib.sha256()
+ sha256.update(string.encode("utf-8"))
+ return sha256.hexdigest()
diff --git a/src/plugins/knowledge/src/utils/json_fix.py b/src/plugins/knowledge/src/utils/json_fix.py
new file mode 100644
index 00000000..672fb1f8
--- /dev/null
+++ b/src/plugins/knowledge/src/utils/json_fix.py
@@ -0,0 +1,79 @@
+import json
+
+
+def _find_unclosed(json_str):
+ """
+ Identifies the unclosed braces and brackets in the JSON string.
+
+ Args:
+ json_str (str): The JSON string to analyze.
+
+ Returns:
+ list: A list of unclosed elements in the order they were opened.
+ """
+ unclosed = []
+ inside_string = False
+ escape_next = False
+
+ for char in json_str:
+ if inside_string:
+ if escape_next:
+ escape_next = False
+ elif char == "\\":
+ escape_next = True
+ elif char == '"':
+ inside_string = False
+ else:
+ if char == '"':
+ inside_string = True
+ elif char in "{[":
+ unclosed.append(char)
+ elif char in "}]":
+ if unclosed and (
+ (char == "}" and unclosed[-1] == "{")
+ or (char == "]" and unclosed[-1] == "[")
+ ):
+ unclosed.pop()
+
+ return unclosed
+
+
+# The following code is used to fix a broken JSON string.
+# From HippoRAG2 (GitHub: OSU-NLP-Group/HippoRAG)
+def fix_broken_generated_json(json_str: str) -> str:
+ """
+ Fixes a malformed JSON string by:
+ - Removing the last comma and any trailing content.
+ - Iterating over the JSON string once to determine and fix unclosed braces or brackets.
+ - Ensuring braces and brackets inside string literals are not considered.
+
+ If the original json_str string can be successfully loaded by json.loads(), will directly return it without any modification.
+
+ Args:
+ json_str (str): The malformed JSON string to be fixed.
+
+ Returns:
+ str: The corrected JSON string.
+ """
+
+ try:
+ # Try to load the JSON to see if it is valid
+ json.loads(json_str)
+ return json_str # Return as-is if valid
+ except json.JSONDecodeError:
+ pass
+
+ # Step 1: Remove trailing content after the last comma.
+ last_comma_index = json_str.rfind(",")
+ if last_comma_index != -1:
+ json_str = json_str[:last_comma_index]
+
+ # Step 2: Identify unclosed braces and brackets.
+ unclosed_elements = _find_unclosed(json_str)
+
+ # Step 3: Append the necessary closing elements in reverse order of opening.
+ closing_map = {"{": "}", "[": "]"}
+ for open_char in reversed(unclosed_elements):
+ json_str += closing_map[open_char]
+
+ return json_str
diff --git a/src/plugins/knowledge/src/utils/visualize_graph.py b/src/plugins/knowledge/src/utils/visualize_graph.py
new file mode 100644
index 00000000..845e18a2
--- /dev/null
+++ b/src/plugins/knowledge/src/utils/visualize_graph.py
@@ -0,0 +1,17 @@
+import networkx as nx
+from matplotlib import pyplot as plt
+
+
+def draw_graph_and_show(graph):
+ """绘制图并显示,画布大小1280*1280"""
+ fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
+ nx.draw_networkx(
+ graph,
+ node_size=100,
+ width=0.5,
+ with_labels=True,
+ labels=nx.get_node_attributes(graph, "content"),
+ font_family="Sarasa Mono SC",
+ font_size=8,
+ )
+ fig.show()
\ No newline at end of file
diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py
deleted file mode 100644
index a95a096e..00000000
--- a/src/plugins/zhishi/knowledge_library.py
+++ /dev/null
@@ -1,350 +0,0 @@
-import os
-import sys
-import requests
-from dotenv import load_dotenv
-import hashlib
-from datetime import datetime
-from tqdm import tqdm
-from rich.console import Console
-from rich.table import Table
-
-# 添加项目根目录到 Python 路径
-root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
-sys.path.append(root_path)
-
-# 现在可以导入src模块
-from src.common.database import db # noqa E402
-
-# 加载根目录下的env.edv文件
-env_path = os.path.join(root_path, ".env")
-if not os.path.exists(env_path):
- raise FileNotFoundError(f"配置文件不存在: {env_path}")
-load_dotenv(env_path)
-
-
-class KnowledgeLibrary:
- def __init__(self):
- self.raw_info_dir = "data/raw_info"
- self._ensure_dirs()
- self.api_key = os.getenv("SILICONFLOW_KEY")
- if not self.api_key:
- raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
- self.console = Console()
-
- def _ensure_dirs(self):
- """确保必要的目录存在"""
- os.makedirs(self.raw_info_dir, exist_ok=True)
-
- def read_file(self, file_path: str) -> str:
- """读取文件内容"""
- with open(file_path, "r", encoding="utf-8") as f:
- return f.read()
-
- def split_content(self, content: str, max_length: int = 512) -> list:
- """将内容分割成适当大小的块,保持段落完整性
-
- Args:
- content: 要分割的文本内容
- max_length: 每个块的最大长度
-
- Returns:
- list: 分割后的文本块列表
- """
- # 首先按段落分割
- paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
- chunks = []
- current_chunk = []
- current_length = 0
-
- for para in paragraphs:
- para_length = len(para)
-
- # 如果单个段落就超过最大长度
- if para_length > max_length:
- # 如果当前chunk不为空,先保存
- if current_chunk:
- chunks.append("\n".join(current_chunk))
- current_chunk = []
- current_length = 0
-
- # 将长段落按句子分割
- sentences = [
- s.strip()
- for s in para.replace("。", "。\n").replace("!", "!\n").replace("?", "?\n").split("\n")
- if s.strip()
- ]
- temp_chunk = []
- temp_length = 0
-
- for sentence in sentences:
- sentence_length = len(sentence)
- if sentence_length > max_length:
- # 如果单个句子超长,强制按长度分割
- if temp_chunk:
- chunks.append("\n".join(temp_chunk))
- temp_chunk = []
- temp_length = 0
- for i in range(0, len(sentence), max_length):
- chunks.append(sentence[i : i + max_length])
- elif temp_length + sentence_length + 1 <= max_length:
- temp_chunk.append(sentence)
- temp_length += sentence_length + 1
- else:
- chunks.append("\n".join(temp_chunk))
- temp_chunk = [sentence]
- temp_length = sentence_length
-
- if temp_chunk:
- chunks.append("\n".join(temp_chunk))
-
- # 如果当前段落加上现有chunk不超过最大长度
- elif current_length + para_length + 1 <= max_length:
- current_chunk.append(para)
- current_length += para_length + 1
- else:
- # 保存当前chunk并开始新的chunk
- chunks.append("\n".join(current_chunk))
- current_chunk = [para]
- current_length = para_length
-
- # 添加最后一个chunk
- if current_chunk:
- chunks.append("\n".join(current_chunk))
-
- return chunks
-
- def get_embedding(self, text: str) -> list:
- """获取文本的embedding向量"""
- url = "https://api.siliconflow.cn/v1/embeddings"
- payload = {"model": "BAAI/bge-m3", "input": text, "encoding_format": "float"}
- headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
-
- response = requests.post(url, json=payload, headers=headers)
- if response.status_code != 200:
- print(f"获取embedding失败: {response.text}")
- return None
-
- return response.json()["data"][0]["embedding"]
-
- def process_files(self, knowledge_length: int = 512):
- """处理raw_info目录下的所有txt文件"""
- txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith(".txt")]
-
- if not txt_files:
- self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
- self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
- return
-
- total_stats = {"processed_files": 0, "total_chunks": 0, "failed_files": [], "skipped_files": []}
-
- self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
-
- for filename in tqdm(txt_files, desc="处理文件进度"):
- file_path = os.path.join(self.raw_info_dir, filename)
- result = self.process_single_file(file_path, knowledge_length)
- self._update_stats(total_stats, result, filename)
-
- self._display_processing_results(total_stats)
-
- def process_single_file(self, file_path: str, knowledge_length: int = 512):
- """处理单个文件"""
- result = {"status": "success", "chunks_processed": 0, "error": None}
-
- try:
- current_hash = self.calculate_file_hash(file_path)
- processed_record = db.processed_files.find_one({"file_path": file_path})
-
- if processed_record:
- if processed_record.get("hash") == current_hash:
- if knowledge_length in processed_record.get("split_by", []):
- result["status"] = "skipped"
- return result
-
- content = self.read_file(file_path)
- chunks = self.split_content(content, knowledge_length)
-
- for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
- embedding = self.get_embedding(chunk)
- if embedding:
- knowledge = {
- "content": chunk,
- "embedding": embedding,
- "source_file": file_path,
- "split_length": knowledge_length,
- "created_at": datetime.now(),
- }
- db.knowledges.insert_one(knowledge)
- result["chunks_processed"] += 1
-
- split_by = processed_record.get("split_by", []) if processed_record else []
- if knowledge_length not in split_by:
- split_by.append(knowledge_length)
-
- db.knowledges.processed_files.update_one(
- {"file_path": file_path},
- {"$set": {"hash": current_hash, "last_processed": datetime.now(), "split_by": split_by}},
- upsert=True,
- )
-
- except Exception as e:
- result["status"] = "failed"
- result["error"] = str(e)
-
- return result
-
- def _update_stats(self, total_stats, result, filename):
- """更新总体统计信息"""
- if result["status"] == "success":
- total_stats["processed_files"] += 1
- total_stats["total_chunks"] += result["chunks_processed"]
- elif result["status"] == "failed":
- total_stats["failed_files"].append((filename, result["error"]))
- elif result["status"] == "skipped":
- total_stats["skipped_files"].append(filename)
-
- def _display_processing_results(self, stats):
- """显示处理结果统计"""
- self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
-
- table = Table(show_header=True, header_style="bold magenta")
- table.add_column("统计项", style="dim")
- table.add_column("数值")
-
- table.add_row("成功处理文件数", str(stats["processed_files"]))
- table.add_row("处理的知识块总数", str(stats["total_chunks"]))
- table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
- table.add_row("失败的文件数", str(len(stats["failed_files"])))
-
- self.console.print(table)
-
- if stats["failed_files"]:
- self.console.print("\n[bold red]处理失败的文件:[/bold red]")
- for filename, error in stats["failed_files"]:
- self.console.print(f"[red]- {filename}: {error}[/red]")
-
- if stats["skipped_files"]:
- self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
- for filename in stats["skipped_files"]:
- self.console.print(f"[yellow]- {filename}[/yellow]")
-
- def calculate_file_hash(self, file_path):
- """计算文件的MD5哈希值"""
- hash_md5 = hashlib.md5()
- with open(file_path, "rb") as f:
- for chunk in iter(lambda: f.read(4096), b""):
- hash_md5.update(chunk)
- return hash_md5.hexdigest()
-
- def search_similar_segments(self, query: str, limit: int = 5) -> list:
- """搜索与查询文本相似的片段"""
- query_embedding = self.get_embedding(query)
- if not query_embedding:
- return []
-
- # 使用余弦相似度计算
- pipeline = [
- {
- "$addFields": {
- "dotProduct": {
- "$reduce": {
- "input": {"$range": [0, {"$size": "$embedding"}]},
- "initialValue": 0,
- "in": {
- "$add": [
- "$$value",
- {
- "$multiply": [
- {"$arrayElemAt": ["$embedding", "$$this"]},
- {"$arrayElemAt": [query_embedding, "$$this"]},
- ]
- },
- ]
- },
- }
- },
- "magnitude1": {
- "$sqrt": {
- "$reduce": {
- "input": "$embedding",
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- "magnitude2": {
- "$sqrt": {
- "$reduce": {
- "input": query_embedding,
- "initialValue": 0,
- "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
- }
- }
- },
- }
- },
- {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
- {"$sort": {"similarity": -1}},
- {"$limit": limit},
- {"$project": {"content": 1, "similarity": 1, "file_path": 1}},
- ]
-
- results = list(db.knowledges.aggregate(pipeline))
- return results
-
-
-# 创建单例实例
-knowledge_library = KnowledgeLibrary()
-
-if __name__ == "__main__":
- console = Console()
- console.print("[bold green]知识库处理工具[/bold green]")
-
- while True:
- console.print("\n请选择要执行的操作:")
- console.print("[1] 麦麦开始学习")
- console.print("[2] 麦麦全部忘光光(仅知识)")
- console.print("[q] 退出程序")
-
- choice = input("\n请输入选项: ").strip()
-
- if choice.lower() == "q":
- console.print("[yellow]程序退出[/yellow]")
- sys.exit(0)
- elif choice == "2":
- confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
- if confirm == "y":
- db.knowledges.delete_many({})
- console.print("[green]已清空所有知识![/green]")
- continue
- elif choice == "1":
- if not os.path.exists(knowledge_library.raw_info_dir):
- console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
- os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
-
- # 询问分割长度
- while True:
- try:
- length_input = input("请输入知识分割长度(默认512,输入q退出,回车使用默认值): ").strip()
- if length_input.lower() == "q":
- break
- if not length_input: # 如果直接回车,使用默认值
- knowledge_length = 512
- break
- knowledge_length = int(length_input)
- if knowledge_length <= 0:
- print("分割长度必须大于0,请重新输入")
- continue
- break
- except ValueError:
- print("请输入有效的数字")
- continue
-
- if length_input.lower() == "q":
- continue
-
- # 测试知识库功能
- print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
- knowledge_library.process_files(knowledge_length=knowledge_length)
- else:
- console.print("[red]无效的选项,请重新选择[/red]")
- continue
diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml
new file mode 100644
index 00000000..ef94c96e
--- /dev/null
+++ b/template/lpmm_config_template.toml
@@ -0,0 +1,54 @@
+# LLM API 服务提供商,可配置多个
+[[llm_providers]]
+name = "localhost"
+base_url = "http://127.0.0.1:8888/v1/"
+api_key = "lm_studio"
+
+[[llm_providers]]
+name = "siliconflow"
+base_url = "https://api.siliconflow.cn/v1/"
+api_key = ""
+
+[entity_extract.llm]
+# 设置用于实体提取的LLM模型
+provider = "localhost" # 服务提供商
+model = "deepseek-r1-distill-llama-8b" # 模型名称
+
+[rdf_build.llm]
+# 设置用于RDF构建的LLM模型
+provider = "localhost" # 服务提供商
+model = "deepseek-r1-distill-llama-8b" # 模型名称
+
+[embedding]
+# 设置用于文本嵌入的Embedding模型
+provider = "localhost" # 服务提供商
+model = "text-embedding-bge-m3" # 模型名称
+dimension = 1024 # 嵌入维度
+
+[rag.params]
+# RAG参数配置
+synonym_search_top_k = 10 # 同义词搜索TopK
+synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
+
+[qa.llm]
+# 设置用于QA的LLM模型
+provider = "localhost" # 服务提供商
+model = "deepseek-r1-distill-llama-8b" # 模型名称
+
+[qa.params]
+# QA参数配置
+relation_search_top_k = 10 # 关系搜索TopK
+relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系)
+paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果)
+paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)
+ent_filter_top_k = 10 # 实体过滤TopK
+ppr_damping = 0.8 # PPR阻尼系数
+res_top_k = 3 # 最终提供的文段TopK
+
+[persistence]
+# 持久化配置(存储中间数据,防止重复计算)
+data_root_path = "data" # 数据根目录
+raw_data_path = "data/import.json" # 原始数据路径
+openie_data_path = "data/openie.json" # OpenIE数据路径
+embedding_data_dir = "data/embedding" # 嵌入数据目录
+rag_data_dir = "data/rag" # RAG数据目录