大力修复

pull/789/head
UnCLAS-Prommer 2025-04-18 12:34:16 +08:00
parent d72b34709c
commit bf7827a571
14 changed files with 30 additions and 552 deletions

View File

@ -815,7 +815,7 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
# 从后向前替换,避免位置改变 # 从后向前替换,避免位置改变
converted_timestamps.reverse() converted_timestamps.reverse()
for ts, match, readable_time in converted_timestamps: for _, match, readable_time in converted_timestamps:
pattern_instance = re.escape(match.group(0)) pattern_instance = re.escape(match.group(0))
if readable_time in readable_time_used: if readable_time in readable_time_used:
# 如果相同格式的时间已存在,替换为空字符串 # 如果相同格式的时间已存在,替换为空字符串

View File

@ -9,11 +9,13 @@ from ...moods.moods import MoodManager
from ....individuality.individuality import Individuality from ....individuality.individuality import Individuality
from ...memory_system.Hippocampus import HippocampusManager from ...memory_system.Hippocampus import HippocampusManager
from ...schedule.schedule_generator import bot_schedule from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config from src.config.config import global_config
from ...person_info.relationship_manager import relationship_manager from ...person_info.relationship_manager import relationship_manager
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugins.knowledge.src.qa_manager import qa_manager from src.plugins.knowledge.knowledge_lib import qa_manager
from src.plugins.chat.chat_stream import ChatStream
logger = get_module_logger("prompt") logger = get_module_logger("prompt")
@ -54,7 +56,7 @@ class PromptBuilder:
self.activate_messages = "" self.activate_messages = ""
async def _build_prompt( async def _build_prompt(
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None self, chat_stream: ChatStream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
) -> tuple[str, str]: ) -> tuple[str, str]:
# 开始构建prompt # 开始构建prompt
prompt_personality = "" prompt_personality = ""
@ -102,16 +104,14 @@ class PromptBuilder:
related_memory = await HippocampusManager.get_instance().get_memory_from_text( related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
) )
related_memory_info = ""
if related_memory: if related_memory:
related_memory_info = ""
for memory in related_memory: for memory in related_memory:
related_memory_info += memory[1] related_memory_info += memory[1]
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n" # memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
memory_prompt = await global_prompt_manager.format_prompt( memory_prompt = await global_prompt_manager.format_prompt(
"memory_prompt", related_memory_info=related_memory_info "memory_prompt", related_memory_info=related_memory_info
) )
else:
related_memory_info = ""
# print(f"相关记忆:{related_memory_info}") # print(f"相关记忆:{related_memory_info}")
@ -230,224 +230,6 @@ class PromptBuilder:
related_info += qa_manager.get_knowledge(message) related_info += qa_manager.get_knowledge(message)
return related_info return related_info
# start_time = time.time()
# related_info = ""
# logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# # 1. 先从LLM获取主题类似于记忆系统的做法
# topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
# if not topics:
# logger.info("未能提取到任何主题,使用整个消息进行查询")
# embedding = await get_embedding(message, request_type="prompt_build")
# if not embedding:
# logger.error("获取消息嵌入向量失败")
# return ""
# related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
# logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
# return related_info
# # 2. 对每个主题进行知识库查询
# logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# # 优化批量获取嵌入向量减少API调用
# embeddings = {}
# topics_batch = [topic for topic in topics if len(topic) > 0]
# if message: # 确保消息非空
# topics_batch.append(message)
# 批量获取嵌入向量
# embed_start_time = time.time()
# for text in topics_batch:
# if not text or len(text.strip()) == 0:
# continue
# try:
# embedding = await get_embedding(text, request_type="prompt_build")
# if embedding:
# embeddings[text] = embedding
# else:
# logger.warning(f"获取'{text}'的嵌入向量失败")
# except Exception as e:
# logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
# logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
# if not embeddings:
# logger.error("所有嵌入向量获取失败")
# return ""
# # 3. 对每个主题进行知识库查询
# all_results = []
# query_start_time = time.time()
# # 首先添加原始消息的查询结果
# if message in embeddings:
# original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
# if original_results:
# for result in original_results:
# result["topic"] = "原始消息"
# all_results.extend(original_results)
# logger.info(f"原始消息查询到{len(original_results)}条结果")
# # 然后添加每个主题的查询结果
# for topic in topics:
# if not topic or topic not in embeddings:
# continue
# try:
# # topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
# # if topic_results:
# # # 添加主题标记
# # for result in topic_results:
# # result["topic"] = topic
# # all_results.extend(topic_results)
# # logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
# except Exception as e:
# logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
# logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# # 4. 去重和过滤
# process_start_time = time.time()
# unique_contents = set()
# filtered_results = []
# for result in all_results:
# content = result["content"]
# if content not in unique_contents:
# unique_contents.add(content)
# filtered_results.append(result)
# # 5. 按相似度排序
# filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# # 6. 限制总数量最多10条
# filtered_results = filtered_results[:10]
# logger.info(
# f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
# )
# # 7. 格式化输出
# if filtered_results:
# format_start_time = time.time()
# grouped_results = {}
# for result in filtered_results:
# topic = result["topic"]
# if topic not in grouped_results:
# grouped_results[topic] = []
# grouped_results[topic].append(result)
# 按主题组织输出
# for topic, results in grouped_results.items():
# related_info += f"【主题: {topic}】\n"
# for _i, result in enumerate(results, 1):
# _similarity = result["similarity"]
# content = result["content"].strip()
# # 调试:为内容添加序号和相似度信息
# # related_info += f"{i}. [{similarity:.2f}] {content}\n"
# related_info += f"{content}\n"
# related_info += "\n"
# # logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
# logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
# return related_info
# def get_info_from_db(
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
# ) -> Union[str, list]:
# if not query_embedding:
# return "" if not return_raw else []
# # 使用余弦相似度计算
# pipeline = [
# {
# "$addFields": {
# "dotProduct": {
# "$reduce": {
# "input": {"$range": [0, {"$size": "$embedding"}]},
# "initialValue": 0,
# "in": {
# "$add": [
# "$$value",
# {
# "$multiply": [
# {"$arrayElemAt": ["$embedding", "$$this"]},
# {"$arrayElemAt": [query_embedding, "$$this"]},
# ]
# },
# ]
# },
# }
# },
# "magnitude1": {
# "$sqrt": {
# "$reduce": {
# "input": "$embedding",
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# "magnitude2": {
# "$sqrt": {
# "$reduce": {
# "input": query_embedding,
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# }
# },
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
# {
# "$match": {
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
# }
# },
# {"$sort": {"similarity": -1}},
# {"$limit": limit},
# {"$project": {"content": 1, "similarity": 1}},
# ]
# results = list(db.knowledges.aggregate(pipeline))
# logger.debug(f"知识库查询结果数量: {len(results)}")
# if not results:
# return "" if not return_raw else []
# if return_raw:
# return results
# else:
# # 返回所有找到的内容,用换行分隔
# return "\n".join(str(result["content"]) for result in results)
init_prompt() init_prompt()

