From 273b36a073fe1301f18719b6d43df5cbdb2a6799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Mon, 21 Apr 2025 22:54:10 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E6=8F=92=E4=BB=B6=E5=8F=8A=E7=9B=B8=E5=85=B3=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增了知识库插件,包括知识提取、存储、检索等功能。主要新增了以下模块: - `knowledge_lib.py`: 知识库初始化及管理 - `qa_manager.py`: 问答系统管理 - `mem_active_manager.py`: 记忆激活管理 - `ie_process.py`: 信息提取处理 - `open_ie.py`: OpenIE数据处理 - `lpmmconfig.py`: 配置文件管理 - `prompt_template.py`: 提示模板管理 - `utils/`: 工具类模块,包括JSON修复、动态TopK选择、数据加载等 此外,还新增了相关工具`lpmm_get_knowledge.py`,用于从知识库中检索信息。同时更新了配置文件模板`lpmm_config_template.toml`和`bot_config_template.toml`,以支持新功能。 这些改动旨在增强麦麦的知识处理能力,使其能够更好地理解和回应复杂问题。 --- .gitignore | 1 + README.md | 9 +- import_openie.py | 171 +++++ info_extraction.py | 175 +++++ raw_data_preprocessor.py | 84 +++ src/config/config.py | 70 +- .../tool_can_use/lpmm_get_knowledge.py | 138 ++++ src/do_tool/tool_use.py | 1 + src/heart_flow/heartflow.py | 101 +-- src/heart_flow/observation.py | 2 +- src/heart_flow/sub_heartflow.py | 25 +- src/main.py | 2 +- src/plugins/chat/bot.py | 61 +- .../heartFC_chat/heartFC_controler.py | 96 +-- .../heartFC_chat/heartFC_processor.py | 16 +- .../chat_module/heartFC_chat/interest.py | 12 - .../chat_module/heartFC_chat/messagesender.py | 3 +- .../chat_module/heartFC_chat/pf_chatting.py | 7 +- .../reasoning_chat/reasoning_chat.py | 43 +- .../reasoning_chat/reasoning_generator.py | 2 +- .../reasoning_prompt_builder.py | 230 +----- .../reasoning_prompt_builder.py.bak | 454 ++++++++++++ src/plugins/knowledge/LICENSE | 674 ++++++++++++++++++ src/plugins/knowledge/__init__.py | 0 src/plugins/knowledge/knowledge_lib.py | 65 ++ src/plugins/knowledge/src/__init__.py | 0 src/plugins/knowledge/src/embedding_store.py | 251 +++++++ src/plugins/knowledge/src/global_logger.py | 14 + src/plugins/knowledge/src/ie_process.py | 108 +++ src/plugins/knowledge/src/kg_manager.py | 436 +++++++++++ src/plugins/knowledge/src/llm_client.py | 53 ++ src/plugins/knowledge/src/lpmmconfig.py | 143 ++++ .../knowledge/src/mem_active_manager.py | 36 + src/plugins/knowledge/src/open_ie.py | 147 ++++ src/plugins/knowledge/src/prompt_template.py | 73 ++ src/plugins/knowledge/src/qa_manager.py | 117 +++ src/plugins/knowledge/src/raw_processing.py | 46 ++ src/plugins/knowledge/src/utils/__init__.py | 0 .../knowledge/src/utils/data_loader.py | 58 ++ src/plugins/knowledge/src/utils/dyn_topk.py | 51 ++ src/plugins/knowledge/src/utils/hash.py | 8 + src/plugins/knowledge/src/utils/json_fix.py | 79 ++ .../knowledge/src/utils/visualize_graph.py | 17 + src/plugins/memory_system/Hippocampus.py | 9 +- template/bot_config_template.toml | 79 +- template/lpmm_config_template.toml | 54 ++ template/template.env | 20 +- 麦麦开始学习.bat | 46 ++ (临时版)麦麦开始学习.bat | 56 -- 49 files changed, 3743 insertions(+), 600 deletions(-) create mode 100644 import_openie.py create mode 100644 info_extraction.py create mode 100644 raw_data_preprocessor.py create mode 100644 src/do_tool/tool_can_use/lpmm_get_knowledge.py create mode 100644 src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py.bak create mode 100644 src/plugins/knowledge/LICENSE create mode 100644 src/plugins/knowledge/__init__.py create mode 100644 src/plugins/knowledge/knowledge_lib.py create mode 100644 src/plugins/knowledge/src/__init__.py create mode 100644 src/plugins/knowledge/src/embedding_store.py create mode 100644 src/plugins/knowledge/src/global_logger.py create mode 100644 src/plugins/knowledge/src/ie_process.py create mode 100644 src/plugins/knowledge/src/kg_manager.py create mode 100644 src/plugins/knowledge/src/llm_client.py create mode 100644 src/plugins/knowledge/src/lpmmconfig.py create mode 100644 src/plugins/knowledge/src/mem_active_manager.py create mode 100644 src/plugins/knowledge/src/open_ie.py create mode 100644 src/plugins/knowledge/src/prompt_template.py create mode 100644 src/plugins/knowledge/src/qa_manager.py create mode 100644 src/plugins/knowledge/src/raw_processing.py create mode 100644 src/plugins/knowledge/src/utils/__init__.py create mode 100644 src/plugins/knowledge/src/utils/data_loader.py create mode 100644 src/plugins/knowledge/src/utils/dyn_topk.py create mode 100644 src/plugins/knowledge/src/utils/hash.py create mode 100644 src/plugins/knowledge/src/utils/json_fix.py create mode 100644 src/plugins/knowledge/src/utils/visualize_graph.py create mode 100644 template/lpmm_config_template.toml create mode 100644 麦麦开始学习.bat delete mode 100644 (临时版)麦麦开始学习.bat diff --git a/.gitignore b/.gitignore index 9bf54a1d..45809b99 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ mongodb/ NapCat.Framework.Windows.Once/ log/ logs/ +temp/ run_ad.bat MaiBot-Napcat-Adapter-main MaiBot-Napcat-Adapter diff --git a/README.md b/README.md index 7eca2260..656f536a 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@

- Logo + Logo
@@ -34,6 +34,7 @@ · 提出新特性

+

## 新版0.6.x部署前先阅读:https://docs.mai-mai.org/manual/usage/mmc_q_a @@ -52,7 +53,7 @@
- 麦麦演示视频 + 麦麦演示视频
👆 点击观看麦麦演示视频 👆
@@ -98,7 +99,7 @@
-

📚 文档

+

📚 文档