View File

@ -1,122 +0,0 @@
import os
import toml
from .global_logger import logger
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
# 无效实体
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
def _load_config(config, config_file_path):
"""读取TOML格式的配置文件"""
if not os.path.exists(config_file_path):
return
with open(config_file_path, "r", encoding="utf-8") as f:
file_config = toml.load(f)
if "llm_providers" in file_config:
for provider in file_config["llm_providers"]:
if provider["name"] not in config["llm_providers"]:
config["llm_providers"][provider["name"]] = dict()
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
if "entity_extract" in file_config:
config["entity_extract"] = file_config["entity_extract"]
if "rdf_build" in file_config:
config["rdf_build"] = file_config["rdf_build"]
if "embedding" in file_config:
config["embedding"] = file_config["embedding"]
if "rag" in file_config:
config["rag"] = file_config["rag"]
if "qa" in file_config:
config["qa"] = file_config["qa"]
if "persistence" in file_config:
config["persistence"] = file_config["persistence"]
logger.info(f"Configurations loaded from file: {config_file_path}")
global_config = dict(
{
"llm_providers": {
"localhost": {
"base_url": "http://localhost:8000",
"api_key": "",
}
},
"entity_extract": {
"llm": {
"provider": "localhost",
"model": "entity-extract",
}
},
"rdf_build": {
"llm": {
"provider": "localhost",
"model": "rdf-build",
}
},
"embedding": {
"provider": "localhost",
"model": "embed",
"dimension": 1024,
},
"rag": {
"params": {
"synonym_search_top_k": 10,
"synonym_threshold": 0.75,
}
},
"qa": {
"params": {
"relation_search_top_k": 10,
"relation_threshold": 0.75,
"paragraph_search_top_k": 10,
"paragraph_node_weight": 0.05,
"ent_filter_top_k": 10,
"ppr_damping": 0.8,
"res_top_k": 10,
},
"llm": {
"provider": "localhost",
"model": "qa",
},
},
"persistence": {
"data_root_path": "data",
"raw_data_path": "data/raw.json",
"openie_data_path": "data/openie.json",
"embedding_data_dir": "data/embedding",
"rag_data_dir": "data/rag",
},
}
)
# _load_config(global_config, parser.parse_args().config_path)
file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path)
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
_load_config(global_config, config_path)

View File

@ -4,7 +4,7 @@ from typing import List
from .global_logger import logger from .global_logger import logger
from . import prompt_template from . import prompt_template
from .config import global_config, INVALID_ENTITY from .lpmmconfig import global_config, INVALID_ENTITY
from .llm_client import LLMClient from .llm_client import LLMClient
from .utils.json_fix import fix_broken_generated_json from .utils.json_fix import fix_broken_generated_json

View File

@ -11,7 +11,7 @@ import tqdm
from ..lib import quick_algo from ..lib import quick_algo
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .config import ( from .lpmmconfig import (
ENT_NAMESPACE, ENT_NAMESPACE,
PG_NAMESPACE, PG_NAMESPACE,
RAG_ENT_CNT_NAMESPACE, RAG_ENT_CNT_NAMESPACE,

View File

@ -1,6 +1,6 @@
import os import os
import toml import toml
import argparse from .global_logger import logger
PG_NAMESPACE = "paragraph" PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity" ENT_NAMESPACE = "entity"
@ -56,42 +56,33 @@ def _load_config(config, config_file_path):
if "persistence" in file_config: if "persistence" in file_config:
config["persistence"] = file_config["persistence"] config["persistence"] = file_config["persistence"]
print(config)
print("Configurations loaded from file: ", config_file_path)
logger.debug("Configurations loaded from file: ", config_file_path)
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
parser.add_argument(
"--config_path",
type=str,
default="lpmm_config.toml",
help="Path to the configuration file",
)
global_config = dict( global_config = dict(
{ {
"llm_providers": { "llm_providers": {
"localhost": { "localhost": {
"base_url": "https://api.siliconflow.cn/v1", "base_url": "",
"api_key": "sk-ospynxadyorf", "api_key": "",
} }
}, },
"entity_extract": { "entity_extract": {
"llm": { "llm": {
"provider": "localhost", "provider": "",
"model": "Pro/deepseek-ai/DeepSeek-V3", "model": "",
} }
}, },
"rdf_build": { "rdf_build": {
"llm": { "llm": {
"provider": "localhost", "provider": "",
"model": "Pro/deepseek-ai/DeepSeek-V3", "model": "",
} }
}, },
"embedding": { "embedding": {
"provider": "localhost", "provider": "",
"model": "Pro/BAAI/bge-m3", "model": "",
"dimension": 1024, "dimension": 1024,
}, },
"rag": { "rag": {
@ -111,24 +102,23 @@ global_config = dict(
"res_top_k": 10, "res_top_k": 10,
}, },
"llm": { "llm": {
"provider": "localhost", "provider": "",
"model": "qa", "model": "",
}, },
}, },
"persistence": { "persistence": {
"data_root_path": "data", "data_root_path": "",
"raw_data_path": "data/raw.json", "raw_data_path": "",
"openie_data_path": "data/openie.json", "openie_data_path": "",
"embedding_data_dir": "data/embedding", "embedding_data_dir": "",
"rag_data_dir": "data/rag", "rag_data_dir": "",
}, },
"info_extraction":{ "info_extraction": {
"workers": 10, "workers": 10,
} },
} }
) )
# _load_config(global_config, parser.parse_args().config_path)
file_path = os.path.abspath(__file__) file_path = os.path.abspath(__file__)
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir) root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)

View File

@ -2,7 +2,7 @@ import json
from typing import Any, Dict, List from typing import Any, Dict, List
from .config import INVALID_ENTITY, global_config from .lpmmconfig import INVALID_ENTITY, global_config
def _filter_invalid_entities(entities: List[str]) -> List[str]: def _filter_invalid_entities(entities: List[str]) -> List[str]:

View File