### (部分内容可能过时,请注意版本对应) @@ -185,7 +186,7 @@ MaiCore是一个开源项目,我们非常欢迎你的参与。你的贡献, 感谢各位大佬! - contributors + **也感谢每一位给麦麦发展提出宝贵意见与建议的用户,感谢陪伴麦麦走到现在的你们** diff --git a/import_openie.py b/import_openie.py new file mode 100644 index 00000000..537187db --- /dev/null +++ b/import_openie.py @@ -0,0 +1,171 @@ +# try: +# import src.plugins.knowledge.lib.quick_algo +# except ImportError: +# print("未找到quick_algo库,无法使用quick_algo算法") +# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace") + + +from typing import Dict, List + +from src.plugins.knowledge.src.lpmmconfig import PG_NAMESPACE, global_config +from src.plugins.knowledge.src.embedding_store import EmbeddingManager +from src.plugins.knowledge.src.llm_client import LLMClient +from src.plugins.knowledge.src.open_ie import OpenIE +from src.plugins.knowledge.src.kg_manager import KGManager +from src.common.logger import get_module_logger +from src.plugins.knowledge.src.utils.hash import get_sha256 + +# 添加在现有导入之后 +import sys + +logger = get_module_logger("LPMM知识库-OpenIE导入") + +def hash_deduplicate( + raw_paragraphs: Dict[str, str], + triple_list_data: Dict[str, List[List[str]]], + stored_pg_hashes: set, + stored_paragraph_hashes: set, +): + """Hash去重 + + Args: + raw_paragraphs: 索引的段落原文 + triple_list_data: 索引的三元组列表 + stored_pg_hashes: 已存储的段落hash集合 + stored_paragraph_hashes: 已存储的段落hash集合 + + Returns: + new_raw_paragraphs: 去重后的段落 + new_triple_list_data: 去重后的三元组 + """ + # 保存去重后的段落 + new_raw_paragraphs = dict() + # 保存去重后的三元组 + new_triple_list_data = dict() + + for _, (raw_paragraph, triple_list) in enumerate( + zip(raw_paragraphs.values(), triple_list_data.values()) + ): + # 段落hash + paragraph_hash = get_sha256(raw_paragraph) + if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and ( + paragraph_hash in stored_paragraph_hashes + ): + continue + new_raw_paragraphs[paragraph_hash] = raw_paragraph + new_triple_list_data[paragraph_hash] = triple_list + + return new_raw_paragraphs, new_triple_list_data + + +def handle_import_openie( + openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager +) -> bool: + # 从OpenIE数据中提取段落原文与三元组列表 + # 索引的段落原文 + raw_paragraphs = openie_data.extract_raw_paragraph_dict() + # 索引的实体列表 + entity_list_data = openie_data.extract_entity_dict() + # 索引的三元组列表 + triple_list_data = openie_data.extract_triple_dict() + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len( + triple_list_data + ): + logger.error("OpenIE数据存在异常") + return False + # 将索引换为对应段落的hash值 + logger.info("正在进行段落去重与重索引") + raw_paragraphs, triple_list_data = hash_deduplicate( + raw_paragraphs, + triple_list_data, + embed_manager.stored_pg_hashes, + kg_manager.stored_paragraph_hashes, + ) + if len(raw_paragraphs) != 0: + # 获取嵌入并保存 + logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}") + logger.info("开始Embedding") + embed_manager.store_new_data_set(raw_paragraphs, triple_list_data) + # Embedding-Faiss重索引 + logger.info("正在重新构建向量索引") + embed_manager.rebuild_faiss_index() + logger.info("向量索引构建完成") + embed_manager.save_to_file() + logger.info("Embedding完成") + # 构建新段落的RAG + logger.info("开始构建RAG") + kg_manager.build_kg(triple_list_data, embed_manager) + kg_manager.save_to_file() + logger.info("RAG构建完成") + else: + logger.info("无新段落需要处理") + return True + + +def main(): + # 新增确认提示 + print("=== 重要操作确认 ===") + print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") + print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") + print("推荐使用硅基流动的Pro/BAAI/bge-m3") + print("每百万Token费用为0.7元") + print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") + print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != 'y': + logger.info("用户取消操作") + print("操作已取消") + sys.exit(1) + print("\n" + "="*40 + "\n") + + logger.info("----开始导入openie数据----\n") + + logger.info("创建LLM客户端") + llm_client_list = dict() + for key in global_config["llm_providers"]: + llm_client_list[key] = LLMClient( + global_config["llm_providers"][key]["base_url"], + global_config["llm_providers"][key]["api_key"], + ) + + # 初始化Embedding库 + embed_manager = embed_manager = EmbeddingManager( + llm_client_list[global_config["embedding"]["provider"]] + ) + logger.info("正在从文件加载Embedding库") + try: + embed_manager.load_from_file() + except Exception as e: + logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.info("Embedding库加载完成") + # 初始化KG + kg_manager = KGManager() + logger.info("正在从文件加载KG") + try: + kg_manager.load_from_file() + except Exception as e: + logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.info("KG加载完成") + + logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") + logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") + + # 数据比对:Embedding库与KG的段落hash集合 + for pg_hash in kg_manager.stored_paragraph_hashes: + key = PG_NAMESPACE + "-" + pg_hash + if key not in embed_manager.stored_pg_hashes: + logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") + + logger.info("正在导入OpenIE数据文件") + try: + openie_data = OpenIE.load() + except Exception as e: + logger.error("导入OpenIE数据文件时发生错误:{}".format(e)) + return False + if handle_import_openie(openie_data, embed_manager, kg_manager) is False: + logger.error("处理OpenIE数据时发生错误") + return False + + +if __name__ == "__main__": + main() diff --git a/info_extraction.py b/info_extraction.py new file mode 100644 index 00000000..04f2dd4f --- /dev/null +++ b/info_extraction.py @@ -0,0 +1,175 @@ +import json +import os +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Lock, Event +import sys + +import tqdm + +from src.common.logger import get_module_logger +from src.plugins.knowledge.src.lpmmconfig import global_config +from src.plugins.knowledge.src.ie_process import info_extract_from_str +from src.plugins.knowledge.src.llm_client import LLMClient +from src.plugins.knowledge.src.open_ie import OpenIE +from src.plugins.knowledge.src.raw_processing import load_raw_data + +logger = get_module_logger("LPMM知识库-信息提取") + +TEMP_DIR = "./temp" + +# 创建一个线程安全的锁,用于保护文件操作和共享数据 +file_lock = Lock() +open_ie_doc_lock = Lock() + +# 创建一个事件标志,用于控制程序终止 +shutdown_event = Event() + + +def process_single_text(pg_hash, raw_data, llm_client_list): + """处理单个文本的函数,用于线程池""" + temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" + + # 使用文件锁检查和读取缓存文件 + with file_lock: + if os.path.exists(temp_file_path): + try: + # 存在对应的提取结果 + logger.info(f"找到缓存的提取结果:{pg_hash}") + with open(temp_file_path, "r", encoding="utf-8") as f: + return json.load(f), None + except json.JSONDecodeError: + # 如果JSON文件损坏,删除它并重新处理 + logger.warning(f"缓存文件损坏,重新处理:{pg_hash}") + os.remove(temp_file_path) + + entity_list, rdf_triple_list = info_extract_from_str( + llm_client_list[global_config["entity_extract"]["llm"]["provider"]], + llm_client_list[global_config["rdf_build"]["llm"]["provider"]], + raw_data, + ) + if entity_list is None or rdf_triple_list is None: + return None, pg_hash + else: + doc_item = { + "idx": pg_hash, + "passage": raw_data, + "extracted_entities": entity_list, + "extracted_triples": rdf_triple_list, + } + # 保存临时提取结果 + with file_lock: + try: + with open(temp_file_path, "w", encoding="utf-8") as f: + json.dump(doc_item, f, ensure_ascii=False, indent=4) + except Exception as e: + logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") + # 如果保存失败,确保不会留下损坏的文件 + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + # 设置shutdown_event以终止程序 + shutdown_event.set() + return None, pg_hash + return doc_item, None + + +def signal_handler(signum, frame): + """处理Ctrl+C信号""" + logger.info("\n接收到中断信号,正在优雅地关闭程序...") + shutdown_event.set() + + +def main(): + # 设置信号处理器 + signal.signal(signal.SIGINT, signal_handler) + + # 新增用户确认提示 + print("=== 重要操作确认 ===") + print("实体提取操作将会花费较多资金和时间,建议在空闲时段执行。") + print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。") + print("建议使用硅基流动的非Pro模型") + print("或者使用可以用赠金抵扣的Pro模型") + print("请确保账户余额充足,并且在执行前确认无误。") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != 'y': + logger.info("用户取消操作") + print("操作已取消") + sys.exit(1) + print("\n" + "="*40 + "\n") + + logger.info("--------进行信息提取--------\n") + + logger.info("创建LLM客户端") + llm_client_list = dict() + for key in global_config["llm_providers"]: + llm_client_list[key] = LLMClient( + global_config["llm_providers"][key]["base_url"], + global_config["llm_providers"][key]["api_key"], + ) + + logger.info("正在加载原始数据") + sha256_list, raw_datas = load_raw_data() + logger.info("原始数据加载完成\n") + + # 创建临时目录 + if not os.path.exists(f"{TEMP_DIR}"): + os.makedirs(f"{TEMP_DIR}") + + failed_sha256 = [] + open_ie_doc = [] + + # 创建线程池,最大线程数为50 + workers = global_config["info_extraction"]["workers"] + with ThreadPoolExecutor(max_workers=workers) as executor: + # 提交所有任务到线程池 + future_to_hash = { + executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash + for pg_hash, raw_data in zip(sha256_list, raw_datas) + } + + # 使用tqdm显示进度 + with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar: + # 处理完成的任务 + try: + for future in as_completed(future_to_hash): + if shutdown_event.is_set(): + # 取消所有未完成的任务 + for f in future_to_hash: + if not f.done(): + f.cancel() + break + + doc_item, failed_hash = future.result() + if failed_hash: + failed_sha256.append(failed_hash) + logger.error(f"提取失败:{failed_hash}") + elif doc_item: + with open_ie_doc_lock: + open_ie_doc.append(doc_item) + pbar.update(1) + except KeyboardInterrupt: + # 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作 + logger.info("\n接收到中断信号,正在优雅地关闭程序...") + shutdown_event.set() + # 取消所有未完成的任务 + for f in future_to_hash: + if not f.done(): + f.cancel() + + # 保存信息提取结果 + sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) + openie_obj = OpenIE( + open_ie_doc, + round(sum_phrase_chars / num_phrases, 4), + round(sum_phrase_words / num_phrases, 4), + ) + OpenIE.save(openie_obj) + + logger.info("--------信息提取完成--------") + logger.info(f"提取失败的文段SHA256:{failed_sha256}") + + +if __name__ == "__main__": + main() diff --git a/raw_data_preprocessor.py b/raw_data_preprocessor.py new file mode 100644 index 00000000..abb1ad64 --- /dev/null +++ b/raw_data_preprocessor.py @@ -0,0 +1,84 @@ +import json +import os +from pathlib import Path +import sys # 新增系统模块导入 + +def check_and_create_dirs(): + """检查并创建必要的目录""" + required_dirs = [ + "data/lpmm_raw_data", + "data/imported_lpmm_data" + ] + + for dir_path in required_dirs: + if not os.path.exists(dir_path): + os.makedirs(dir_path) + print(f"已创建目录: {dir_path}") + +def process_text_file(file_path): + """处理单个文本文件,返回段落列表""" + with open(file_path, "r", encoding="utf-8") as f: + raw = f.read() + + paragraphs = [] + paragraph = "" + for line in raw.split("\n"): + if line.strip() == "": + if paragraph != "": + paragraphs.append(paragraph.strip()) + paragraph = "" + else: + paragraph += line + "\n" + + if paragraph != "": + paragraphs.append(paragraph.strip()) + + return paragraphs + +def main(): + # 新增用户确认提示 + print("=== 重要操作确认 ===") + print("如果你并非第一次导入知识") + print("请先删除data/import.json文件,备份data/openie.json文件") + print("在进行知识库导入之前") + print("请修改config/lpmm_config.toml中的配置项") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != 'y': + print("操作已取消") + sys.exit(1) + print("\n" + "="*40 + "\n") + + # 检查并创建必要的目录 + check_and_create_dirs() + + # 检查输出文件是否存在 + if os.path.exists("data/import.json"): + print("错误: data/import.json 已存在,请先处理或删除该文件") + sys.exit(1) + + if os.path.exists("data/openie.json"): + print("错误: data/openie.json 已存在,请先处理或删除该文件") + sys.exit(1) + + # 获取所有原始文本文件 + raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) + if not raw_files: + print("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") + sys.exit(1) + + # 处理所有文件 + all_paragraphs = [] + for file in raw_files: + print(f"正在处理文件: {file.name}") + paragraphs = process_text_file(file) + all_paragraphs.extend(paragraphs) + + # 保存合并后的结果 + output_path = "data/import.json" + with open(output_path, "w", encoding="utf-8") as f: + json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) + + print(f"处理完成,结果已保存到: {output_path}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/config/config.py b/src/config/config.py index 83e47837..0dae0244 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -186,18 +186,12 @@ class BotConfig: ban_words = set() ban_msgs_regex = set() - # [heartflow] # 启用启用heart_flowC(心流聊天)模式时生效, 需要填写token消耗量巨大的相关模型 - # 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间), 进行长时间高质量的聊天 - enable_heart_flowC: bool = True # 是否启用heart_flowC(心流聊天, HFC)模式 - reply_trigger_threshold: float = 3.0 # 心流聊天触发阈值,越低越容易触发 - probability_decay_factor_per_second: float = 0.2 # 概率衰减因子,越大衰减越快 - default_decay_rate_per_second: float = 0.98 # 默认衰减率,越大衰减越慢 - initial_duration: int = 60 # 初始持续时间,越大心流聊天持续的时间越长 - - # sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒 - # sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 + # heartflow + # enable_heartflow: bool = False # 是否启用心流 + sub_heart_flow_update_interval: int = 60 # 子心流更新频率,间隔 单位秒 + sub_heart_flow_freeze_time: int = 120 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 sub_heart_flow_stop_time: int = 600 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 - # heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒 + heart_flow_update_interval: int = 300 # 心流更新频率,间隔 单位秒 observation_context_size: int = 20 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩 compressed_length: int = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 compress_length_limit: int = 5 # 最多压缩份数,超过该数值的压缩上下文会被删除 @@ -213,8 +207,8 @@ class BotConfig: # response response_mode: str = "heart_flow" # 回复策略 - model_reasoning_probability: float = 0.7 # 麦麦回答时选择推理模型(主要)模型概率 - model_normal_probability: float = 0.3 # 麦麦回答时选择一般模型(次要)模型概率 + MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 + MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 # MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 # emoji @@ -407,34 +401,29 @@ class BotConfig: def response(parent: dict): response_config = parent["response"] - config.model_reasoning_probability = response_config.get( - "model_reasoning_probability", config.model_reasoning_probability - ) - config.model_normal_probability = response_config.get( - "model_normal_probability", config.model_normal_probability - ) - - # 添加 enable_heart_flowC 的加载逻辑 (假设它在 [response] 部分) - if config.INNER_VERSION in SpecifierSet(">=1.4.0"): - config.enable_heart_flowC = response_config.get("enable_heart_flowC", config.enable_heart_flowC) + config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) + config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) + # config.MODEL_R1_DISTILL_PROBABILITY = response_config.get( + # "model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY + # ) + config.max_response_length = response_config.get("max_response_length", config.max_response_length) + if config.INNER_VERSION in SpecifierSet(">=1.0.4"): + config.response_mode = response_config.get("response_mode", config.response_mode) def heartflow(parent: dict): heartflow_config = parent["heartflow"] - # 加载新增的 heartflowC 参数 - - # 加载原有的 heartflow 参数 - # config.sub_heart_flow_update_interval = heartflow_config.get( - # "sub_heart_flow_update_interval", config.sub_heart_flow_update_interval - # ) - # config.sub_heart_flow_freeze_time = heartflow_config.get( - # "sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time - # ) + config.sub_heart_flow_update_interval = heartflow_config.get( + "sub_heart_flow_update_interval", config.sub_heart_flow_update_interval + ) + config.sub_heart_flow_freeze_time = heartflow_config.get( + "sub_heart_flow_freeze_time", config.sub_heart_flow_freeze_time + ) config.sub_heart_flow_stop_time = heartflow_config.get( "sub_heart_flow_stop_time", config.sub_heart_flow_stop_time ) - # config.heart_flow_update_interval = heartflow_config.get( - # "heart_flow_update_interval", config.heart_flow_update_interval - # ) + config.heart_flow_update_interval = heartflow_config.get( + "heart_flow_update_interval", config.heart_flow_update_interval + ) if config.INNER_VERSION in SpecifierSet(">=1.3.0"): config.observation_context_size = heartflow_config.get( "observation_context_size", config.observation_context_size @@ -443,17 +432,6 @@ class BotConfig: config.compress_length_limit = heartflow_config.get( "compress_length_limit", config.compress_length_limit ) - if config.INNER_VERSION in SpecifierSet(">=1.4.0"): - config.reply_trigger_threshold = heartflow_config.get( - "reply_trigger_threshold", config.reply_trigger_threshold - ) - config.probability_decay_factor_per_second = heartflow_config.get( - "probability_decay_factor_per_second", config.probability_decay_factor_per_second - ) - config.default_decay_rate_per_second = heartflow_config.get( - "default_decay_rate_per_second", config.default_decay_rate_per_second - ) - config.initial_duration = heartflow_config.get("initial_duration", config.initial_duration) def willing(parent: dict): willing_config = parent["willing"] diff --git a/src/do_tool/tool_can_use/lpmm_get_knowledge.py b/src/do_tool/tool_can_use/lpmm_get_knowledge.py new file mode 100644 index 00000000..5c70adc1 --- /dev/null +++ b/src/do_tool/tool_can_use/lpmm_get_knowledge.py @@ -0,0 +1,138 @@ +from src.do_tool.tool_can_use.base_tool import BaseTool +from src.plugins.chat.utils import get_embedding +# from src.common.database import db +from src.common.logger import get_module_logger +from typing import Dict, Any +from src.plugins.knowledge.knowledge_lib import qa_manager + + +logger = get_module_logger("lpmm_get_knowledge_tool") + + +class SearchKnowledgeFromLPMMTool(BaseTool): + """从LPMM知识库中搜索相关信息的工具""" + + name = "lpmm_search_knowledge" + description = "从知识库中搜索相关信息" + parameters = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "搜索查询关键词"}, + "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, + }, + "required": ["query"], + } + + async def execute(self, function_args: Dict[str, Any], message_txt: str = "") -> Dict[str, Any]: + """执行知识库搜索 + + Args: + function_args: 工具参数 + message_txt: 原始消息文本 + + Returns: + Dict: 工具执行结果 + """ + try: + query = function_args.get("query", message_txt) + # threshold = function_args.get("threshold", 0.4) + + # 调用知识库搜索 + embedding = await get_embedding(query, request_type="info_retrieval") + if embedding: + knowledge_info = qa_manager.get_knowledge(query) + logger.debug(f"知识库查询结果: {knowledge_info}") + if knowledge_info: + content = f"你知道这些知识: {knowledge_info}" + else: + content = f"你不太了解有关{query}的知识" + return {"name": "search_knowledge", "content": content} + return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"} + except Exception as e: + logger.error(f"知识库搜索工具执行失败: {str(e)}") + return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"} + + # def get_info_from_db( + # self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + # ) -> Union[str, list]: + # """从数据库中获取相关信息 + + # Args: + # query_embedding: 查询的嵌入向量 + # limit: 最大返回结果数 + # threshold: 相似度阈值 + # return_raw: 是否返回原始结果 + + # Returns: + # Union[str, list]: 格式化的信息字符串或原始结果列表 + # """ + # if not query_embedding: + # return "" if not return_raw else [] + + # # 使用余弦相似度计算 + # pipeline = [ + # { + # "$addFields": { + # "dotProduct": { + # "$reduce": { + # "input": {"$range": [0, {"$size": "$embedding"}]}, + # "initialValue": 0, + # "in": { + # "$add": [ + # "$$value", + # { + # "$multiply": [ + # {"$arrayElemAt": ["$embedding", "$$this"]}, + # {"$arrayElemAt": [query_embedding, "$$this"]}, + # ] + # }, + # ] + # }, + # } + # }, + # "magnitude1": { + # "$sqrt": { + # "$reduce": { + # "input": "$embedding", + # "initialValue": 0, + # "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + # } + # } + # }, + # "magnitude2": { + # "$sqrt": { + # "$reduce": { + # "input": query_embedding, + # "initialValue": 0, + # "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + # } + # } + # }, + # } + # }, + # {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, + # { + # "$match": { + # "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + # } + # }, + # {"$sort": {"similarity": -1}}, + # {"$limit": limit}, + # {"$project": {"content": 1, "similarity": 1}}, + # ] + + # results = list(db.knowledges.aggregate(pipeline)) + # logger.debug(f"知识库查询结果数量: {len(results)}") + + # if not results: + # return "" if not return_raw else [] + + # if return_raw: + # return results + # else: + # # 返回所有找到的内容,用换行分隔 + # return "\n".join(str(result["content"]) for result in results) + + +# 注册工具 +# register_tool(SearchKnowledgeTool) diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py index 52c26f80..938dde16 100644 --- a/src/do_tool/tool_use.py +++ b/src/do_tool/tool_use.py @@ -47,6 +47,7 @@ class ToolUser: prompt += message_txt # prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n" prompt += f"注意你就是{bot_name},{bot_name}是你的名字。根据之前的聊天记录补充问题信息,搜索时避开你的名字。\n" + prompt += "必须调用 'lpmm_get_knowledge' 工具来获取知识。\n" prompt += "你现在需要对群里的聊天内容进行回复,现在选择工具来对消息和你的回复进行处理,你是否需要额外的信息,比如回忆或者搜寻已有的知识,改变关系和情感,或者了解你现在正在做什么。" prompt = await relationship_manager.convert_all_person_sign_to_person_name(prompt) diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 50f0a735..793f406f 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -1,4 +1,5 @@ -from .sub_heartflow import SubHeartflow, ChattingObservation +from .sub_heartflow import SubHeartflow +from .observation import ChattingObservation from src.plugins.moods.moods import MoodManager from src.plugins.models.utils_model import LLMRequest from src.config.config import global_config @@ -9,8 +10,7 @@ from src.common.logger import get_module_logger, LogConfig, HEARTFLOW_STYLE_CONF from src.individuality.individuality import Individuality import time import random -from typing import Dict, Any, Optional -import traceback +from typing import Dict, Any heartflow_config = LogConfig( # 使用海马体专用样式 @@ -45,8 +45,6 @@ class CurrentState: def __init__(self): self.current_state_info = "" - self.chat_status = "IDLE" - self.mood_manager = MoodManager() self.mood = self.mood_manager.get_prompt() @@ -72,27 +70,20 @@ class Heartflow: """定期清理不活跃的子心流""" while True: current_time = time.time() - inactive_subheartflows_ids = [] # 修改变量名以清晰表示存储的是ID + inactive_subheartflows = [] # 检查所有子心流 - # 使用 list(self._subheartflows.items()) 避免在迭代时修改字典 - for subheartflow_id, subheartflow in list(self._subheartflows.items()): + for subheartflow_id, subheartflow in self._subheartflows.items(): if ( current_time - subheartflow.last_active_time > global_config.sub_heart_flow_stop_time ): # 10分钟 = 600秒 - logger.info(f"发现不活跃的子心流: {subheartflow_id}, 准备清理。") - # 1. 标记子心流让其后台任务停止 - subheartflow.should_stop = True - # 2. 将ID添加到待清理列表 - inactive_subheartflows_ids.append(subheartflow_id) + inactive_subheartflows.append(subheartflow_id) + logger.info(f"发现不活跃的子心流: {subheartflow_id}") - # 清理不活跃的子心流 (从字典中移除) - for subheartflow_id in inactive_subheartflows_ids: - if subheartflow_id in self._subheartflows: - del self._subheartflows[subheartflow_id] - logger.info(f"已从主心流移除子心流: {subheartflow_id}") - else: - logger.warning(f"尝试移除子心流 {subheartflow_id} 时发现其已被移除。") + # 清理不活跃的子心流 + for subheartflow_id in inactive_subheartflows: + del self._subheartflows[subheartflow_id] + logger.info(f"已清理不活跃的子心流: {subheartflow_id}") await asyncio.sleep(30) # 每分钟检查一次 @@ -104,10 +95,8 @@ class Heartflow: await asyncio.sleep(30) # 每分钟检查一次是否有新的子心流 continue - # await self.do_a_thinking() - # await asyncio.sleep(global_config.heart_flow_update_interval * 3) # 5分钟思考一次 - - await asyncio.sleep(300) + await self.do_a_thinking() + await asyncio.sleep(global_config.heart_flow_update_interval) # 5分钟思考一次 async def heartflow_start_working(self): # 启动清理任务 @@ -121,7 +110,7 @@ class Heartflow: print("TODO") async def do_a_thinking(self): - # logger.debug("麦麦大脑袋转起来了") + logger.debug("麦麦大脑袋转起来了") self.current_state.update_current_state_info() # 开始构建prompt @@ -227,55 +216,33 @@ class Heartflow: return response - async def create_subheartflow(self, subheartflow_id: Any) -> Optional[SubHeartflow]: + async def create_subheartflow(self, subheartflow_id): """ - 获取或创建一个新的SubHeartflow实例。 - - 如果实例已存在,则直接返回。 - 如果不存在,则创建实例、观察对象、启动后台任务,并返回新实例。 - 创建过程中发生任何错误将返回 None。 - - Args: - subheartflow_id: 用于标识子心流的ID (例如群聊ID)。 - - Returns: - 对应的 SubHeartflow 实例,如果创建失败则返回 None。 + 创建一个新的SubHeartflow实例 + 添加一个SubHeartflow实例到self._subheartflows字典中 + 并根据subheartflow_id为子心流创建一个观察对象 """ - # 检查是否已存在 - existing_subheartflow = self._subheartflows.get(subheartflow_id) - if existing_subheartflow: - logger.debug(f"返回已存在的 subheartflow: {subheartflow_id}") - return existing_subheartflow - # 如果不存在,则创建新的 - logger.info(f"尝试创建新的 subheartflow: {subheartflow_id}") try: - subheartflow = SubHeartflow(subheartflow_id) - - # 创建并初始化观察对象 - logger.debug(f"为 {subheartflow_id} 创建 observation") - observation = ChattingObservation(subheartflow_id) - await observation.initialize() # 等待初始化完成 - subheartflow.add_observation(observation) - logger.debug(f"为 {subheartflow_id} 添加 observation 成功") - - # 创建并存储后台任务 - subheartflow.task = asyncio.create_task(subheartflow.subheartflow_start_working()) - logger.debug(f"为 {subheartflow_id} 创建后台任务成功") - - # 添加到管理字典 - self._subheartflows[subheartflow_id] = subheartflow - logger.info(f"添加 subheartflow {subheartflow_id} 成功") - return subheartflow - + if subheartflow_id not in self._subheartflows: + subheartflow = SubHeartflow(subheartflow_id) + # 创建一个观察对象,目前只可以用chat_id创建观察对象 + logger.debug(f"创建 observation: {subheartflow_id}") + observation = ChattingObservation(subheartflow_id) + await observation.initialize() + subheartflow.add_observation(observation) + logger.debug("添加 observation 成功") + # 创建异步任务 + asyncio.create_task(subheartflow.subheartflow_start_working()) + logger.debug("创建异步任务 成功") + self._subheartflows[subheartflow_id] = subheartflow + logger.info("添加 subheartflow 成功") + return self._subheartflows[subheartflow_id] except Exception as e: - # 记录详细错误信息 - logger.error(f"创建 subheartflow {subheartflow_id} 失败: {e}") - logger.error(traceback.format_exc()) # 记录完整的 traceback - # 考虑是否需要更具体的错误处理或资源清理逻辑 + logger.error(f"创建 subheartflow 失败: {e}") return None - def get_subheartflow(self, observe_chat_id: Any) -> Optional[SubHeartflow]: + def get_subheartflow(self, observe_chat_id) -> SubHeartflow: """获取指定ID的SubHeartflow实例""" return self._subheartflows.get(observe_chat_id) diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 49efe7eb..9903b184 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -139,7 +139,7 @@ class ChattingObservation(Observation): # traceback.print_exc() # 记录详细堆栈 # print(f"处理后self.talking_message:{self.talking_message}") - self.talking_message_str = await build_readable_messages(messages=self.talking_message, timestamp_mode="normal") + self.talking_message_str = await build_readable_messages(self.talking_message) logger.trace( f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}" diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index 9087b576..439b2a3f 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -4,7 +4,8 @@ from src.plugins.moods.moods import MoodManager from src.plugins.models.utils_model import LLMRequest from src.config.config import global_config import time -from typing import Optional, List +from typing import Optional +from datetime import datetime import traceback from src.plugins.chat.utils import parse_text_timestamps @@ -64,7 +65,7 @@ class SubHeartflow: def __init__(self, subheartflow_id): self.subheartflow_id = subheartflow_id - self.current_mind = "你什么也没想" + self.current_mind = "" self.past_mind = [] self.current_state: CurrentState = CurrentState() self.llm_model = LLMRequest( @@ -76,13 +77,15 @@ class SubHeartflow: self.main_heartflow_info = "" + self.last_reply_time = time.time() self.last_active_time = time.time() # 添加最后激活时间 - self.should_stop = False # 添加停止标志 - self.task: Optional[asyncio.Task] = None # 添加 task 属性 + + if not self.current_mind: + self.current_mind = "你什么也没想" self.is_active = False - self.observations: List[ChattingObservation] = [] # 使用 List 类型提示 + self.observations: list[ChattingObservation] = [] self.running_knowledges = [] @@ -90,13 +93,19 @@ class SubHeartflow: async def subheartflow_start_working(self): while True: + current_time = time.time() # --- 调整后台任务逻辑 --- # # 这个后台循环现在主要负责检查是否需要自我销毁 # 不再主动进行思考或状态更新,这些由 HeartFC_Chat 驱动 - # 检查是否被主心流标记为停止 - if self.should_stop: - logger.info(f"子心流 {self.subheartflow_id} 被标记为停止,正在退出后台任务...") + # 检查是否超过指定时间没有激活 (例如,没有被调用进行思考) + if current_time - self.last_active_time > global_config.sub_heart_flow_stop_time: # 例如 5 分钟 + logger.info( + f"子心流 {self.subheartflow_id} 超过 {global_config.sub_heart_flow_stop_time} 秒没有激活,正在销毁..." + f" (Last active: {datetime.fromtimestamp(self.last_active_time).strftime('%Y-%m-%d %H:%M:%S')})" + ) + # 在这里添加实际的销毁逻辑,例如从主 Heartflow 管理器中移除自身 + # heartflow.remove_subheartflow(self.subheartflow_id) # 假设有这样的方法 break # 退出循环以停止任务 await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 定期检查销毁条件 diff --git a/src/main.py b/src/main.py index 99578591..929cff7d 100644 --- a/src/main.py +++ b/src/main.py @@ -117,7 +117,7 @@ class MainSystem: await interest_manager.start_background_tasks() logger.success("兴趣管理器后台任务启动成功") - # 初始化并独立启动 HeartFCController + # 初始化并独立启动 HeartFC_Chat HeartFCController() heartfc_chat_instance = HeartFCController.get_instance() if heartfc_chat_instance: diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 12ee7b6d..cfe4238e 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -105,24 +105,53 @@ class ChatBot: template_group_name = None async def preprocess(): - if groupinfo is None: - if global_config.enable_friend_chat: - if global_config.enable_pfc_chatting: - userinfo = message.message_info.user_info - messageinfo = message.message_info - # 创建聊天流 - chat = await chat_manager.get_or_create_stream( - platform=messageinfo.platform, - user_info=userinfo, - group_info=groupinfo, - ) - message.update_chat_stream(chat) - await self.only_process_chat.process_message(message) - await self._create_pfc_chat(message) + if global_config.enable_pfc_chatting: + try: + if groupinfo is None: + if global_config.enable_friend_chat: + userinfo = message.message_info.user_info + messageinfo = message.message_info + # 创建聊天流 + chat = await chat_manager.get_or_create_stream( + platform=messageinfo.platform, + user_info=userinfo, + group_info=groupinfo, + ) + message.update_chat_stream(chat) + await self.only_process_chat.process_message(message) + await self._create_pfc_chat(message) else: - await self.heartFC_processor.process_message(message_data) + if groupinfo.group_id in global_config.talk_allowed_groups: + # logger.debug(f"开始群聊模式{str(message_data)[:50]}...") + if global_config.response_mode == "heart_flow": + # logger.info(f"启动最新最好的思维流FC模式{str(message_data)[:50]}...") + await self.heartFC_processor.process_message(message_data) + elif global_config.response_mode == "reasoning": + # logger.debug(f"开始推理模式{str(message_data)[:50]}...") + await self.reasoning_chat.process_message(message_data) + else: + logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}") + except Exception as e: + logger.error(f"处理PFC消息失败: {e}") else: - await self.heartFC_processor.process_message(message_data) + if groupinfo is None: + if global_config.enable_friend_chat: + # 私聊处理流程 + # await self._handle_private_chat(message) + if global_config.response_mode == "heart_flow": + await self.heartFC_processor.process_message(message_data) + elif global_config.response_mode == "reasoning": + await self.reasoning_chat.process_message(message_data) + else: + logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}") + else: # 群聊处理 + if groupinfo.group_id in global_config.talk_allowed_groups: + if global_config.response_mode == "heart_flow": + await self.heartFC_processor.process_message(message_data) + elif global_config.response_mode == "reasoning": + await self.reasoning_chat.process_message(message_data) + else: + logger.error(f"未知的回复模式,请检查配置文件!!: {global_config.response_mode}") if template_group_name: async with global_prompt_manager.async_message_scope(template_group_name): diff --git a/src/plugins/chat_module/heartFC_chat/heartFC_controler.py b/src/plugins/chat_module/heartFC_chat/heartFC_controler.py index 4dd49e2d..a217f978 100644 --- a/src/plugins/chat_module/heartFC_chat/heartFC_controler.py +++ b/src/plugins/chat_module/heartFC_chat/heartFC_controler.py @@ -1,7 +1,7 @@ import traceback from typing import Optional, Dict import asyncio -import threading # 导入 threading +from asyncio import Lock from ...moods.moods import MoodManager from ...chat.emoji_manager import emoji_manager from .heartFC_generator import ResponseGenerator @@ -14,7 +14,6 @@ from .interest import InterestManager from src.plugins.chat.chat_stream import chat_manager from .pf_chatting import PFChatting - # 定义日志配置 chat_config = LogConfig( console_format=CHAT_STYLE_CONFIG["console_format"], @@ -27,81 +26,44 @@ logger = get_module_logger("HeartFCController", config=chat_config) INTEREST_MONITOR_INTERVAL_SECONDS = 1 -# 合并后的版本:使用 __new__ + threading.Lock 实现线程安全单例,类名为 HeartFCController class HeartFCController: - _instance = None - _lock = threading.Lock() # 使用 threading.Lock 保证 __new__ 线程安全 - _initialized = False - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - with cls._lock: - # Double-checked locking - if cls._instance is None: - logger.debug("创建 HeartFCController 单例实例...") - cls._instance = super().__new__(cls) - return cls._instance + _instance = None # For potential singleton access if needed by MessageManager def __init__(self): - # 使用 _initialized 标志确保 __init__ 只执行一次 - if self._initialized: + # --- Updated Init --- + if HeartFCController._instance is not None: + # Prevent re-initialization if used as a singleton return - # 虽然 __new__ 保证了只有一个实例,但为了防止意外重入或多线程下的初始化竞争, - # 再次使用类锁保护初始化过程是更严谨的做法。 - # 如果确定 __init__ 逻辑本身是幂等的或非关键的,可以省略这里的锁。 - # 但为了保持原始逻辑的意图(防止重复初始化),这里保留检查。 - with self.__class__._lock: # 确保初始化逻辑线程安全 - if self._initialized: # 再次检查,防止锁等待期间其他线程已完成初始化 - return - - logger.info("正在初始化 HeartFCController 单例...") - self.gpt = ResponseGenerator() - self.mood_manager = MoodManager.get_instance() - # 注意:mood_manager 的 start_mood_update 可能需要在应用主循环启动后调用, - # 或者确保其内部实现是安全的。这里保持原状。 - self.mood_manager.start_mood_update() - self.tool_user = ToolUser() - # 注意:InterestManager() 可能是另一个单例或需要特定初始化。 - # 假设 InterestManager() 返回的是正确配置的实例。 - self.interest_manager = InterestManager() - self._interest_monitor_task: Optional[asyncio.Task] = None - self.pf_chatting_instances: Dict[str, PFChatting] = {} - # _pf_chatting_lock 用于保护 pf_chatting_instances 的异步操作 - self._pf_chatting_lock = asyncio.Lock() # 这个是 asyncio.Lock,用于异步上下文 - self.emoji_manager = emoji_manager # 假设是全局或已初始化的实例 - self.relationship_manager = relationship_manager # 假设是全局或已初始化的实例 - # MessageManager 可能是类本身或单例实例,根据其设计确定 - self.MessageManager = MessageManager - self._initialized = True - logger.info("HeartFCController 单例初始化完成。") + self.gpt = ResponseGenerator() + self.mood_manager = MoodManager.get_instance() + self.mood_manager.start_mood_update() + self.tool_user = ToolUser() + self.interest_manager = InterestManager() + self._interest_monitor_task: Optional[asyncio.Task] = None + # --- New PFChatting Management --- + self.pf_chatting_instances: Dict[str, PFChatting] = {} + self._pf_chatting_lock = Lock() + # --- End New PFChatting Management --- + HeartFCController._instance = self # Register instance + # --- End Updated Init --- + # --- Make dependencies accessible for PFChatting --- + # These are accessed via the passed instance in PFChatting + self.emoji_manager = emoji_manager + self.relationship_manager = relationship_manager + self.MessageManager = MessageManager # Pass the class/singleton access + # --- End dependencies --- + # --- Added Class Method for Singleton Access --- @classmethod def get_instance(cls): - """获取 HeartFCController 的单例实例。""" - # 如果实例尚未创建,调用构造函数(这将触发 __new__ 和 __init__) if cls._instance is None: - # 在首次调用 get_instance 时创建实例。 - # __new__ 中的锁会确保线程安全。 - cls() - # 添加日志记录,说明实例是在 get_instance 调用时创建的 - logger.info("HeartFCController 实例在首次 get_instance 时创建。") - elif not cls._initialized: - # 实例已创建但可能未初始化完成(理论上不太可能发生,除非 __init__ 异常) - logger.warning("HeartFCController 实例存在但尚未完成初始化。") + # This might indicate an issue if called before initialization + logger.warning("HeartFCController get_instance called before initialization.") + # Optionally, initialize here if a strict singleton pattern is desired + # cls._instance = cls() return cls._instance - # --- 新增:检查 PFChatting 状态的方法 --- # - def is_pf_chatting_active(self, stream_id: str) -> bool: - """检查指定 stream_id 的 PFChatting 循环是否处于活动状态。""" - # 注意:这里直接访问字典,不加锁,因为读取通常是安全的, - # 并且 PFChatting 实例的 _loop_active 状态由其自身的异步循环管理。 - # 如果需要更强的保证,可以在访问 pf_instance 前获取 _pf_chatting_lock - pf_instance = self.pf_chatting_instances.get(stream_id) - if pf_instance and pf_instance._loop_active: # 直接检查 PFChatting 实例的 _loop_active 属性 - return True - return False - - # --- 结束新增 --- # + # --- End Added Class Method --- async def start(self): """启动异步任务,如回复启动器""" diff --git a/src/plugins/chat_module/heartFC_chat/heartFC_processor.py b/src/plugins/chat_module/heartFC_chat/heartFC_processor.py index f907a8be..44849f82 100644 --- a/src/plugins/chat_module/heartFC_chat/heartFC_processor.py +++ b/src/plugins/chat_module/heartFC_chat/heartFC_processor.py @@ -13,7 +13,6 @@ from ...chat.message_buffer import message_buffer from ...utils.timer_calculater import Timer from .interest import InterestManager from src.plugins.person_info.relationship_manager import relationship_manager -from .reasoning_chat import ReasoningChat # 定义日志配置 processor_config = LogConfig( @@ -30,7 +29,7 @@ class HeartFCProcessor: def __init__(self): self.storage = MessageStorage() self.interest_manager = InterestManager() - self.reasoning_chat = ReasoningChat.get_instance() + # self.chat_instance = chat_instance # 持有 HeartFC_Chat 实例 async def process_message(self, message_data: str) -> None: """处理接收到的原始消息数据,完成消息解析、缓冲、过滤、存储、兴趣度计算与更新等核心流程。 @@ -73,11 +72,11 @@ class HeartFCProcessor: user_info=userinfo, group_info=groupinfo, ) - - # --- 添加兴趣追踪启动 --- - # 在获取到 chat 对象后,启动对该聊天流的兴趣监控 - await self.reasoning_chat.start_monitoring_interest(chat) - # --- 结束添加 --- + if not chat: + logger.error( + f"无法为消息创建或获取聊天流: user {userinfo.user_id}, group {groupinfo.group_id if groupinfo else 'None'}" + ) + return message.update_chat_stream(chat) @@ -91,6 +90,7 @@ class HeartFCProcessor: message.raw_message, chat, userinfo ): return + logger.trace(f"过滤词/正则表达式过滤成功: {message.processed_plain_text}") # 查询缓冲器结果 buffer_result = await message_buffer.query_buffer_result(message) @@ -152,8 +152,6 @@ class HeartFCProcessor: f"使用激活率 {interested_rate:.2f} 更新后 (通过缓冲后),当前兴趣度: {current_interest:.2f}" ) - self.interest_manager.add_interest_dict(message, interested_rate, is_mentioned) - except Exception as e: logger.error(f"更新兴趣度失败: {e}") # 调整日志消息 logger.error(traceback.format_exc()) diff --git a/src/plugins/chat_module/heartFC_chat/interest.py b/src/plugins/chat_module/heartFC_chat/interest.py index 4ac5498a..5a961e91 100644 --- a/src/plugins/chat_module/heartFC_chat/interest.py +++ b/src/plugins/chat_module/heartFC_chat/interest.py @@ -6,7 +6,6 @@ import json # 引入 json import os # 引入 os from typing import Optional # <--- 添加导入 import random # <--- 添加导入 random -from src.plugins.chat.message import MessageRecv from src.common.logger import get_module_logger, LogConfig, DEFAULT_CONFIG # 引入 DEFAULT_CONFIG from src.plugins.chat.chat_stream import chat_manager # *** Import ChatManager *** @@ -67,13 +66,6 @@ class InterestChatting: self.is_above_threshold: bool = False # 标记兴趣值是否高于阈值 # --- 结束:概率回复相关属性 --- - # 记录激发兴趣对(消息id,激活值) - self.interest_dict = {} - - def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool): - # Store the MessageRecv object and the interest value as a tuple - self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) - def _calculate_decay(self, current_time: float): """计算从上次更新到现在的衰减""" time_delta = current_time - self.last_update_time @@ -453,10 +445,6 @@ class InterestManager: stream_name = chat_manager.get_stream_name(stream_id) or stream_id # 获取流名称 logger.warning(f"尝试降低不存在的聊天流 {stream_name} 的兴趣度") - def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool): - interest_chatting = self._get_or_create_interest_chatting(message.chat_stream.stream_id) - interest_chatting.add_interest_dict(message, interest_value, is_mentioned) - def cleanup_inactive_chats(self, max_age_seconds=INACTIVE_THRESHOLD_SECONDS): """ 清理长时间不活跃的聊天流记录 diff --git a/src/plugins/chat_module/heartFC_chat/messagesender.py b/src/plugins/chat_module/heartFC_chat/messagesender.py index 897bc45f..fb295bed 100644 --- a/src/plugins/chat_module/heartFC_chat/messagesender.py +++ b/src/plugins/chat_module/heartFC_chat/messagesender.py @@ -220,9 +220,10 @@ class MessageManager: await asyncio.sleep(typing_time) logger.debug(f"\n{message_earliest.processed_plain_text},{typing_time},等待输入时间结束\n") - await MessageSender().send_message(message_earliest) await self.storage.store_message(message_earliest, message_earliest.chat_stream) + await MessageSender().send_message(message_earliest) + container.remove_message(message_earliest) async def start_processor(self): diff --git a/src/plugins/chat_module/heartFC_chat/pf_chatting.py b/src/plugins/chat_module/heartFC_chat/pf_chatting.py index 92e3da54..2bb89987 100644 --- a/src/plugins/chat_module/heartFC_chat/pf_chatting.py +++ b/src/plugins/chat_module/heartFC_chat/pf_chatting.py @@ -15,9 +15,6 @@ from src.config.config import global_config from src.plugins.chat.utils_image import image_path_to_base64 # Local import needed after move from src.plugins.utils.timer_calculater import Timer # <--- Import Timer -INITIAL_DURATION = 60.0 - - # 定义日志配置 (使用 loguru 格式) interest_log_config = LogConfig( console_format=PFC_STYLE_CONFIG["console_format"], # 使用默认控制台格式 @@ -70,7 +67,7 @@ class PFChatting: Args: chat_id: The identifier for the chat stream (e.g., stream_id). - heartfc_controller_instance: 访问共享资源和方法的主HeartFCController实例。 + heartfc_controller_instance: 访问共享资源和方法的主HeartFC_Controller实例。 """ self.heartfc_controller = heartfc_controller_instance # Store the controller instance self.stream_id: str = chat_id @@ -94,7 +91,7 @@ class PFChatting: self._loop_active: bool = False # Is the loop currently running? self._loop_task: Optional[asyncio.Task] = None # Stores the main loop task self._trigger_count_this_activation: int = 0 # Counts triggers within an active period - self._initial_duration: float = INITIAL_DURATION # 首次触发增加的时间 + self._initial_duration: float = 60.0 # 首次触发增加的时间 self._last_added_duration: float = self._initial_duration # <--- 新增:存储上次增加的时间 def _get_log_prefix(self) -> str: diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py index 50613a98..2eb56c83 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_chat.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_chat.py @@ -157,17 +157,17 @@ class ReasoningChat: # 消息加入缓冲池 await message_buffer.start_caching_messages(message) + # logger.info("使用推理聊天模式") + # 创建聊天流 chat = await chat_manager.get_or_create_stream( platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo, ) - message.update_chat_stream(chat) await message.process() - logger.trace(f"消息处理成功: {message.processed_plain_text}") # 过滤词/正则表达式过滤 if self._check_ban_words(message.processed_plain_text, chat, userinfo) or self._check_ban_regex( @@ -175,13 +175,27 @@ class ReasoningChat: ): return + await self.storage.store_message(message, chat) + + # 记忆激活 + with Timer("记忆激活", timing_results): + interested_rate = await HippocampusManager.get_instance().get_activate_from_text( + message.processed_plain_text, fast_retrieval=True + ) + # 查询缓冲器结果,会整合前面跳过的消息,改变processed_plain_text buffer_result = await message_buffer.query_buffer_result(message) + # 处理提及 + is_mentioned, reply_probability = is_mentioned_bot_in_message(message) + + # 意愿管理器:设置当前message信息 + willing_manager.setup(message, chat, is_mentioned, interested_rate) + # 处理缓冲器结果 if not buffer_result: - # await willing_manager.bombing_buffer_message_handle(message.message_info.message_id) - # willing_manager.delete(message.message_info.message_id) + await willing_manager.bombing_buffer_message_handle(message.message_info.message_id) + willing_manager.delete(message.message_info.message_id) f_type = "seglist" if message.message_segment.type != "seglist": f_type = message.message_segment.type @@ -200,27 +214,6 @@ class ReasoningChat: logger.info("触发缓冲,已炸飞消息列") return - try: - await self.storage.store_message(message, chat) - logger.trace(f"存储成功 (通过缓冲后): {message.processed_plain_text}") - except Exception as e: - logger.error(f"存储消息失败: {e}") - logger.error(traceback.format_exc()) - # 存储失败可能仍需考虑是否继续,暂时返回 - return - - is_mentioned, reply_probability = is_mentioned_bot_in_message(message) - # 记忆激活 - with Timer("记忆激活", timing_results): - interested_rate = await HippocampusManager.get_instance().get_activate_from_text( - message.processed_plain_text, fast_retrieval=True - ) - - # 处理提及 - - # 意愿管理器:设置当前message信息 - willing_manager.setup(message, chat, is_mentioned, interested_rate) - # 获取回复概率 is_willing = False if reply_probability != 1: diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py index 2f4ba06e..dda4e7c7 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_generator.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_generator.py @@ -44,7 +44,7 @@ class ResponseGenerator: async def generate_response(self, message: MessageThinking, thinking_id: str) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 - if random.random() < global_config.model_reasoning_probability: + if random.random() < global_config.MODEL_R1_PROBABILITY: self.current_model_type = "深深地" current_model = self.model_reasoning else: diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index d37d6545..6d7c9ca1 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -1,9 +1,9 @@ import random import time -from typing import Optional, Union +from typing import Optional -from ....common.database import db -from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker +# from ....common.database import db +from ...chat.utils import get_recent_group_detailed_plain_text, get_recent_group_speaker from ...chat.chat_stream import chat_manager from ...moods.moods import MoodManager from ....individuality.individuality import Individuality @@ -13,6 +13,8 @@ from ....config.config import global_config from ...person_info.relationship_manager import relationship_manager from src.common.logger import get_module_logger from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager +from src.plugins.knowledge.knowledge_lib import qa_manager +from src.plugins.chat.chat_stream import ChatStream logger = get_module_logger("prompt") @@ -53,7 +55,7 @@ class PromptBuilder: self.activate_messages = "" async def _build_prompt( - self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None + self, chat_stream: ChatStream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None ) -> tuple[str, str]: # 开始构建prompt prompt_personality = "你" @@ -101,14 +103,16 @@ class PromptBuilder: related_memory = await HippocampusManager.get_instance().get_memory_from_text( text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False ) - related_memory_info = "" if related_memory: + related_memory_info = "" for memory in related_memory: related_memory_info += memory[1] # memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" memory_prompt = await global_prompt_manager.format_prompt( "memory_prompt", related_memory_info=related_memory_info ) + else: + related_memory_info = "" # print(f"相关记忆:{related_memory_info}") @@ -160,6 +164,7 @@ class PromptBuilder: # 知识构建 start_time = time.time() + prompt_info = "" prompt_info = await self.get_prompt_info(message_txt, threshold=0.38) if prompt_info: # prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" @@ -221,225 +226,12 @@ class PromptBuilder: return prompt async def get_prompt_info(self, message: str, threshold: float): - start_time = time.time() related_info = "" logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + related_info += qa_manager.get_knowledge(message) - # 1. 先从LLM获取主题,类似于记忆系统的做法 - topics = [] - # try: - # # 先尝试使用记忆系统的方法获取主题 - # hippocampus = HippocampusManager.get_instance()._hippocampus - # topic_num = min(5, max(1, int(len(message) * 0.1))) - # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) - - # # 提取关键词 - # topics = re.findall(r"<([^>]+)>", topics_response[0]) - # if not topics: - # topics = [] - # else: - # topics = [ - # topic.strip() - # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") - # if topic.strip() - # ] - - # logger.info(f"从LLM提取的主题: {', '.join(topics)}") - # except Exception as e: - # logger.error(f"从LLM提取主题失败: {str(e)}") - # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 - # words = jieba.cut(message) - # topics = [word for word in words if len(word) > 1][:5] - # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") - - # 如果无法提取到主题,直接使用整个消息 - if not topics: - logger.info("未能提取到任何主题,使用整个消息进行查询") - embedding = await get_embedding(message, request_type="prompt_build") - if not embedding: - logger.error("获取消息嵌入向量失败") - return "" - - related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") - return related_info - - # 2. 对每个主题进行知识库查询 - logger.info(f"开始处理{len(topics)}个主题的知识库查询") - - # 优化:批量获取嵌入向量,减少API调用 - embeddings = {} - topics_batch = [topic for topic in topics if len(topic) > 0] - if message: # 确保消息非空 - topics_batch.append(message) - - # 批量获取嵌入向量 - embed_start_time = time.time() - for text in topics_batch: - if not text or len(text.strip()) == 0: - continue - - try: - embedding = await get_embedding(text, request_type="prompt_build") - if embedding: - embeddings[text] = embedding - else: - logger.warning(f"获取'{text}'的嵌入向量失败") - except Exception as e: - logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") - - logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") - - if not embeddings: - logger.error("所有嵌入向量获取失败") - return "" - - # 3. 对每个主题进行知识库查询 - all_results = [] - query_start_time = time.time() - - # 首先添加原始消息的查询结果 - if message in embeddings: - original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) - if original_results: - for result in original_results: - result["topic"] = "原始消息" - all_results.extend(original_results) - logger.info(f"原始消息查询到{len(original_results)}条结果") - - # 然后添加每个主题的查询结果 - for topic in topics: - if not topic or topic not in embeddings: - continue - - try: - topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) - if topic_results: - # 添加主题标记 - for result in topic_results: - result["topic"] = topic - all_results.extend(topic_results) - logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") - except Exception as e: - logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") - - logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") - - # 4. 去重和过滤 - process_start_time = time.time() - unique_contents = set() - filtered_results = [] - for result in all_results: - content = result["content"] - if content not in unique_contents: - unique_contents.add(content) - filtered_results.append(result) - - # 5. 按相似度排序 - filtered_results.sort(key=lambda x: x["similarity"], reverse=True) - - # 6. 限制总数量(最多10条) - filtered_results = filtered_results[:10] - logger.info( - f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" - ) - - # 7. 格式化输出 - if filtered_results: - format_start_time = time.time() - grouped_results = {} - for result in filtered_results: - topic = result["topic"] - if topic not in grouped_results: - grouped_results[topic] = [] - grouped_results[topic].append(result) - - # 按主题组织输出 - for topic, results in grouped_results.items(): - related_info += f"【主题: {topic}】\n" - for _i, result in enumerate(results, 1): - _similarity = result["similarity"] - content = result["content"].strip() - # 调试:为内容添加序号和相似度信息 - # related_info += f"{i}. [{similarity:.2f}] {content}\n" - related_info += f"{content}\n" - related_info += "\n" - - logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") - - logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") return related_info - @staticmethod - def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - if not query_embedding: - return "" if not return_raw else [] - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - { - "$match": { - "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}}, - ] - - results = list(db.knowledges.aggregate(pipeline)) - logger.debug(f"知识库查询结果数量: {len(results)}") - - if not results: - return "" if not return_raw else [] - - if return_raw: - return results - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - init_prompt() prompt_builder = PromptBuilder() diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py.bak b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py.bak new file mode 100644 index 00000000..acfa20c6 --- /dev/null +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py.bak @@ -0,0 +1,454 @@ +import random +import time +from typing import Optional + +# from ....common.database import db +from ...chat.utils import get_recent_group_detailed_plain_text, get_recent_group_speaker +from ...chat.chat_stream import chat_manager +from ...moods.moods import MoodManager +from ....individuality.individuality import Individuality +from ...memory_system.Hippocampus import HippocampusManager +from ...schedule.schedule_generator import bot_schedule +from ...config.config import global_config +from ...person_info.relationship_manager import relationship_manager +from src.common.logger import get_module_logger +from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager +from src.plugins.knowledge.knowledge_lib import qa_manager + +logger = get_module_logger("prompt") + + +def init_prompt(): + Prompt( + """ +{relation_prompt_all} +{memory_prompt} +{prompt_info} +{schedule_prompt} +{chat_target} +{chat_talking_prompt} +现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n +你的网名叫{bot_name},有人也叫你{bot_other_names},{prompt_personality}。 +你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},然后给出日常且口语化的回复,平淡一些, +尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} +请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 +请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 +{moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""", + "reasoning_prompt_main", + ) + Prompt( + "{relation_prompt}关系等级越大,关系越好,请分析聊天记录,根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。", + "relationship_prompt", + ) + Prompt( + "你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n", + "memory_prompt", + ) + Prompt("你现在正在做的事情是:{schedule_info}", "schedule_prompt") + Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") + + +class PromptBuilder: + def __init__(self): + self.prompt_built = "" + self.activate_messages = "" + + async def _build_prompt( + self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None + ) -> tuple[str, str]: + # 开始构建prompt + prompt_personality = "你" + # person + individuality = Individuality.get_instance() + + personality_core = individuality.personality.personality_core + prompt_personality += personality_core + + personality_sides = individuality.personality.personality_sides + random.shuffle(personality_sides) + prompt_personality += f",{personality_sides[0]}" + + identity_detail = individuality.identity.identity_detail + random.shuffle(identity_detail) + prompt_personality += f",{identity_detail[0]}" + + # 关系 + who_chat_in_group = [ + (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) + ] + who_chat_in_group += get_recent_group_speaker( + stream_id, + (chat_stream.user_info.platform, chat_stream.user_info.user_id), + limit=global_config.MAX_CONTEXT_SIZE, + ) + + relation_prompt = "" + for person in who_chat_in_group: + relation_prompt += await relationship_manager.build_relationship_info(person) + + # relation_prompt_all = ( + # f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录," + # f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。" + # ) + + # 心情 + mood_manager = MoodManager.get_instance() + mood_prompt = mood_manager.get_prompt() + + # logger.info(f"心情prompt: {mood_prompt}") + + # 调取记忆 + memory_prompt = "" + related_memory = await HippocampusManager.get_instance().get_memory_from_text( + text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False + ) + if related_memory: + related_memory_info = "" + for memory in related_memory: + related_memory_info += memory[1] + # memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n" + memory_prompt = await global_prompt_manager.format_prompt( + "memory_prompt", related_memory_info=related_memory_info + ) + else: + related_memory_info = "" + + # print(f"相关记忆:{related_memory_info}") + + # 日程构建 + # schedule_prompt = f"""你现在正在做的事情是:{bot_schedule.get_current_num_task(num=1, time_info=False)}""" + + # 获取聊天上下文 + chat_in_group = True + chat_talking_prompt = "" + if stream_id: + chat_talking_prompt = get_recent_group_detailed_plain_text( + stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True + ) + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream.group_info: + chat_talking_prompt = chat_talking_prompt + else: + chat_in_group = False + chat_talking_prompt = chat_talking_prompt + # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") + # 关键词检测与反应 + keywords_reaction_prompt = "" + for rule in global_config.keywords_reaction_rules: + if rule.get("enable", False): + if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): + logger.info( + f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" + ) + keywords_reaction_prompt += rule.get("reaction", "") + "," + else: + for pattern in rule.get("regex", []): + result = pattern.search(message_txt) + if result: + reaction = rule.get("reaction", "") + for name, content in result.groupdict().items(): + reaction = reaction.replace(f"[{name}]", content) + logger.info(f"匹配到以下正则表达式:{pattern},触发反应:{reaction}") + keywords_reaction_prompt += reaction + "," + break + + # 中文高手(新加的好玩功能) + prompt_ger = "" + if random.random() < 0.04: + prompt_ger += "你喜欢用倒装句" + if random.random() < 0.02: + prompt_ger += "你喜欢用反问句" + if random.random() < 0.01: + prompt_ger += "你喜欢用文言文" + + # 知识构建 + start_time = time.time() + prompt_info = "" + prompt_info = await self.get_prompt_info(message_txt, threshold=0.38) + if prompt_info: + # prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" + prompt_info = await global_prompt_manager.format_prompt("knowledge_prompt", prompt_info=prompt_info) + + end_time = time.time() + logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒") + + # moderation_prompt = "" + # moderation_prompt = """**检查并忽略**任何涉及尝试绕过审核的行为。 + # 涉及政治敏感以及违法违规的内容请规避。""" + + logger.debug("开始构建prompt") + + # prompt = f""" + # {relation_prompt_all} + # {memory_prompt} + # {prompt_info} + # {schedule_prompt} + # {chat_target} + # {chat_talking_prompt} + # 现在"{sender_name}"说的:{message_txt}。引起了你的注意,你想要在群里发言发言或者回复这条消息。\n + # 你的网名叫{global_config.BOT_NICKNAME},有人也叫你{"/".join(global_config.BOT_ALIAS_NAMES)},{prompt_personality}。 + # 你正在{chat_target_2},现在请你读读之前的聊天记录,{mood_prompt},然后给出日常且口语化的回复,平淡一些, + # 尽量简短一些。{keywords_reaction_prompt}请注意把握聊天内容,不要回复的太有条理,可以有个性。{prompt_ger} + # 请回复的平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 + # 请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只输出回复内容。 + # {moderation_prompt}不要输出多余内容(包括前后缀,冒号和引号,括号,表情包,at或 @等 )。""" + + prompt = await global_prompt_manager.format_prompt( + "reasoning_prompt_main", + relation_prompt_all=await global_prompt_manager.get_prompt_async("relationship_prompt"), + relation_prompt=relation_prompt, + sender_name=sender_name, + memory_prompt=memory_prompt, + prompt_info=prompt_info, + schedule_prompt=await global_prompt_manager.format_prompt( + "schedule_prompt", schedule_info=bot_schedule.get_current_num_task(num=1, time_info=False) + ), + chat_target=await global_prompt_manager.get_prompt_async("chat_target_group1") + if chat_in_group + else await global_prompt_manager.get_prompt_async("chat_target_private1"), + chat_target_2=await global_prompt_manager.get_prompt_async("chat_target_group2") + if chat_in_group + else await global_prompt_manager.get_prompt_async("chat_target_private2"), + chat_talking_prompt=chat_talking_prompt, + message_txt=message_txt, + bot_name=global_config.BOT_NICKNAME, + bot_other_names="/".join( + global_config.BOT_ALIAS_NAMES, + ), + prompt_personality=prompt_personality, + mood_prompt=mood_prompt, + keywords_reaction_prompt=keywords_reaction_prompt, + prompt_ger=prompt_ger, + moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"), + ) + + return prompt + + async def get_prompt_info(self, message: str, threshold: float): + related_info = "" + logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + related_info += qa_manager.get_knowledge(message) + + return related_info + # start_time = time.time() + # related_info = "" + # logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + + # # 1. 先从LLM获取主题,类似于记忆系统的做法 + # topics = [] + # try: + # # 先尝试使用记忆系统的方法获取主题 + # hippocampus = HippocampusManager.get_instance()._hippocampus + # topic_num = min(5, max(1, int(len(message) * 0.1))) + # topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) + + # # 提取关键词 + # topics = re.findall(r"<([^>]+)>", topics_response[0]) + # if not topics: + # topics = [] + # else: + # topics = [ + # topic.strip() + # for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + # if topic.strip() + # ] + + # logger.info(f"从LLM提取的主题: {', '.join(topics)}") + # except Exception as e: + # logger.error(f"从LLM提取主题失败: {str(e)}") + # # 如果LLM提取失败,使用jieba分词提取关键词作为备选 + # words = jieba.cut(message) + # topics = [word for word in words if len(word) > 1][:5] + # logger.info(f"使用jieba提取的主题: {', '.join(topics)}") + + # 如果无法提取到主题,直接使用整个消息 + # if not topics: + # logger.info("未能提取到任何主题,使用整个消息进行查询") + # embedding = await get_embedding(message, request_type="prompt_build") + # if not embedding: + # logger.error("获取消息嵌入向量失败") + # return "" + + # related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + # logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") + # return related_info + + # # 2. 对每个主题进行知识库查询 + # logger.info(f"开始处理{len(topics)}个主题的知识库查询") + + # # 优化:批量获取嵌入向量,减少API调用 + # embeddings = {} + # topics_batch = [topic for topic in topics if len(topic) > 0] + # if message: # 确保消息非空 + # topics_batch.append(message) + + # 批量获取嵌入向量 + # embed_start_time = time.time() + # for text in topics_batch: + # if not text or len(text.strip()) == 0: + # continue + + # try: + # embedding = await get_embedding(text, request_type="prompt_build") + # if embedding: + # embeddings[text] = embedding + # else: + # logger.warning(f"获取'{text}'的嵌入向量失败") + # except Exception as e: + # logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") + + # logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") + + # if not embeddings: + # logger.error("所有嵌入向量获取失败") + # return "" + + # # 3. 对每个主题进行知识库查询 + # all_results = [] + # query_start_time = time.time() + + # # 首先添加原始消息的查询结果 + # if message in embeddings: + # original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) + # if original_results: + # for result in original_results: + # result["topic"] = "原始消息" + # all_results.extend(original_results) + # logger.info(f"原始消息查询到{len(original_results)}条结果") + + # # 然后添加每个主题的查询结果 + # for topic in topics: + # if not topic or topic not in embeddings: + # continue + + # try: + + # # topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) + # # if topic_results: + # # # 添加主题标记 + # # for result in topic_results: + # # result["topic"] = topic + # # all_results.extend(topic_results) + # # logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") + # except Exception as e: + # logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") + + # logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") + + # # 4. 去重和过滤 + # process_start_time = time.time() + # unique_contents = set() + # filtered_results = [] + # for result in all_results: + # content = result["content"] + # if content not in unique_contents: + # unique_contents.add(content) + # filtered_results.append(result) + + # # 5. 按相似度排序 + # filtered_results.sort(key=lambda x: x["similarity"], reverse=True) + + # # 6. 限制总数量(最多10条) + # filtered_results = filtered_results[:10] + # logger.info( + # f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" + # ) + + # # 7. 格式化输出 + # if filtered_results: + # format_start_time = time.time() + # grouped_results = {} + # for result in filtered_results: + # topic = result["topic"] + # if topic not in grouped_results: + # grouped_results[topic] = [] + # grouped_results[topic].append(result) + + # 按主题组织输出 + # for topic, results in grouped_results.items(): + # related_info += f"【主题: {topic}】\n" + # for _i, result in enumerate(results, 1): + # _similarity = result["similarity"] + # content = result["content"].strip() + # # 调试:为内容添加序号和相似度信息 + # # related_info += f"{i}. [{similarity:.2f}] {content}\n" + # related_info += f"{content}\n" + # related_info += "\n" + + # # logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") + + # logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") + # return related_info + + # def get_info_from_db( + # self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False + # ) -> Union[str, list]: + # if not query_embedding: + # return "" if not return_raw else [] + # # 使用余弦相似度计算 + # pipeline = [ + # { + # "$addFields": { + # "dotProduct": { + # "$reduce": { + # "input": {"$range": [0, {"$size": "$embedding"}]}, + # "initialValue": 0, + # "in": { + # "$add": [ + # "$$value", + # { + # "$multiply": [ + # {"$arrayElemAt": ["$embedding", "$$this"]}, + # {"$arrayElemAt": [query_embedding, "$$this"]}, + # ] + # }, + # ] + # }, + # } + # }, + # "magnitude1": { + # "$sqrt": { + # "$reduce": { + # "input": "$embedding", + # "initialValue": 0, + # "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + # } + # } + # }, + # "magnitude2": { + # "$sqrt": { + # "$reduce": { + # "input": query_embedding, + # "initialValue": 0, + # "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, + # } + # } + # }, + # } + # }, + # {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, + # { + # "$match": { + # "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + # } + # }, + # {"$sort": {"similarity": -1}}, + # {"$limit": limit}, + # {"$project": {"content": 1, "similarity": 1}}, + # ] + + # results = list(db.knowledges.aggregate(pipeline)) + # logger.debug(f"知识库查询结果数量: {len(results)}") + + # if not results: + # return "" if not return_raw else [] + + # if return_raw: + # return results + # else: + # # 返回所有找到的内容,用换行分隔 + # return "\n".join(str(result["content"]) for result in results) + + +init_prompt() +prompt_builder = PromptBuilder() diff --git a/src/plugins/knowledge/LICENSE b/src/plugins/knowledge/LICENSE new file mode 100644 index 00000000..f288702d --- /dev/null +++ b/src/plugins/knowledge/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/src/plugins/knowledge/__init__.py b/src/plugins/knowledge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/plugins/knowledge/knowledge_lib.py b/src/plugins/knowledge/knowledge_lib.py new file mode 100644 index 00000000..31167391 --- /dev/null +++ b/src/plugins/knowledge/knowledge_lib.py @@ -0,0 +1,65 @@ +from .src.lpmmconfig import PG_NAMESPACE, global_config +from .src.embedding_store import EmbeddingManager +from .src.llm_client import LLMClient +from .src.mem_active_manager import MemoryActiveManager +from .src.qa_manager import QAManager +from .src.kg_manager import KGManager +from .src.global_logger import logger +# try: +# import quick_algo +# except ImportError: +# print("quick_algo not found, please install it first") + +logger.info("正在初始化Mai-LPMM\n") +logger.info("创建LLM客户端") +llm_client_list = dict() +for key in global_config["llm_providers"]: + llm_client_list[key] = LLMClient( + global_config["llm_providers"][key]["base_url"], + global_config["llm_providers"][key]["api_key"], + ) + +# 初始化Embedding库 +embed_manager = EmbeddingManager( + llm_client_list[global_config["embedding"]["provider"]] +) +logger.info("正在从文件加载Embedding库") +try: + embed_manager.load_from_file() +except Exception as e: + logger.error("从文件加载Embedding库时发生错误:{}".format(e)) +logger.info("Embedding库加载完成") +# 初始化KG +kg_manager = KGManager() +logger.info("正在从文件加载KG") +try: + kg_manager.load_from_file() +except Exception as e: + logger.error("从文件加载KG时发生错误:{}".format(e)) +logger.info("KG加载完成") + +logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") +logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") + + +# 数据比对:Embedding库与KG的段落hash集合 +for pg_hash in kg_manager.stored_paragraph_hashes: + key = PG_NAMESPACE + "-" + pg_hash + if key not in embed_manager.stored_pg_hashes: + logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") + +# 问答系统(用于知识库) +qa_manager = QAManager( + embed_manager, + kg_manager, + llm_client_list[global_config["embedding"]["provider"]], + llm_client_list[global_config["qa"]["llm"]["provider"]], + llm_client_list[global_config["qa"]["llm"]["provider"]], +) + +# 记忆激活(用于记忆库) +inspire_manager = MemoryActiveManager( + embed_manager, + llm_client_list[global_config["embedding"]["provider"]], +) + diff --git a/src/plugins/knowledge/src/__init__.py b/src/plugins/knowledge/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py new file mode 100644 index 00000000..59c804fa --- /dev/null +++ b/src/plugins/knowledge/src/embedding_store.py @@ -0,0 +1,251 @@ +from dataclasses import dataclass +import json +import os +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd +import tqdm +import faiss + +from .llm_client import LLMClient +from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config +from .utils.hash import get_sha256 +from .global_logger import logger + + +@dataclass +class EmbeddingStoreItem: + """嵌入库中的项""" + + def __init__(self, item_hash: str, embedding: List[float], content: str): + self.hash = item_hash + self.embedding = embedding + self.str = content + + def to_dict(self) -> dict: + """转为dict""" + return { + "hash": self.hash, + "embedding": self.embedding, + "str": self.str, + } + + +class EmbeddingStore: + def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str): + self.namespace = namespace + self.llm_client = llm_client + self.dir = dir_path + self.embedding_file_path = dir_path + "/" + namespace + ".parquet" + self.index_file_path = dir_path + "/" + namespace + ".index" + self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" + + self.store = dict() + + self.faiss_index = None + self.idx2hash = None + + def _get_embedding(self, s: str) -> List[float]: + return self.llm_client.send_embedding_request( + global_config["embedding"]["model"], s + ) + + def batch_insert_strs(self, strs: List[str]) -> None: + """向库中存入字符串""" + # 逐项处理 + for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"): + # 计算hash去重 + item_hash = self.namespace + "-" + get_sha256(s) + if item_hash in self.store: + continue + + # 获取embedding + embedding = self._get_embedding(s) + + # 存入 + self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + + def save_to_file(self) -> None: + """保存到文件""" + data = [] + logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}") + for item in self.store.values(): + data.append(item.to_dict()) + data_frame = pd.DataFrame(data) + + if not os.path.exists(self.dir): + os.makedirs(self.dir, exist_ok=True) + if not os.path.exists(self.embedding_file_path): + open(self.embedding_file_path, "w").close() + + data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False) + logger.info(f"{self.namespace}嵌入库保存成功") + + if self.faiss_index is not None and self.idx2hash is not None: + logger.info( + f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}" + ) + faiss.write_index(self.faiss_index, self.index_file_path) + logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功") + logger.info( + f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}" + ) + with open(self.idx2hash_file_path, "w", encoding="utf-8") as f: + f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4)) + logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功") + + def load_from_file(self) -> None: + """从文件中加载""" + if not os.path.exists(self.embedding_file_path): + raise Exception(f"文件{self.embedding_file_path}不存在") + + logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") + data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") + for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): + self.store[row["hash"]] = EmbeddingStoreItem( + row["hash"], row["embedding"], row["str"] + ) + logger.info(f"{self.namespace}嵌入库加载成功") + + try: + if os.path.exists(self.index_file_path): + logger.info( + f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex" + ) + self.faiss_index = faiss.read_index(self.index_file_path) + logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功") + else: + raise Exception(f"文件{self.index_file_path}不存在") + if os.path.exists(self.idx2hash_file_path): + logger.info( + f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射" + ) + with open(self.idx2hash_file_path, "r") as f: + self.idx2hash = json.load(f) + logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功") + else: + raise Exception(f"文件{self.idx2hash_file_path}不存在") + except Exception as e: + logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误:{e}") + logger.warning("正在重建Faiss索引") + self.build_faiss_index() + logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功") + self.save_to_file() + + def build_faiss_index(self) -> None: + """重新构建Faiss索引,以余弦相似度为度量""" + # 获取所有的embedding + array = [] + self.idx2hash = dict() + for key in self.store: + array.append(self.store[key].embedding) + self.idx2hash[str(len(array) - 1)] = key + embeddings = np.array(array, dtype=np.float32) + # L2归一化 + faiss.normalize_L2(embeddings) + # 构建索引 + self.faiss_index = faiss.IndexFlatIP(global_config["embedding"]["dimension"]) + self.faiss_index.add(embeddings) + + def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]: + """搜索最相似的k个项,以余弦相似度为度量 + Args: + query: 查询的embedding + k: 返回的最相似的k个项 + Returns: + result: 最相似的k个项的(hash, 余弦相似度)列表 + """ + if self.faiss_index is None: + raise Exception("Faiss索引尚未构建") + if self.idx2hash is None: + raise Exception("idx2hash映射尚未构建") + + # L2归一化 + faiss.normalize_L2(np.array([query], dtype=np.float32)) + # 搜索 + distances, indices = self.faiss_index.search(np.array([query]), k) + # 整理结果 + indices = list(indices.flatten()) + distances = list(distances.flatten()) + result = [ + (self.idx2hash[str(int(idx))], float(sim)) + for (idx, sim) in zip(indices, distances) + if idx in range(len(self.idx2hash)) + ] + + return result + + +class EmbeddingManager: + def __init__(self, llm_client: LLMClient): + self.paragraphs_embedding_store = EmbeddingStore( + llm_client, + PG_NAMESPACE, + global_config["persistence"]["embedding_data_dir"], + ) + self.entities_embedding_store = EmbeddingStore( + llm_client, + ENT_NAMESPACE, + global_config["persistence"]["embedding_data_dir"], + ) + self.relation_embedding_store = EmbeddingStore( + llm_client, + REL_NAMESPACE, + global_config["persistence"]["embedding_data_dir"], + ) + self.stored_pg_hashes = set() + + def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): + """将段落编码存入Embedding库""" + self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values())) + + def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + """将实体编码存入Embedding库""" + entities = set() + for triple_list in triple_list_data.values(): + for triple in triple_list: + entities.add(triple[0]) + entities.add(triple[2]) + self.entities_embedding_store.batch_insert_strs(list(entities)) + + def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): + """将关系编码存入Embedding库""" + graph_triples = [] # a list of unique relation triple (in tuple) from all chunks + for triples in triple_list_data.values(): + graph_triples.extend([tuple(t) for t in triples]) + graph_triples = list(set(graph_triples)) + self.relation_embedding_store.batch_insert_strs( + [str(triple) for triple in graph_triples] + ) + + def load_from_file(self): + """从文件加载""" + self.paragraphs_embedding_store.load_from_file() + self.entities_embedding_store.load_from_file() + self.relation_embedding_store.load_from_file() + # 从段落库中获取已存储的hash + self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys()) + + def store_new_data_set( + self, + raw_paragraphs: Dict[str, str], + triple_list_data: Dict[str, List[List[str]]], + ): + """存储新的数据集""" + self._store_pg_into_embedding(raw_paragraphs) + self._store_ent_into_embedding(triple_list_data) + self._store_rel_into_embedding(triple_list_data) + self.stored_pg_hashes.update(raw_paragraphs.keys()) + + def save_to_file(self): + """保存到文件""" + self.paragraphs_embedding_store.save_to_file() + self.entities_embedding_store.save_to_file() + self.relation_embedding_store.save_to_file() + + def rebuild_faiss_index(self): + """重建Faiss索引(请在添加新数据后调用)""" + self.paragraphs_embedding_store.build_faiss_index() + self.entities_embedding_store.build_faiss_index() + self.relation_embedding_store.build_faiss_index() diff --git a/src/plugins/knowledge/src/global_logger.py b/src/plugins/knowledge/src/global_logger.py new file mode 100644 index 00000000..0311db5f --- /dev/null +++ b/src/plugins/knowledge/src/global_logger.py @@ -0,0 +1,14 @@ +# Configure logger + +import logging + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +console_logging_handler = logging.StreamHandler() +console_logging_handler.setFormatter( + logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") +) +console_logging_handler.setLevel(logging.DEBUG) +logger.addHandler(console_logging_handler) diff --git a/src/plugins/knowledge/src/ie_process.py b/src/plugins/knowledge/src/ie_process.py new file mode 100644 index 00000000..5da9ad9e --- /dev/null +++ b/src/plugins/knowledge/src/ie_process.py @@ -0,0 +1,108 @@ +import json +import time +from typing import List + +from .global_logger import logger +from . import prompt_template +from .lpmmconfig import global_config, INVALID_ENTITY +from .llm_client import LLMClient +from .utils.json_fix import fix_broken_generated_json + + +def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: + """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" + entity_extract_context = prompt_template.build_entity_extract_context(paragraph) + _, request_result = llm_client.send_chat_request( + global_config["entity_extract"]["llm"]["model"], entity_extract_context + ) + + # 去除‘{’前的内容(结果中可能有多个‘{’) + if "[" in request_result: + request_result = request_result[request_result.index("[") :] + + # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) + if "]" in request_result: + request_result = request_result[: request_result.rindex("]") + 1] + + entity_extract_result = json.loads(fix_broken_generated_json(request_result)) + + entity_extract_result = [ + entity + for entity in entity_extract_result + if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY) + ] + + if len(entity_extract_result) == 0: + raise Exception("实体提取结果为空") + + return entity_extract_result + + +def _rdf_triple_extract( + llm_client: LLMClient, paragraph: str, entities: list +) -> List[List[str]]: + """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" + entity_extract_context = prompt_template.build_rdf_triple_extract_context( + paragraph, entities=json.dumps(entities, ensure_ascii=False) + ) + _, request_result = llm_client.send_chat_request( + global_config["rdf_build"]["llm"]["model"], entity_extract_context + ) + + # 去除‘{’前的内容(结果中可能有多个‘{’) + if "[" in request_result: + request_result = request_result[request_result.index("[") :] + + # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) + if "]" in request_result: + request_result = request_result[: request_result.rindex("]") + 1] + + entity_extract_result = json.loads(fix_broken_generated_json(request_result)) + + for triple in entity_extract_result: + if ( + len(triple) != 3 + or (triple[0] is None or triple[1] is None or triple[2] is None) + or "" in triple + ): + raise Exception("RDF提取结果格式错误") + + return entity_extract_result + + +def info_extract_from_str( + llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str +) -> tuple[None, None] | tuple[list[str], list[list[str]]]: + try_count = 0 + while True: + try: + entity_extract_result = _entity_extract(llm_client_for_ner, paragraph) + break + except Exception as e: + logger.warning(f"实体提取失败,错误信息:{e}") + try_count += 1 + if try_count < 3: + logger.warning("将于5秒后重试") + time.sleep(5) + else: + logger.error("实体提取失败,已达最大重试次数") + return None, None + + try_count = 0 + while True: + try: + rdf_triple_extract_result = _rdf_triple_extract( + llm_client_for_rdf, paragraph, entity_extract_result + ) + break + except Exception as e: + logger.warning(f"实体提取失败,错误信息:{e}") + try_count += 1 + if try_count < 3: + logger.warning("将于5秒后重试") + time.sleep(5) + else: + logger.error("实体提取失败,已达最大重试次数") + return None, None + + return entity_extract_result, rdf_triple_extract_result diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py new file mode 100644 index 00000000..4fcdcf80 --- /dev/null +++ b/src/plugins/knowledge/src/kg_manager.py @@ -0,0 +1,436 @@ +import json +import os +import time +from typing import Dict, List, Tuple + +import numpy as np +import pandas as pd +import tqdm +from quick_algo import di_graph, pagerank + + +from .utils.hash import get_sha256 +from .embedding_store import EmbeddingManager, EmbeddingStoreItem +from .lpmmconfig import ( + ENT_NAMESPACE, + PG_NAMESPACE, + RAG_ENT_CNT_NAMESPACE, + RAG_GRAPH_NAMESPACE, + RAG_PG_HASH_NAMESPACE, + global_config, +) + +from .global_logger import logger + +class KGManager: + def __init__(self): + # 会被保存的字段 + # 存储段落的hash值,用于去重 + self.stored_paragraph_hashes = set() + # 实体出现次数 + self.ent_appear_cnt = dict() + # KG + self.graph = di_graph.DiGraph() + + # 持久化相关 + self.dir_path = global_config["persistence"]["rag_data_dir"] + self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" + self.ent_cnt_data_path = ( + self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" + ) + self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" + + def save_to_file(self): + """将KG数据保存到文件""" + # 确保目录存在 + if not os.path.exists(self.dir_path): + os.makedirs(self.dir_path, exist_ok=True) + + # 保存KG + di_graph.save_to_file(self.graph, self.graph_data_path) + + # 保存实体计数到文件 + ent_cnt_df = pd.DataFrame( + [{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()] + ) + ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) + + # 保存段落hash到文件 + with open(self.pg_hash_file_path, "w", encoding="utf-8") as f: + data = {"stored_paragraph_hashes": list(self.stored_paragraph_hashes)} + f.write(json.dumps(data, ensure_ascii=False, indent=4)) + + def load_from_file(self): + """从文件加载KG数据""" + # 确保文件存在 + if not os.path.exists(self.pg_hash_file_path): + raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在") + if not os.path.exists(self.ent_cnt_data_path): + raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在") + if not os.path.exists(self.graph_data_path): + raise Exception(f"KG图文件{self.graph_data_path}不存在") + + # 加载段落hash + with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: + data = json.load(f) + self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"]) + + # 加载实体计数 + ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") + self.ent_appear_cnt = dict( + {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()} + ) + + # 加载KG + self.graph = di_graph.load_from_file(self.graph_data_path) + + def _build_edges_between_ent( + self, + node_to_node: Dict[Tuple[str, str], float], + triple_list_data: Dict[str, List[List[str]]], + ): + """构建实体节点之间的关系,同时统计实体出现次数""" + for triple_list in triple_list_data.values(): + entity_set = set() + for triple in triple_list: + if triple[0] == triple[2]: + # 避免自连接 + continue + # 一个triple就是一条边(同时构建双向联系) + hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) + hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) + node_to_node[(hash_key1, hash_key2)] = ( + node_to_node.get((hash_key1, hash_key2), 0) + 1.0 + ) + node_to_node[(hash_key2, hash_key1)] = ( + node_to_node.get((hash_key2, hash_key1), 0) + 1.0 + ) + entity_set.add(hash_key1) + entity_set.add(hash_key2) + + # 实体出现次数统计 + for hash_key in entity_set: + self.ent_appear_cnt[hash_key] = ( + self.ent_appear_cnt.get(hash_key, 0) + 1.0 + ) + + @staticmethod + def _build_edges_between_ent_pg( + node_to_node: Dict[Tuple[str, str], float], + triple_list_data: Dict[str, List[List[str]]], + ): + """构建实体节点与文段节点之间的关系""" + for idx in triple_list_data: + for triple in triple_list_data[idx]: + ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) + pg_hash_key = PG_NAMESPACE + "-" + str(idx) + node_to_node[(ent_hash_key, pg_hash_key)] = ( + node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 + ) + + @staticmethod + def _synonym_connect( + node_to_node: Dict[Tuple[str, str], float], + triple_list_data: Dict[str, List[List[str]]], + embedding_manager: EmbeddingManager, + ) -> int: + """同义词连接""" + new_edge_cnt = 0 + # 获取所有实体节点的hash值 + ent_hash_list = set() + for triple_list in triple_list_data.values(): + for triple in triple_list: + ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0])) + ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2])) + ent_hash_list = list(ent_hash_list) + + synonym_hash_set = set() + + synonym_result = dict() + + # 对每个实体节点,查找其相似的实体节点,建立扩展连接 + for ent_hash in tqdm.tqdm(ent_hash_list): + if ent_hash in synonym_hash_set: + # 避免同一批次内重复添加 + continue + ent = embedding_manager.entities_embedding_store.store.get(ent_hash) + assert isinstance(ent, EmbeddingStoreItem) + if ent is None: + continue + # 查询相似实体 + similar_ents = embedding_manager.entities_embedding_store.search_top_k( + ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] + ) + res_ent = [] # Debug + for res_ent_hash, similarity in similar_ents: + if res_ent_hash == ent_hash: + # 避免自连接 + continue + if similarity < global_config["rag"]["params"]["synonym_threshold"]: + # 相似度阈值 + continue + node_to_node[(res_ent_hash, ent_hash)] = similarity + node_to_node[(ent_hash, res_ent_hash)] = similarity + synonym_hash_set.add(res_ent_hash) + new_edge_cnt += 1 + res_ent.append( + ( + embedding_manager.entities_embedding_store.store[ + res_ent_hash + ].str, + similarity, + ) + ) # Debug + synonym_result[ent.str] = res_ent + + for k, v in synonym_result.items(): + print(f'"{k}"的相似实体为:{v}') + return new_edge_cnt + + def _update_graph( + self, + node_to_node: Dict[Tuple[str, str], float], + embedding_manager: EmbeddingManager, + ): + """更新KG图结构 + + 流程: + 1. 更新图结构:遍历所有待添加的新边 + - 若是新边,则添加到图中 + - 若是已存在的边,则更新边的权重 + 2. 更新新节点的属性 + """ + existed_nodes = self.graph.get_node_list() + existed_edges = [str((edge[0], edge[1])) for edge in self.graph.get_edge_list()] + + now_time = time.time() + + # 更新图结构 + for src_tgt, weight in node_to_node.items(): + key = str(src_tgt) + # 检查边是否已存在 + if key not in existed_edges: + # 新边 + self.graph.add_edge( + di_graph.DiEdge( + src_tgt[0], + src_tgt[1], + { + "weight": weight, + "create_time": now_time, + "update_time": now_time, + }, + ) + ) + else: + # 已存在的边 + edge_item = self.graph[src_tgt[0], src_tgt[1]] + edge_item["weight"] += weight + edge_item["update_time"] = now_time + self.graph.update_edge(edge_item) + + # 更新新节点属性 + for src_tgt in node_to_node.keys(): + for node_hash in src_tgt: + if node_hash not in existed_nodes: + if node_hash.startswith(ENT_NAMESPACE): + # 新增实体节点 + node = embedding_manager.entities_embedding_store.store[ + node_hash + ] + assert isinstance(node, EmbeddingStoreItem) + node_item = self.graph[node_hash] + node_item["content"] = node.str + node_item["type"] = "ent" + node_item["create_time"] = now_time + self.graph.update_node(node_item) + elif node_hash.startswith(PG_NAMESPACE): + # 新增文段节点 + node = embedding_manager.paragraphs_embedding_store.store[ + node_hash + ] + assert isinstance(node, EmbeddingStoreItem) + content = node.str.replace("\n", " ") + node_item = self.graph[node_hash] + node_item["content"] = ( + content if len(content) < 8 else content[:8] + "..." + ) + node_item["type"] = "pg" + node_item["create_time"] = now_time + self.graph.update_node(node_item) + + def build_kg( + self, + triple_list_data: Dict[str, List[List[str]]], + embedding_manager: EmbeddingManager, + ): + """增量式构建KG + + 注意:应当在调用该方法后保存KG + + Args: + triple_list_data: 三元组数据 + embedding_manager: EmbeddingManager对象 + """ + # 实体之间的联系 + node_to_node = dict() + + # 构建实体节点之间的关系,同时统计实体出现次数 + logger.info("正在构建KG实体节点之间的关系,同时统计实体出现次数") + # 从三元组提取实体对 + self._build_edges_between_ent(node_to_node, triple_list_data) + + # 构建实体节点与文段节点之间的关系 + logger.info("正在构建KG实体节点与文段节点之间的关系") + self._build_edges_between_ent_pg(node_to_node, triple_list_data) + + # 近义词扩展链接 + # 对每个实体节点,找到最相似的实体节点,建立扩展连接 + logger.info("正在进行近义词扩展链接") + self._synonym_connect(node_to_node, triple_list_data, embedding_manager) + + # 构建图 + self._update_graph(node_to_node, embedding_manager) + + # 记录已处理(存储)的段落hash + for idx in triple_list_data: + self.stored_paragraph_hashes.add(str(idx)) + + def kg_search( + self, + relation_search_result: List[Tuple[Tuple[str, str, str], float]], + paragraph_search_result: List[Tuple[str, float]], + embed_manager: EmbeddingManager, + ): + """RAG搜索与PageRank + + Args: + relation_search_result: RelationEmbedding的搜索结果(relation_tripple, similarity) + paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity) + embed_manager: EmbeddingManager对象 + """ + # 图中存在的节点总集 + existed_nodes = self.graph.get_node_list() + + # 准备PPR使用的数据 + # 节点权重:实体 + ent_weights = {} + # 节点权重:文段 + pg_weights = {} + + # 以下部分处理实体权重ent_weights + + # 针对每个关系,提取出其中的主宾短语作为两个实体,并记录对应的三元组的相似度作为权重依据 + ent_sim_scores = {} + for relation_hash, similarity, _ in relation_search_result: + # 提取主宾短语 + relation = embed_manager.relation_embedding_store.store.get( + relation_hash + ).str + assert relation is not None # 断言:relation不为空 + # 关系三元组 + triple = relation[2:-2].split("', '") + for ent in [(triple[0]), (triple[2])]: + ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent) + if ent_hash in existed_nodes: # 该实体需在KG中存在 + if ent_hash not in ent_sim_scores: # 尚未记录的实体 + ent_sim_scores[ent_hash] = [] + ent_sim_scores[ent_hash].append(similarity) + + ent_mean_scores = {} # 记录实体的平均相似度 + for ent_hash, scores in ent_sim_scores.items(): + # 先对相似度进行累加,然后与实体计数相除获取最终权重 + ent_weights[ent_hash] = ( + float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] + ) + # 记录实体的平均相似度,用于后续的top_k筛选 + ent_mean_scores[ent_hash] = float(np.mean(scores)) + del ent_sim_scores + + ent_weights_max = max(ent_weights.values()) + ent_weights_min = min(ent_weights.values()) + if ent_weights_max == ent_weights_min: + # 只有一个相似度,则全赋值为1 + for ent_hash in ent_weights.keys(): + ent_weights[ent_hash] = 1.0 + else: + down_edge = global_config["qa"]["params"]["paragraph_node_weight"] + # 缩放取值区间至[down_edge, 1] + for ent_hash, score in ent_weights.items(): + # 缩放相似度 + ent_weights[ent_hash] = ( + (score - ent_weights_min) + * (1 - down_edge) + / (ent_weights_max - ent_weights_min) + ) + down_edge + + # 取平均相似度的top_k实体 + top_k = global_config["qa"]["params"]["ent_filter_top_k"] + if len(ent_mean_scores) > top_k: + # 从大到小排序,取后len - k个 + ent_mean_scores = { + k: v + for k, v in sorted( + ent_mean_scores.items(), key=lambda item: item[1], reverse=True + ) + } + for ent_hash, _ in ent_mean_scores.items(): + # 删除被淘汰的实体节点权重设置 + del ent_weights[ent_hash] + del top_k, ent_mean_scores + + # 以下部分处理文段权重pg_weights + + # 将搜索结果中文段的相似度归一化作为权重 + pg_sim_scores = {} + pg_sim_score_max = 0.0 + pg_sim_score_min = 1.0 + for pg_hash, similarity in paragraph_search_result: + # 查找最大和最小值 + pg_sim_score_max = max(pg_sim_score_max, similarity) + pg_sim_score_min = min(pg_sim_score_min, similarity) + pg_sim_scores[pg_hash] = similarity + + # 归一化 + for pg_hash, similarity in pg_sim_scores.items(): + # 归一化相似度 + pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / ( + pg_sim_score_max - pg_sim_score_min + ) + del pg_sim_score_max, pg_sim_score_min + + for pg_hash, score in pg_sim_scores.items(): + pg_weights[pg_hash] = ( + score * global_config["qa"]["params"]["paragraph_node_weight"] + ) # 文段权重 = 归一化相似度 * 文段节点权重参数 + del pg_sim_scores + + # 最终权重数据 = 实体权重 + 文段权重 + ppr_node_weights = { + k: v for d in [ent_weights, pg_weights] for k, v in d.items() + } + del ent_weights, pg_weights + + # PersonalizedPageRank + ppr_res = pagerank.run_pagerank( + self.graph, + personalization=ppr_node_weights, + max_iter=100, + alpha=global_config["qa"]["params"]["ppr_damping"], + ) + + # 获取最终结果 + # 从搜索结果中提取文段节点的结果 + passage_node_res = [ + (node_key, score) + for node_key, score in ppr_res.items() + if node_key.startswith(PG_NAMESPACE) + ] + del ppr_res + + # 排序:按照分数从大到小 + passage_node_res = sorted( + passage_node_res, key=lambda item: item[1], reverse=True + ) + + return passage_node_res, ppr_node_weights diff --git a/src/plugins/knowledge/src/llm_client.py b/src/plugins/knowledge/src/llm_client.py new file mode 100644 index 00000000..3036662c --- /dev/null +++ b/src/plugins/knowledge/src/llm_client.py @@ -0,0 +1,53 @@ +from openai import OpenAI + + +class LLMMessage: + def __init__(self, role, content): + self.role = role + self.content = content + + def to_dict(self): + return {"role": self.role, "content": self.content} + + +class LLMClient: + """LLM客户端,对应一个API服务商""" + + def __init__(self, url, api_key): + self.client = OpenAI( + base_url=url, + api_key=api_key, + ) + + def send_chat_request(self, model, messages): + """发送对话请求,等待返回结果""" + response = self.client.chat.completions.create( + model=model, messages=messages, stream=False + ) + if hasattr(response.choices[0].message, "reasoning_content"): + # 有单独的推理内容块 + reasoning_content = response.choices[0].message.reasoning_content + content = response.choices[0].message.content + else: + # 无单独的推理内容块 + response = ( + response.choices[0] + .message.content.split("")[-1] + .split("") + ) + # 如果有推理内容,则分割推理内容和内容 + if len(response) == 2: + reasoning_content = response[0] + content = response[1] + else: + reasoning_content = None + content = response[0] + + return reasoning_content, content + + def send_embedding_request(self, model, text): + """发送嵌入请求,等待返回结果""" + text = text.replace("\n", " ") + return ( + self.client.embeddings.create(input=[text], model=model).data[0].embedding + ) diff --git a/src/plugins/knowledge/src/lpmmconfig.py b/src/plugins/knowledge/src/lpmmconfig.py new file mode 100644 index 00000000..ff1ac8fa --- /dev/null +++ b/src/plugins/knowledge/src/lpmmconfig.py @@ -0,0 +1,143 @@ +import os +import toml +import sys +import argparse + +PG_NAMESPACE = "paragraph" +ENT_NAMESPACE = "entity" +REL_NAMESPACE = "relation" + +RAG_GRAPH_NAMESPACE = "rag-graph" +RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" +RAG_PG_HASH_NAMESPACE = "rag-pg-hash" + +# 无效实体 +INVALID_ENTITY = [ + "", + "你", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "她们", + "它们", +] + + +def _load_config(config, config_file_path): + """读取TOML格式的配置文件""" + if not os.path.exists(config_file_path): + return + with open(config_file_path, "r", encoding="utf-8") as f: + file_config = toml.load(f) + + # Check if all top-level keys from default config exist in the file config + for key in config.keys(): + if key not in file_config: + print(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。") + sys.exit(1) + + if "llm_providers" in file_config: + for provider in file_config["llm_providers"]: + if provider["name"] not in config["llm_providers"]: + config["llm_providers"][provider["name"]] = dict() + config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"] + config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"] + + if "entity_extract" in file_config: + config["entity_extract"] = file_config["entity_extract"] + + if "rdf_build" in file_config: + config["rdf_build"] = file_config["rdf_build"] + + if "embedding" in file_config: + config["embedding"] = file_config["embedding"] + + if "rag" in file_config: + config["rag"] = file_config["rag"] + + if "qa" in file_config: + config["qa"] = file_config["qa"] + + if "persistence" in file_config: + config["persistence"] = file_config["persistence"] + print(config) + print("Configurations loaded from file: ", config_file_path) + + + +parser = argparse.ArgumentParser(description="Configurations for the pipeline") +parser.add_argument( + "--config_path", + type=str, + default="lpmm_config.toml", + help="Path to the configuration file", +) + +global_config = dict( + { + "llm_providers": { + "localhost": { + "base_url": "https://api.siliconflow.cn/v1", + "api_key": "sk-ospynxadyorf", + } + }, + "entity_extract": { + "llm": { + "provider": "localhost", + "model": "Pro/deepseek-ai/DeepSeek-V3", + } + }, + "rdf_build": { + "llm": { + "provider": "localhost", + "model": "Pro/deepseek-ai/DeepSeek-V3", + } + }, + "embedding": { + "provider": "localhost", + "model": "Pro/BAAI/bge-m3", + "dimension": 1024, + }, + "rag": { + "params": { + "synonym_search_top_k": 10, + "synonym_threshold": 0.75, + } + }, + "qa": { + "params": { + "relation_search_top_k": 10, + "relation_threshold": 0.75, + "paragraph_search_top_k": 10, + "paragraph_node_weight": 0.05, + "ent_filter_top_k": 10, + "ppr_damping": 0.8, + "res_top_k": 10, + }, + "llm": { + "provider": "localhost", + "model": "qa", + }, + }, + "persistence": { + "data_root_path": "data", + "raw_data_path": "data/raw.json", + "openie_data_path": "data/openie.json", + "embedding_data_dir": "data/embedding", + "rag_data_dir": "data/rag", + }, + "info_extraction":{ + "workers": 10, + } + } +) + +# _load_config(global_config, parser.parse_args().config_path) +file_path = os.path.abspath(__file__) +dir_path = os.path.dirname(file_path) +root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir) +config_path = os.path.join(root_path, "config", "lpmm_config.toml") +_load_config(global_config, config_path) diff --git a/src/plugins/knowledge/src/mem_active_manager.py b/src/plugins/knowledge/src/mem_active_manager.py new file mode 100644 index 00000000..073e3bda --- /dev/null +++ b/src/plugins/knowledge/src/mem_active_manager.py @@ -0,0 +1,36 @@ +from .lpmmconfig import global_config +from .embedding_store import EmbeddingManager +from .llm_client import LLMClient +from .utils.dyn_topk import dyn_select_top_k + + +class MemoryActiveManager: + def __init__( + self, + embed_manager: EmbeddingManager, + llm_client_embedding: LLMClient, + ): + self.embed_manager = embed_manager + self.embedding_client = llm_client_embedding + + def get_activation(self, question: str) -> float: + """获取记忆激活度""" + # 生成问题的Embedding + question_embedding = self.embedding_client.send_embedding_request( + "text-embedding", question + ) + # 查询关系库中的相似度 + rel_search_res = self.embed_manager.relation_embedding_store.search_top_k( + question_embedding, 10 + ) + + # 动态过滤阈值 + rel_scores = dyn_select_top_k(rel_search_res, 0.5, 1.0) + if rel_scores[0][1] < global_config["qa"]["params"]["relation_threshold"]: + # 未找到相关关系 + return 0.0 + + # 计算激活度 + activation = sum([item[2] for item in rel_scores]) * 10 + + return activation diff --git a/src/plugins/knowledge/src/open_ie.py b/src/plugins/knowledge/src/open_ie.py new file mode 100644 index 00000000..58259ef8 --- /dev/null +++ b/src/plugins/knowledge/src/open_ie.py @@ -0,0 +1,147 @@ +import json +from typing import Any, Dict, List + + +from .lpmmconfig import INVALID_ENTITY, global_config + + +def _filter_invalid_entities(entities: List[str]) -> List[str]: + """过滤无效的实体""" + valid_entities = set() + for entity in entities: + if ( + not isinstance(entity, str) + or entity.strip() == "" + or entity in INVALID_ENTITY + or entity in valid_entities + ): + # 非字符串/空字符串/在无效实体列表中/重复 + continue + valid_entities.add(entity) + + return list(valid_entities) + + +def _filter_invalid_triples(triples: List[List[str]]) -> List[List[str]]: + """过滤无效的三元组""" + unique_triples = set() + valid_triples = [] + + for triple in triples: + if len(triple) != 3 or ( + (not isinstance(triple[0], str) or triple[0].strip() == "") + or (not isinstance(triple[1], str) or triple[1].strip() == "") + or (not isinstance(triple[2], str) or triple[2].strip() == "") + ): + # 三元组长度不为3,或其中存在空值 + continue + + valid_triple = [str(item) for item in triple] + if tuple(valid_triple) not in unique_triples: + unique_triples.add(tuple(valid_triple)) + valid_triples.append(valid_triple) + + return valid_triples + + +class OpenIE: + """ + OpenIE规约的数据格式为如下 + { + "docs": [ + { + "idx": "文档的唯一标识符(通常是文本的SHA256哈希值)", + "passage": "文档的原始文本", + "extracted_entities": ["实体1", "实体2", ...], + "extracted_triples": [["主语", "谓语", "宾语"], ...] + }, + ... + ], + "avg_ent_chars": "实体平均字符数", + "avg_ent_words": "实体平均词数" + } + """ + + def __init__( + self, + docs: List[Dict[str, Any]], + avg_ent_chars, + avg_ent_words, + ): + self.docs = docs + self.avg_ent_chars = avg_ent_chars + self.avg_ent_words = avg_ent_words + + for doc in self.docs: + # 过滤实体列表 + doc["extracted_entities"] = _filter_invalid_entities( + doc["extracted_entities"] + ) + # 过滤无效的三元组 + doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) + + @staticmethod + def _from_dict(data): + """从字典中获取OpenIE对象""" + return OpenIE( + docs=data["docs"], + avg_ent_chars=data["avg_ent_chars"], + avg_ent_words=data["avg_ent_words"], + ) + + def _to_dict(self): + """转换为字典""" + return { + "docs": self.docs, + "avg_ent_chars": self.avg_ent_chars, + "avg_ent_words": self.avg_ent_words, + } + + @staticmethod + def load() -> "OpenIE": + """从文件中加载OpenIE数据""" + with open( + global_config["persistence"]["openie_data_path"], "r", encoding="utf-8" + ) as f: + data = json.loads(f.read()) + + openie_data = OpenIE._from_dict(data) + + return openie_data + + @staticmethod + def save(openie_data: "OpenIE"): + """保存OpenIE数据到文件""" + with open( + global_config["persistence"]["openie_data_path"], "w", encoding="utf-8" + ) as f: + f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) + + def extract_entity_dict(self): + """提取实体列表""" + ner_output_dict = dict( + { + doc_item["idx"]: doc_item["extracted_entities"] + for doc_item in self.docs + if len(doc_item["extracted_entities"]) > 0 + } + ) + return ner_output_dict + + def extract_triple_dict(self): + """提取三元组列表""" + triple_output_dict = dict( + { + doc_item["idx"]: doc_item["extracted_triples"] + for doc_item in self.docs + if len(doc_item["extracted_triples"]) > 0 + } + ) + return triple_output_dict + + def extract_raw_paragraph_dict(self): + """提取原始段落""" + raw_paragraph_dict = dict( + {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs} + ) + return raw_paragraph_dict diff --git a/src/plugins/knowledge/src/prompt_template.py b/src/plugins/knowledge/src/prompt_template.py new file mode 100644 index 00000000..ab6ac08c --- /dev/null +++ b/src/plugins/knowledge/src/prompt_template.py @@ -0,0 +1,73 @@ +from typing import List + +from .llm_client import LLMMessage + +entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。 + +输出格式示例: +[ "实体A", "实体B", "实体C" ] + +请注意以下要求: +- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。 +- 尽可能多的提取出段落中的全部实体; +""" + + +def build_entity_extract_context(paragraph: str) -> List[LLMMessage]: + messages = [ + LLMMessage("system", entity_extract_system_prompt).to_dict(), + LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(), + ] + return messages + + +rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。 + +请使用JSON回复,使用三元组的JSON列表输出RDF图中的关系(每个三元组代表一个关系)。 + +输出格式示例: +[ + ["某实体","关系","某属性"], + ["某实体","关系","某实体"], + ["某资源","关系","某属性"] +] + +请注意以下要求: +- 每个三元组应包含每个段落的实体命名列表中的至少一个命名实体,但最好是两个。 +- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。 +""" + + +def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]: + messages = [ + LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), + LLMMessage( + "user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""" + ).to_dict(), + ] + return messages + + +qa_system_prompt = """ +你是一个性能优异的QA系统。请根据给定的问题和一些可能对你有帮助的信息作出回答。 + +请注意以下要求: +- 你可以使用给定的信息来回答问题,但请不要直接引用它们。 +- 你的回答应该简洁明了,避免冗长的解释。 +- 如果你无法回答问题,请直接说“我不知道”。 +""" + + +def build_qa_context( + question: str, knowledge: list[(str, str, str)] +) -> List[LLMMessage]: + knowledge = "\n".join( + [f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)] + ) + messages = [ + LLMMessage("system", qa_system_prompt).to_dict(), + LLMMessage( + "user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}" + ).to_dict(), + ] + return messages diff --git a/src/plugins/knowledge/src/qa_manager.py b/src/plugins/knowledge/src/qa_manager.py new file mode 100644 index 00000000..986c8419 --- /dev/null +++ b/src/plugins/knowledge/src/qa_manager.py @@ -0,0 +1,117 @@ +import time +from typing import Tuple, List, Dict + +from .global_logger import logger +# from . import prompt_template +from .embedding_store import EmbeddingManager +from .llm_client import LLMClient +from .kg_manager import KGManager +from .lpmmconfig import global_config +from .utils.dyn_topk import dyn_select_top_k + + +class QAManager: + def __init__( + self, + embed_manager: EmbeddingManager, + kg_manager: KGManager, + llm_client_embedding: LLMClient, + llm_client_filter: LLMClient, + llm_client_qa: LLMClient, + ): + self.embed_manager = embed_manager + self.kg_manager = kg_manager + self.llm_client_list = { + "embedding": llm_client_embedding, + "filter": llm_client_filter, + "qa": llm_client_qa, + } + + def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Dict[str, float] | None]: + """处理查询""" + + # 生成问题的Embedding + part_start_time =time.perf_counter() + question_embedding = self.llm_client_list["embedding"].send_embedding_request( + global_config["embedding"]["model"], question + ) + part_end_time = time.perf_counter() + logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") + + # 根据问题Embedding查询Relation Embedding库 + part_start_time =time.perf_counter() + relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( + question_embedding, + global_config["qa"]["params"]["relation_search_top_k"], + ) + # 过滤阈值 + # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 + relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) + if ( + relation_search_res[0][1] + < global_config["qa"]["params"]["relation_threshold"] + ): + # 未找到相关关系 + relation_search_res = [] + + part_end_time = time.perf_counter() + logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") + + for res in relation_search_res: + rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str + print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") + + # TODO: 使用LLM过滤三元组结果 + # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") + # part_start_time = time.time() + + # 根据问题Embedding查询Paragraph Embedding库 + part_start_time =time.perf_counter() + paragraph_search_res = ( + self.embed_manager.paragraphs_embedding_store.search_top_k( + question_embedding, + global_config["qa"]["params"]["paragraph_search_top_k"], + ) + ) + part_end_time = time.perf_counter() + logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") + + if len(relation_search_res) != 0: + logger.info("找到相关关系,将使用RAG进行检索") + # 使用KG检索 + part_start_time =time.perf_counter() + result, ppr_node_weights = self.kg_manager.kg_search( + relation_search_res, paragraph_search_res, self.embed_manager + ) + part_end_time = time.perf_counter() + logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s") + else: + logger.info("未找到相关关系,将使用文段检索结果") + result = paragraph_search_res + ppr_node_weights = None + + # 过滤阈值 + result = dyn_select_top_k(result, 0.5, 1.0) + + for res in result: + raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[ + res[0] + ].str + print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") + + return result, ppr_node_weights + + def get_knowledge(self, question: str) -> str: + """获取知识""" + # 处理查询 + query_res, _ = self.process_query(question) + + knowledge = [ + ( + self.embed_manager.paragraphs_embedding_store.store[res[0]].str, + res[1], + ) + for res in query_res + ] + found_knowledge = "\n".join([f"第{i + 1}条知识:{k[1]}\n 该条知识对于问题的相关性:{k[0]}" for i, k in enumerate(knowledge)]) + return found_knowledge \ No newline at end of file diff --git a/src/plugins/knowledge/src/raw_processing.py b/src/plugins/knowledge/src/raw_processing.py new file mode 100644 index 00000000..75e4b863 --- /dev/null +++ b/src/plugins/knowledge/src/raw_processing.py @@ -0,0 +1,46 @@ +import json +import os + +from .global_logger import logger +from .lpmmconfig import global_config +from .utils.hash import get_sha256 + + +def load_raw_data() -> tuple[list[str], list[str]]: + """加载原始数据文件 + + 读取原始数据文件,将原始数据加载到内存中 + + Returns: + - raw_data: 原始数据字典 + - md5_set: 原始数据的SHA256集合 + """ + # 读取import.json文件 + if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: + with open( + global_config["persistence"]["raw_data_path"], "r", encoding="utf-8" + ) as f: + import_json = json.loads(f.read()) + else: + raise Exception("原始数据文件读取失败") + # import_json内容示例: + # import_json = [ + # "The capital of China is Beijing. The capital of France is Paris.", + # ] + raw_data = [] + sha256_list = [] + sha256_set = set() + for item in import_json: + if not isinstance(item, str): + logger.warning("数据类型错误:{}".format(item)) + continue + pg_hash = get_sha256(item) + if pg_hash in sha256_set: + logger.warning("重复数据:{}".format(item)) + continue + sha256_set.add(pg_hash) + sha256_list.append(pg_hash) + raw_data.append(item) + logger.info("共读取到{}条数据".format(len(raw_data))) + + return sha256_list, raw_data diff --git a/src/plugins/knowledge/src/utils/__init__.py b/src/plugins/knowledge/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/plugins/knowledge/src/utils/data_loader.py b/src/plugins/knowledge/src/utils/data_loader.py new file mode 100644 index 00000000..08bdf414 --- /dev/null +++ b/src/plugins/knowledge/src/utils/data_loader.py @@ -0,0 +1,58 @@ +import jsonlines +from pathlib import Path +from typing import List, Dict, Any, Union, Optional +from src.config import global_config as config + + +class DataLoader: + """数据加载工具类,用于从/data目录下加载各种格式的数据文件""" + + def __init__(self, custom_data_dir: Optional[Union[str, Path]] = None): + """ + 初始化数据加载器 + + Args: + custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径 + """ + self.data_dir = ( + Path(custom_data_dir) + if custom_data_dir + else Path(config["persistence"]["data_root_path"]) + ) + if not self.data_dir.exists(): + raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在") + + def _resolve_file_path(self, filename: str) -> Path: + """ + 解析文件路径 + + Args: + filename: 文件名 + + Returns: + 完整的文件路径 + + Raises: + FileNotFoundError: 当文件不存在时抛出 + """ + file_path = self.data_dir / filename + if not file_path.exists(): + raise FileNotFoundError(f"文件 {filename} 不存在") + return file_path + + def load_jsonl(self, filename: str) -> List[Dict[str, Any]]: + """ + 加载JSONL格式的文件 + + Args: + filename: 文件名 + + Returns: + 包含所有数据的列表 + """ + file_path = self._resolve_file_path(filename) + data = [] + with jsonlines.open(file_path) as reader: + for obj in reader: + data.append(obj) + return data diff --git a/src/plugins/knowledge/src/utils/dyn_topk.py b/src/plugins/knowledge/src/utils/dyn_topk.py new file mode 100644 index 00000000..02bc5e3e --- /dev/null +++ b/src/plugins/knowledge/src/utils/dyn_topk.py @@ -0,0 +1,51 @@ +from typing import List, Any, Tuple + + +def dyn_select_top_k( + score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float +) -> List[Tuple[Any, float, float]]: + """动态TopK选择""" + # 按照分数排序(降序) + sorted_score = sorted(score, key=lambda x: x[1], reverse=True) + + # 归一化 + max_score = sorted_score[0][1] + min_score = sorted_score[-1][1] + normalized_score = [] + for score_item in sorted_score: + normalized_score.append( + tuple( + [ + score_item[0], + score_item[1], + (score_item[1] - min_score) / (max_score - min_score), + ] + ) + ) + + # 寻找跳变点:score变化最大的位置 + jump_idx = 0 + for i in range(1, len(normalized_score)): + if abs(normalized_score[i][2] - normalized_score[i - 1][2]) > abs( + normalized_score[jump_idx][2] - normalized_score[jump_idx - 1][2] + ): + jump_idx = i + # 跳变阈值 + jump_threshold = normalized_score[jump_idx][2] + + # 计算均值 + mean_score = sum([s[2] for s in normalized_score]) / len(normalized_score) + # 计算方差 + var_score = sum([(s[2] - mean_score) ** 2 for s in normalized_score]) / len( + normalized_score + ) + + # 动态阈值 + threshold = jmp_factor * jump_threshold + (1 - jmp_factor) * ( + mean_score + var_factor * var_score + ) + + # 重新过滤 + res = [s for s in normalized_score if s[2] > threshold] + + return res diff --git a/src/plugins/knowledge/src/utils/hash.py b/src/plugins/knowledge/src/utils/hash.py new file mode 100644 index 00000000..b3e12b87 --- /dev/null +++ b/src/plugins/knowledge/src/utils/hash.py @@ -0,0 +1,8 @@ +import hashlib + + +def get_sha256(string: str) -> str: + """获取字符串的SHA256值""" + sha256 = hashlib.sha256() + sha256.update(string.encode("utf-8")) + return sha256.hexdigest() diff --git a/src/plugins/knowledge/src/utils/json_fix.py b/src/plugins/knowledge/src/utils/json_fix.py new file mode 100644 index 00000000..672fb1f8 --- /dev/null +++ b/src/plugins/knowledge/src/utils/json_fix.py @@ -0,0 +1,79 @@ +import json + + +def _find_unclosed(json_str): + """ + Identifies the unclosed braces and brackets in the JSON string. + + Args: + json_str (str): The JSON string to analyze. + + Returns: + list: A list of unclosed elements in the order they were opened. + """ + unclosed = [] + inside_string = False + escape_next = False + + for char in json_str: + if inside_string: + if escape_next: + escape_next = False + elif char == "\\": + escape_next = True + elif char == '"': + inside_string = False + else: + if char == '"': + inside_string = True + elif char in "{[": + unclosed.append(char) + elif char in "}]": + if unclosed and ( + (char == "}" and unclosed[-1] == "{") + or (char == "]" and unclosed[-1] == "[") + ): + unclosed.pop() + + return unclosed + + +# The following code is used to fix a broken JSON string. +# From HippoRAG2 (GitHub: OSU-NLP-Group/HippoRAG) +def fix_broken_generated_json(json_str: str) -> str: + """ + Fixes a malformed JSON string by: + - Removing the last comma and any trailing content. + - Iterating over the JSON string once to determine and fix unclosed braces or brackets. + - Ensuring braces and brackets inside string literals are not considered. + + If the original json_str string can be successfully loaded by json.loads(), will directly return it without any modification. + + Args: + json_str (str): The malformed JSON string to be fixed. + + Returns: + str: The corrected JSON string. + """ + + try: + # Try to load the JSON to see if it is valid + json.loads(json_str) + return json_str # Return as-is if valid + except json.JSONDecodeError: + pass + + # Step 1: Remove trailing content after the last comma. + last_comma_index = json_str.rfind(",") + if last_comma_index != -1: + json_str = json_str[:last_comma_index] + + # Step 2: Identify unclosed braces and brackets. + unclosed_elements = _find_unclosed(json_str) + + # Step 3: Append the necessary closing elements in reverse order of opening. + closing_map = {"{": "}", "[": "]"} + for open_char in reversed(unclosed_elements): + json_str += closing_map[open_char] + + return json_str diff --git a/src/plugins/knowledge/src/utils/visualize_graph.py b/src/plugins/knowledge/src/utils/visualize_graph.py new file mode 100644 index 00000000..845e18a2 --- /dev/null +++ b/src/plugins/knowledge/src/utils/visualize_graph.py @@ -0,0 +1,17 @@ +import networkx as nx +from matplotlib import pyplot as plt + + +def draw_graph_and_show(graph): + """绘制图并显示,画布大小1280*1280""" + fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100) + nx.draw_networkx( + graph, + node_size=100, + width=0.5, + with_labels=True, + labels=nx.get_node_attributes(graph, "content"), + font_family="Sarasa Mono SC", + font_size=8, + ) + fig.show() \ No newline at end of file diff --git a/src/plugins/memory_system/Hippocampus.py b/src/plugins/memory_system/Hippocampus.py index 738e47c4..f25f1d45 100644 --- a/src/plugins/memory_system/Hippocampus.py +++ b/src/plugins/memory_system/Hippocampus.py @@ -1509,14 +1509,19 @@ class HippocampusManager: return response async def get_memory_from_topic( - self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3 + self, + valid_keywords: list[str], + max_memory_num: int = 3, + max_memory_length: int = 2, + max_depth: int = 3, + fast_retrieval: bool = False, ) -> list: """从文本中获取相关记忆的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") try: response = await self._hippocampus.get_memory_from_topic( - valid_keywords, max_memory_num, max_memory_length, max_depth + valid_keywords, max_memory_num, max_memory_length, max_depth, fast_retrieval ) except Exception as e: logger.error(f"文本激活记忆失败: {e}") diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index e4e2a2a8..f0a52e76 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,7 +1,8 @@ [inner] -version = "1.4.0" +version = "1.3.1" -#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- + +#以下是给开发人员阅读的,一般用户不需要阅读 #如果你想要修改配置文件,请在修改后将version的值进行变更 #如果新增项目,请在BotConfig类下新增相应的变量 #1.如果你修改的是[]层级项目,例如你新增了 [memory],那么请在config.py的 load_config函数中的include_configs字典中新增"内容":{ @@ -18,12 +19,11 @@ version = "1.4.0" # 次版本号:当你做了向下兼容的功能性新增, # 修订号:当你做了向下兼容的问题修正。 # 先行版本号及版本编译信息可以加到“主版本号.次版本号.修订号”的后面,作为延伸。 -#----以上是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- [bot] -qq = 1145141919810 +qq = 114514 nickname = "麦麦" -alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效 +alias_names = ["麦叠", "牢麦"] [groups] talk_allowed = [ @@ -41,24 +41,23 @@ personality_sides = [ "用一句话或几句话描述人格的一些细节", "用一句话或几句话描述人格的一些细节", "用一句话或几句话描述人格的一些细节", -]# 条数任意,不能为0, 该选项还在调试中,可能未完全生效 +]# 条数任意 [identity] #アイデンティティがない 生まれないらららら # 兴趣爱好 未完善,有些条目未使用 identity_detail = [ "身份特点", "身份特点", -]# 条数任意,不能为0, 该选项还在调试中,可能未完全生效 +]# 条数任意 #外貌特征 -height = 170 # 身高 单位厘米 该选项还在调试中,暂时未生效 -weight = 50 # 体重 单位千克 该选项还在调试中,暂时未生效 -age = 20 # 年龄 单位岁 该选项还在调试中,暂时未生效 -gender = "男" # 性别 该选项还在调试中,暂时未生效 -appearance = "用几句话描述外貌特征" # 外貌特征 该选项还在调试中,暂时未生效 +height = 170 # 身高 单位厘米 +weight = 50 # 体重 单位千克 +age = 20 # 年龄 单位岁 +gender = "男" # 性别 +appearance = "用几句话描述外貌特征" # 外貌特征 [schedule] -enable_schedule_gen = true # 是否启用日程表 -enable_schedule_interaction = true # 日程表是否影响回复模式 +enable_schedule_gen = true # 是否启用日程表(尚未完成) prompt_schedule_gen = "用几句话描述描述性格特点或行动规律,这个特征会用来生成日程表" schedule_doing_update_interval = 900 # 日程表更新间隔 单位秒 schedule_temperature = 0.1 # 日程表温度,建议0.1-0.5 @@ -68,25 +67,19 @@ time_zone = "Asia/Shanghai" # 给你的机器人设置时区,可以解决运 nonebot-qq="http://127.0.0.1:18002/api/message" [response] #群聊的回复策略 -enable_heart_flowC = true -# 该功能还在完善中 -# 是否启用heart_flowC(心流聊天,HFC)模式 -# 启用后麦麦会自主选择进入heart_flowC模式(持续一段时间),进行主动的观察和回复,并给出回复,比较消耗token +#reasoning:推理模式,麦麦会根据上下文进行推理,并给出回复 +#heart_flow:结合了PFC模式和心流模式,麦麦会进行主动的观察和回复,并给出回复 +response_mode = "heart_flow" # 回复策略,可选值:heart_flow(心流),reasoning(推理) -#一般回复参数 -model_reasoning_probability = 0.7 # 麦麦回答时选择推理模型 模型的概率 -model_normal_probability = 0.3 # 麦麦回答时选择一般模型 模型的概率 - -[heartflow] #启用启用heart_flowC(心流聊天)模式时生效,需要填写以下参数 -reply_trigger_threshold = 3.0 # 心流聊天触发阈值,越低越容易进入心流聊天 -probability_decay_factor_per_second = 0.2 # 概率衰减因子,越大衰减越快,越高越容易退出心流聊天 -default_decay_rate_per_second = 0.98 # 默认衰减率,越大衰减越快,越高越难进入心流聊天 -initial_duration = 60 # 初始持续时间,越大心流聊天持续的时间越长 +#推理回复参数 +model_r1_probability = 0.7 # 麦麦回答时选择主要回复模型1 模型的概率 +model_v3_probability = 0.3 # 麦麦回答时选择次要回复模型2 模型的概率 +[heartflow] # 注意:可能会消耗大量token,请谨慎开启,仅会使用v3模型 +sub_heart_flow_update_interval = 60 # 子心流更新频率,间隔 单位秒 +sub_heart_flow_freeze_time = 100 # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 sub_heart_flow_stop_time = 500 # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 -# sub_heart_flow_update_interval = 60 -# sub_heart_flow_freeze_time = 100 -# heart_flow_update_interval = 600 +heart_flow_update_interval = 600 # 心流更新频率,间隔 单位秒 observation_context_size = 20 # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩 compressed_length = 5 # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 @@ -94,13 +87,11 @@ compress_length_limit = 5 #最多压缩份数,超过该数值的压缩上下 [message] -max_context_size = 12 # 麦麦回复时获得的上文数量,建议12,太短太长都会导致脑袋尖尖 -emoji_chance = 0.2 # 麦麦一般回复时使用表情包的概率,设置为1让麦麦自己决定发不发 -thinking_timeout = 100 # 麦麦最长思考时间,超过这个时间的思考会放弃(往往是api反应太慢) -max_response_length = 256 # 麦麦单次回答的最大token数 +max_context_size = 12 # 麦麦获得的上文数量,建议12,太短太长都会导致脑袋尖尖 +emoji_chance = 0.2 # 麦麦使用表情包的概率,设置为1让麦麦自己决定发不发 +thinking_timeout = 60 # 麦麦最长思考时间,超过这个时间的思考会放弃 +max_response_length = 256 # 麦麦回答的最大token数 message_buffer = true # 启用消息缓冲器?启用此项以解决消息的拆分问题,但会使麦麦的回复延迟 - -# 以下是消息过滤,可以根据规则过滤特定消息,将不会读取这些消息 ban_words = [ # "403","张三" ] @@ -112,23 +103,22 @@ ban_msgs_regex = [ # "\\[CQ:at,qq=\\d+\\]" # 匹配@ ] -[willing] # 一般回复模式的回复意愿设置 +[willing] willing_mode = "classical" # 回复意愿模式 —— 经典模式:classical,动态模式:dynamic,mxp模式:mxp,自定义模式:custom(需要你自己实现) response_willing_amplifier = 1 # 麦麦回复意愿放大系数,一般为1 response_interested_rate_amplifier = 1 # 麦麦回复兴趣度放大系数,听到记忆里的内容时放大系数 down_frequency_rate = 3 # 降低回复频率的群组回复意愿降低系数 除法 -emoji_response_penalty = 0 # 表情包回复惩罚系数,设为0为不回复单个表情包,减少单独回复表情包的概率 +emoji_response_penalty = 0.1 # 表情包回复惩罚系数,设为0为不回复单个表情包,减少单独回复表情包的概率 mentioned_bot_inevitable_reply = false # 提及 bot 必然回复 at_bot_inevitable_reply = false # @bot 必然回复 [emoji] -max_emoji_num = 90 # 表情包最大数量 +max_emoji_num = 120 # 表情包最大数量 max_reach_deletion = true # 开启则在达到最大数量时删除表情包,关闭则达到最大数量时不删除,只是不会继续收集表情包 check_interval = 30 # 检查表情包(注册,破损,删除)的时间间隔(分钟) auto_save = true # 是否保存表情包和图片 - -enable_check = false # 是否启用表情包过滤,只有符合该要求的表情包才会被保存 -check_prompt = "符合公序良俗" # 表情包过滤要求,只有符合该要求的表情包才会被保存 +enable_check = false # 是否启用表情包过滤 +check_prompt = "符合公序良俗" # 表情包过滤要求 [memory] build_memory_interval = 2000 # 记忆构建间隔 单位秒 间隔越低,麦麦学习越多,但是冗余信息也会增多 @@ -141,8 +131,7 @@ forget_memory_interval = 1000 # 记忆遗忘间隔 单位秒 间隔越低, memory_forget_time = 24 #多长时间后的记忆会被遗忘 单位小时 memory_forget_percentage = 0.01 # 记忆遗忘比例 控制记忆遗忘程度 越大遗忘越多 建议保持默认 -#不希望记忆的词,已经记忆的不会受到影响 -memory_ban_words = [ +memory_ban_words = [ #不希望记忆的词 # "403","张三" ] @@ -178,7 +167,7 @@ word_replace_rate=0.006 # 整词替换概率 [response_splitter] enable_response_splitter = true # 是否启用回复分割器 -response_max_length = 256 # 回复允许的最大长度 +response_max_length = 100 # 回复允许的最大长度 response_max_sentence_num = 4 # 回复允许的最大句子数 [remote] #发送统计信息,主要是看全球有多少只麦麦 diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml new file mode 100644 index 00000000..14ea3a2e --- /dev/null +++ b/template/lpmm_config_template.toml @@ -0,0 +1,54 @@ +# LLM API 服务提供商,可配置多个 +[[llm_providers]] +name = "localhost" +base_url = "http://127.0.0.1:8888/v1/" +api_key = "lm_studio" + +[[llm_providers]] +name = "siliconflow" +base_url = "https://api.siliconflow.cn/v1/" +api_key = "" + +[entity_extract.llm] +# 设置用于实体提取的LLM模型 +provider = "siliconflow" # 服务提供商 +model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称 + +[rdf_build.llm] +# 设置用于RDF构建的LLM模型 +provider = "siliconflow" # 服务提供商 +model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称 + +[embedding] +# 设置用于文本嵌入的Embedding模型 +provider = "localhost" # 服务提供商 +model = "text-embedding-bge-m3" # 模型名称 +dimension = 1024 # 嵌入维度 + +[rag.params] +# RAG参数配置 +synonym_search_top_k = 10 # 同义词搜索TopK +synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词) + +[qa.llm] +# 设置用于QA的LLM模型 +provider = "siliconflow" # 服务提供商 +model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称 + +[qa.params] +# QA参数配置 +relation_search_top_k = 10 # 关系搜索TopK +relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系) +paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果) +paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用) +ent_filter_top_k = 10 # 实体过滤TopK +ppr_damping = 0.8 # PPR阻尼系数 +res_top_k = 3 # 最终提供的文段TopK + +[persistence] +# 持久化配置(存储中间数据,防止重复计算) +data_root_path = "data" # 数据根目录 +raw_data_path = "data/import.json" # 原始数据路径 +openie_data_path = "data/openie.json" # OpenIE数据路径 +embedding_data_dir = "data/embedding" # 嵌入数据目录 +rag_data_dir = "data/rag" # RAG数据目录 diff --git a/template/template.env b/template/template.env index c1a6dd0d..06e9b07e 100644 --- a/template/template.env +++ b/template/template.env @@ -29,18 +29,8 @@ CHAT_ANY_WHERE_KEY= SILICONFLOW_KEY= # 定义日志相关配置 - -# 精简控制台输出格式 -SIMPLE_OUTPUT=true - -# 自定义日志的默认控制台输出日志级别 -CONSOLE_LOG_LEVEL=INFO - -# 自定义日志的默认文件输出日志级别 -FILE_LOG_LEVEL=DEBUG - -# 原生日志的控制台输出日志级别(nonebot就是这一类) -DEFAULT_CONSOLE_LOG_LEVEL=SUCCESS - -# 原生日志的默认文件输出日志级别(nonebot就是这一类) -DEFAULT_FILE_LOG_LEVEL=DEBUG +SIMPLE_OUTPUT=true # 精简控制台输出格式 +CONSOLE_LOG_LEVEL=INFO # 自定义日志的默认控制台输出日志级别 +FILE_LOG_LEVEL=DEBUG # 自定义日志的默认文件输出日志级别 +DEFAULT_CONSOLE_LOG_LEVEL=SUCCESS # 原生日志的控制台输出日志级别(nonebot就是这一类) +DEFAULT_FILE_LOG_LEVEL=DEBUG # 原生日志的默认文件输出日志级别(nonebot就是这一类) \ No newline at end of file diff --git a/麦麦开始学习.bat b/麦麦开始学习.bat new file mode 100644 index 00000000..53bda5d0 --- /dev/null +++ b/麦麦开始学习.bat @@ -0,0 +1,46 @@ +@echo off +CHCP 65001 > nul +setlocal enabledelayedexpansion + +@REM REM 查找venv虚拟环境 +@REM set "venv_path=%~dp0venv\Scripts\activate.bat" +@REM if not exist "%venv_path%" ( +@REM echo 错误: 未找到虚拟环境,请确保venv目录存在 +@REM pause +@REM exit /b 1 +@REM ) + +@REM REM 激活虚拟环境 +@REM call "%venv_path%" +@REM if %ERRORLEVEL% neq 0 ( +@REM echo 错误: 虚拟环境激活失败 +@REM pause +@REM exit /b 1 +@REM ) + +REM 运行预处理脚本 +python "%~dp0raw_data_preprocessor.py" +if %ERRORLEVEL% neq 0 ( + echo 错误: raw_data_preprocessor.py 执行失败 + pause + exit /b 1 +) + +REM 运行信息提取脚本 +python "%~dp0info_extraction.py" +if %ERRORLEVEL% neq 0 ( + echo 错误: info_extraction.py 执行失败 + pause + exit /b 1 +) + +REM 运行OpenIE导入脚本 +python "%~dp0import_openie.py" +if %ERRORLEVEL% neq 0 ( + echo 错误: import_openie.py 执行失败 + pause + exit /b 1 +) + +echo 所有处理步骤完成! +pause \ No newline at end of file diff --git a/(临时版)麦麦开始学习.bat b/(临时版)麦麦开始学习.bat deleted file mode 100644 index f96d7cfd..00000000 --- a/(临时版)麦麦开始学习.bat +++ /dev/null @@ -1,56 +0,0 @@ -@echo off -chcp 65001 > nul -setlocal enabledelayedexpansion -cd /d %~dp0 - -title 麦麦学习系统 - -cls -echo ====================================== -echo 警告提示 -echo ====================================== -echo 1.这是一个demo系统,不完善不稳定,仅用于体验/不要塞入过长过大的文本,这会导致信息提取迟缓 -echo ====================================== - -echo. -echo ====================================== -echo 请选择Python环境: -echo 1 - venv (推荐) -echo 2 - conda -echo ====================================== -choice /c 12 /n /m "请输入数字选择(1或2): " - -if errorlevel 2 ( - echo ====================================== - set "CONDA_ENV=" - set /p CONDA_ENV="请输入要激活的 conda 环境名称: " - - :: 检查输入是否为空 - if "!CONDA_ENV!"=="" ( - echo 错误:环境名称不能为空 - pause - exit /b 1 - ) - - call conda activate !CONDA_ENV! - if errorlevel 1 ( - echo 激活 conda 环境失败 - pause - exit /b 1 - ) - - echo Conda 环境 "!CONDA_ENV!" 激活成功 - python src/plugins/zhishi/knowledge_library.py -) else ( - if exist "venv\Scripts\python.exe" ( - venv\Scripts\python src/plugins/zhishi/knowledge_library.py - ) else ( - echo ====================================== - echo 错误: venv环境不存在,请先创建虚拟环境 - pause - exit /b 1 - ) -) - -endlocal -pause