@ -1,73 +0,0 @@
from typing import List
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
输出格式示例
[ "实体A", "实体B", "实体C" ]
请注意以下要求
- 将代词转化为对应的实体命名以避免指代不清
- 尽可能多的提取出段落中的全部实体
"""
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
messages = [
LLMMessage("system", entity_extract_system_prompt).to_dict(),
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
]
return messages
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF资源描述框架由节点和边组成节点表示实体/资源、属性边则表示了实体和实体之间的关系以及实体和属性的关系。构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
请使用JSON回复使用三元组的JSON列表输出RDF图中的关系每个三元组代表一个关系
输出格式示例
[
["某实体","关系","某属性"],
["某实体","关系","某实体"],
["某资源","关系","某属性"]
]
请注意以下要求
- 每个三元组应包含每个段落的实体命名列表中的至少一个命名实体但最好是两个
- 将代词转化为对应的实体命名以避免指代不清
"""
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
messages = [
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
LLMMessage(
"user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
).to_dict(),
]
return messages
qa_system_prompt = """
你是一个性能优异的QA系统请根据给定的问题和一些可能对你有帮助的信息作出回答
请注意以下要求
- 你可以使用给定的信息来回答问题但请不要直接引用它们
- 你的回答应该简洁明了避免冗长的解释
- 如果你无法回答问题请直接说我不知道
"""
def build_qa_context(
question: str, knowledge: list[(str, str, str)]
) -> List[LLMMessage]:
knowledge = "\n".join(
[f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
)
messages = [
LLMMessage("system", qa_system_prompt).to_dict(),
LLMMessage(
"user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
).to_dict(),
]
return messages

View File

@ -2,7 +2,6 @@ import time
from typing import Tuple, List, Dict from typing import Tuple, List, Dict
from .global_logger import logger from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
from .llm_client import LLMClient from .llm_client import LLMClient
from .kg_manager import KGManager from .kg_manager import KGManager

View File

@ -2,7 +2,7 @@ import json
import os import os
from .global_logger import logger from .global_logger import logger
from .config import global_config from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256

View File

@ -1,23 +0,0 @@
import json
with open("raw.txt", "r", encoding="utf-8") as f:
raw = f.read()
# 一行一行的处理文件
paragraphs = []
paragraph = ""
for line in raw.split("\n"):
if line.strip() == "":
# 有空行,表示段落结束
if paragraph != "":
paragraphs.append(paragraph)
paragraph = ""
else:
paragraph += line + "\n"
if paragraph != "":
paragraphs.append(paragraph)
with open("raw.json", "w", encoding="utf-8") as f:
json.dump(paragraphs, f, ensure_ascii=False, indent=4)

View File

@ -1,58 +0,0 @@
import jsonlines
from pathlib import Path
from typing import List, Dict, Any, Union, Optional
from src.config import global_config as config
class DataLoader:
"""数据加载工具类,用于从/data目录下加载各种格式的数据文件"""
def __init__(self, custom_data_dir: Optional[Union[str, Path]] = None):
"""
初始化数据加载器
Args:
custom_data_dir: 可选的自定义数据目录路径如果不提供则使用配置文件中的默认路径
"""
self.data_dir = (
Path(custom_data_dir)
if custom_data_dir
else Path(config["persistence"]["data_root_path"])
)
if not self.data_dir.exists():
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
def _resolve_file_path(self, filename: str) -> Path:
"""
解析文件路径
Args:
filename: 文件名
Returns:
完整的文件路径
Raises:
FileNotFoundError: 当文件不存在时抛出
"""
file_path = self.data_dir / filename
if not file_path.exists():
raise FileNotFoundError(f"文件 {filename} 不存在")
return file_path
def load_jsonl(self, filename: str) -> List[Dict[str, Any]]:
"""
加载JSONL格式的文件
Args:
filename: 文件名
Returns:
包含所有数据的列表
"""
file_path = self._resolve_file_path(filename)
data = []
with jsonlines.open(file_path) as reader:
for obj in reader:
data.append(obj)
return data

View File

@ -1,17 +0,0 @@
import networkx as nx
from matplotlib import pyplot as plt
def draw_graph_and_show(graph):
"""绘制图并显示画布大小1280*1280"""
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
nx.draw_networkx(
graph,
node_size=100,
width=0.5,
with_labels=True,
labels=nx.get_node_attributes(graph, "content"),
font_family="Sarasa Mono SC",
font_size=8,
)
fig.show()