mirror of https://github.com/Mai-with-u/MaiBot.git
大力修复
parent
d72b34709c
commit
bf7827a571
|
|
@ -815,7 +815,7 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
||||||
|
|
||||||
# 从后向前替换,避免位置改变
|
# 从后向前替换,避免位置改变
|
||||||
converted_timestamps.reverse()
|
converted_timestamps.reverse()
|
||||||
for ts, match, readable_time in converted_timestamps:
|
for _, match, readable_time in converted_timestamps:
|
||||||
pattern_instance = re.escape(match.group(0))
|
pattern_instance = re.escape(match.group(0))
|
||||||
if readable_time in readable_time_used:
|
if readable_time in readable_time_used:
|
||||||
# 如果相同格式的时间已存在,替换为空字符串
|
# 如果相同格式的时间已存在,替换为空字符串
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,13 @@ from ...moods.moods import MoodManager
|
||||||
from ....individuality.individuality import Individuality
|
from ....individuality.individuality import Individuality
|
||||||
from ...memory_system.Hippocampus import HippocampusManager
|
from ...memory_system.Hippocampus import HippocampusManager
|
||||||
from ...schedule.schedule_generator import bot_schedule
|
from ...schedule.schedule_generator import bot_schedule
|
||||||
from ...config.config import global_config
|
from src.config.config import global_config
|
||||||
from ...person_info.relationship_manager import relationship_manager
|
from ...person_info.relationship_manager import relationship_manager
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.plugins.knowledge.src.qa_manager import qa_manager
|
from src.plugins.knowledge.knowledge_lib import qa_manager
|
||||||
|
|
||||||
|
from src.plugins.chat.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_module_logger("prompt")
|
logger = get_module_logger("prompt")
|
||||||
|
|
||||||
|
|
@ -54,7 +56,7 @@ class PromptBuilder:
|
||||||
self.activate_messages = ""
|
self.activate_messages = ""
|
||||||
|
|
||||||
async def _build_prompt(
|
async def _build_prompt(
|
||||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
self, chat_stream: ChatStream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
# 开始构建prompt
|
# 开始构建prompt
|
||||||
prompt_personality = "你"
|
prompt_personality = "你"
|
||||||
|
|
@ -102,16 +104,14 @@ class PromptBuilder:
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
)
|
)
|
||||||
|
related_memory_info = ""
|
||||||
if related_memory:
|
if related_memory:
|
||||||
related_memory_info = ""
|
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
related_memory_info += memory[1]
|
related_memory_info += memory[1]
|
||||||
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
|
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
|
||||||
memory_prompt = await global_prompt_manager.format_prompt(
|
memory_prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_prompt", related_memory_info=related_memory_info
|
"memory_prompt", related_memory_info=related_memory_info
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
related_memory_info = ""
|
|
||||||
|
|
||||||
# print(f"相关记忆:{related_memory_info}")
|
# print(f"相关记忆:{related_memory_info}")
|
||||||
|
|
||||||
|
|
@ -230,224 +230,6 @@ class PromptBuilder:
|
||||||
related_info += qa_manager.get_knowledge(message)
|
related_info += qa_manager.get_knowledge(message)
|
||||||
|
|
||||||
return related_info
|
return related_info
|
||||||
# start_time = time.time()
|
|
||||||
# related_info = ""
|
|
||||||
# logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
|
||||||
|
|
||||||
# # 1. 先从LLM获取主题,类似于记忆系统的做法
|
|
||||||
# topics = []
|
|
||||||
# try:
|
|
||||||
# # 先尝试使用记忆系统的方法获取主题
|
|
||||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
|
||||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
|
||||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
|
||||||
|
|
||||||
# # 提取关键词
|
|
||||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
|
||||||
# if not topics:
|
|
||||||
# topics = []
|
|
||||||
# else:
|
|
||||||
# topics = [
|
|
||||||
# topic.strip()
|
|
||||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
|
||||||
# if topic.strip()
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
|
||||||
# # 如果LLM提取失败,使用jieba分词提取关键词作为备选
|
|
||||||
# words = jieba.cut(message)
|
|
||||||
# topics = [word for word in words if len(word) > 1][:5]
|
|
||||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
|
||||||
|
|
||||||
# 如果无法提取到主题,直接使用整个消息
|
|
||||||
# if not topics:
|
|
||||||
# logger.info("未能提取到任何主题,使用整个消息进行查询")
|
|
||||||
# embedding = await get_embedding(message, request_type="prompt_build")
|
|
||||||
# if not embedding:
|
|
||||||
# logger.error("获取消息嵌入向量失败")
|
|
||||||
# return ""
|
|
||||||
|
|
||||||
# related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
|
||||||
# logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
|
||||||
# return related_info
|
|
||||||
|
|
||||||
# # 2. 对每个主题进行知识库查询
|
|
||||||
# logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
|
||||||
|
|
||||||
# # 优化:批量获取嵌入向量,减少API调用
|
|
||||||
# embeddings = {}
|
|
||||||
# topics_batch = [topic for topic in topics if len(topic) > 0]
|
|
||||||
# if message: # 确保消息非空
|
|
||||||
# topics_batch.append(message)
|
|
||||||
|
|
||||||
# 批量获取嵌入向量
|
|
||||||
# embed_start_time = time.time()
|
|
||||||
# for text in topics_batch:
|
|
||||||
# if not text or len(text.strip()) == 0:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# embedding = await get_embedding(text, request_type="prompt_build")
|
|
||||||
# if embedding:
|
|
||||||
# embeddings[text] = embedding
|
|
||||||
# else:
|
|
||||||
# logger.warning(f"获取'{text}'的嵌入向量失败")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
# logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
|
||||||
|
|
||||||
# if not embeddings:
|
|
||||||
# logger.error("所有嵌入向量获取失败")
|
|
||||||
# return ""
|
|
||||||
|
|
||||||
# # 3. 对每个主题进行知识库查询
|
|
||||||
# all_results = []
|
|
||||||
# query_start_time = time.time()
|
|
||||||
|
|
||||||
# # 首先添加原始消息的查询结果
|
|
||||||
# if message in embeddings:
|
|
||||||
# original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
|
||||||
# if original_results:
|
|
||||||
# for result in original_results:
|
|
||||||
# result["topic"] = "原始消息"
|
|
||||||
# all_results.extend(original_results)
|
|
||||||
# logger.info(f"原始消息查询到{len(original_results)}条结果")
|
|
||||||
|
|
||||||
# # 然后添加每个主题的查询结果
|
|
||||||
# for topic in topics:
|
|
||||||
# if not topic or topic not in embeddings:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# try:
|
|
||||||
|
|
||||||
# # topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
|
||||||
# # if topic_results:
|
|
||||||
# # # 添加主题标记
|
|
||||||
# # for result in topic_results:
|
|
||||||
# # result["topic"] = topic
|
|
||||||
# # all_results.extend(topic_results)
|
|
||||||
# # logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
|
||||||
|
|
||||||
# logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
|
||||||
|
|
||||||
# # 4. 去重和过滤
|
|
||||||
# process_start_time = time.time()
|
|
||||||
# unique_contents = set()
|
|
||||||
# filtered_results = []
|
|
||||||
# for result in all_results:
|
|
||||||
# content = result["content"]
|
|
||||||
# if content not in unique_contents:
|
|
||||||
# unique_contents.add(content)
|
|
||||||
# filtered_results.append(result)
|
|
||||||
|
|
||||||
# # 5. 按相似度排序
|
|
||||||
# filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
|
||||||
|
|
||||||
# # 6. 限制总数量(最多10条)
|
|
||||||
# filtered_results = filtered_results[:10]
|
|
||||||
# logger.info(
|
|
||||||
# f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 7. 格式化输出
|
|
||||||
# if filtered_results:
|
|
||||||
# format_start_time = time.time()
|
|
||||||
# grouped_results = {}
|
|
||||||
# for result in filtered_results:
|
|
||||||
# topic = result["topic"]
|
|
||||||
# if topic not in grouped_results:
|
|
||||||
# grouped_results[topic] = []
|
|
||||||
# grouped_results[topic].append(result)
|
|
||||||
|
|
||||||
# 按主题组织输出
|
|
||||||
# for topic, results in grouped_results.items():
|
|
||||||
# related_info += f"【主题: {topic}】\n"
|
|
||||||
# for _i, result in enumerate(results, 1):
|
|
||||||
# _similarity = result["similarity"]
|
|
||||||
# content = result["content"].strip()
|
|
||||||
# # 调试:为内容添加序号和相似度信息
|
|
||||||
# # related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
|
||||||
# related_info += f"{content}\n"
|
|
||||||
# related_info += "\n"
|
|
||||||
|
|
||||||
# # logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
|
||||||
|
|
||||||
# logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
|
||||||
# return related_info
|
|
||||||
|
|
||||||
# def get_info_from_db(
|
|
||||||
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
|
||||||
# ) -> Union[str, list]:
|
|
||||||
# if not query_embedding:
|
|
||||||
# return "" if not return_raw else []
|
|
||||||
# # 使用余弦相似度计算
|
|
||||||
# pipeline = [
|
|
||||||
# {
|
|
||||||
# "$addFields": {
|
|
||||||
# "dotProduct": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": {"$range": [0, {"$size": "$embedding"}]},
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {
|
|
||||||
# "$add": [
|
|
||||||
# "$$value",
|
|
||||||
# {
|
|
||||||
# "$multiply": [
|
|
||||||
# {"$arrayElemAt": ["$embedding", "$$this"]},
|
|
||||||
# {"$arrayElemAt": [query_embedding, "$$this"]},
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "magnitude1": {
|
|
||||||
# "$sqrt": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": "$embedding",
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "magnitude2": {
|
|
||||||
# "$sqrt": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": query_embedding,
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
|
||||||
# {
|
|
||||||
# "$match": {
|
|
||||||
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# {"$sort": {"similarity": -1}},
|
|
||||||
# {"$limit": limit},
|
|
||||||
# {"$project": {"content": 1, "similarity": 1}},
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# results = list(db.knowledges.aggregate(pipeline))
|
|
||||||
# logger.debug(f"知识库查询结果数量: {len(results)}")
|
|
||||||
|
|
||||||
# if not results:
|
|
||||||
# return "" if not return_raw else []
|
|
||||||
|
|
||||||
# if return_raw:
|
|
||||||
# return results
|
|
||||||
# else:
|
|
||||||
# # 返回所有找到的内容,用换行分隔
|
|
||||||
# return "\n".join(str(result["content"]) for result in results)
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
|
||||||
|
|
@ -1,122 +0,0 @@
|
||||||
import os
|
|
||||||
import toml
|
|
||||||
from .global_logger import logger
|
|
||||||
PG_NAMESPACE = "paragraph"
|
|
||||||
ENT_NAMESPACE = "entity"
|
|
||||||
REL_NAMESPACE = "relation"
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
|
||||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
|
||||||
|
|
||||||
# 无效实体
|
|
||||||
INVALID_ENTITY = [
|
|
||||||
"",
|
|
||||||
"你",
|
|
||||||
"他",
|
|
||||||
"她",
|
|
||||||
"它",
|
|
||||||
"我们",
|
|
||||||
"你们",
|
|
||||||
"他们",
|
|
||||||
"她们",
|
|
||||||
"它们",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _load_config(config, config_file_path):
|
|
||||||
"""读取TOML格式的配置文件"""
|
|
||||||
if not os.path.exists(config_file_path):
|
|
||||||
return
|
|
||||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
|
||||||
file_config = toml.load(f)
|
|
||||||
|
|
||||||
if "llm_providers" in file_config:
|
|
||||||
for provider in file_config["llm_providers"]:
|
|
||||||
if provider["name"] not in config["llm_providers"]:
|
|
||||||
config["llm_providers"][provider["name"]] = dict()
|
|
||||||
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
|
|
||||||
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
|
|
||||||
|
|
||||||
if "entity_extract" in file_config:
|
|
||||||
config["entity_extract"] = file_config["entity_extract"]
|
|
||||||
|
|
||||||
if "rdf_build" in file_config:
|
|
||||||
config["rdf_build"] = file_config["rdf_build"]
|
|
||||||
|
|
||||||
if "embedding" in file_config:
|
|
||||||
config["embedding"] = file_config["embedding"]
|
|
||||||
|
|
||||||
if "rag" in file_config:
|
|
||||||
config["rag"] = file_config["rag"]
|
|
||||||
|
|
||||||
if "qa" in file_config:
|
|
||||||
config["qa"] = file_config["qa"]
|
|
||||||
|
|
||||||
if "persistence" in file_config:
|
|
||||||
config["persistence"] = file_config["persistence"]
|
|
||||||
|
|
||||||
logger.info(f"Configurations loaded from file: {config_file_path}")
|
|
||||||
|
|
||||||
global_config = dict(
|
|
||||||
{
|
|
||||||
"llm_providers": {
|
|
||||||
"localhost": {
|
|
||||||
"base_url": "http://localhost:8000",
|
|
||||||
"api_key": "",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"entity_extract": {
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "entity-extract",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"rdf_build": {
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "rdf-build",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"embedding": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "embed",
|
|
||||||
"dimension": 1024,
|
|
||||||
},
|
|
||||||
"rag": {
|
|
||||||
"params": {
|
|
||||||
"synonym_search_top_k": 10,
|
|
||||||
"synonym_threshold": 0.75,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"qa": {
|
|
||||||
"params": {
|
|
||||||
"relation_search_top_k": 10,
|
|
||||||
"relation_threshold": 0.75,
|
|
||||||
"paragraph_search_top_k": 10,
|
|
||||||
"paragraph_node_weight": 0.05,
|
|
||||||
"ent_filter_top_k": 10,
|
|
||||||
"ppr_damping": 0.8,
|
|
||||||
"res_top_k": 10,
|
|
||||||
},
|
|
||||||
"llm": {
|
|
||||||
"provider": "localhost",
|
|
||||||
"model": "qa",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"persistence": {
|
|
||||||
"data_root_path": "data",
|
|
||||||
"raw_data_path": "data/raw.json",
|
|
||||||
"openie_data_path": "data/openie.json",
|
|
||||||
"embedding_data_dir": "data/embedding",
|
|
||||||
"rag_data_dir": "data/rag",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# _load_config(global_config, parser.parse_args().config_path)
|
|
||||||
file_path = os.path.abspath(__file__)
|
|
||||||
dir_path = os.path.dirname(file_path)
|
|
||||||
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
|
||||||
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
|
|
||||||
_load_config(global_config, config_path)
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import List
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from . import prompt_template
|
from . import prompt_template
|
||||||
from .config import global_config, INVALID_ENTITY
|
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
from .utils.json_fix import fix_broken_generated_json
|
from .utils.json_fix import fix_broken_generated_json
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import tqdm
|
||||||
from ..lib import quick_algo
|
from ..lib import quick_algo
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||||
from .config import (
|
from .lpmmconfig import (
|
||||||
ENT_NAMESPACE,
|
ENT_NAMESPACE,
|
||||||
PG_NAMESPACE,
|
PG_NAMESPACE,
|
||||||
RAG_ENT_CNT_NAMESPACE,
|
RAG_ENT_CNT_NAMESPACE,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
import argparse
|
from .global_logger import logger
|
||||||
|
|
||||||
PG_NAMESPACE = "paragraph"
|
PG_NAMESPACE = "paragraph"
|
||||||
ENT_NAMESPACE = "entity"
|
ENT_NAMESPACE = "entity"
|
||||||
|
|
@ -56,42 +56,33 @@ def _load_config(config, config_file_path):
|
||||||
|
|
||||||
if "persistence" in file_config:
|
if "persistence" in file_config:
|
||||||
config["persistence"] = file_config["persistence"]
|
config["persistence"] = file_config["persistence"]
|
||||||
print(config)
|
|
||||||
print("Configurations loaded from file: ", config_file_path)
|
|
||||||
|
|
||||||
|
|
||||||
|
logger.debug("Configurations loaded from file: ", config_file_path)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path",
|
|
||||||
type=str,
|
|
||||||
default="lpmm_config.toml",
|
|
||||||
help="Path to the configuration file",
|
|
||||||
)
|
|
||||||
|
|
||||||
global_config = dict(
|
global_config = dict(
|
||||||
{
|
{
|
||||||
"llm_providers": {
|
"llm_providers": {
|
||||||
"localhost": {
|
"localhost": {
|
||||||
"base_url": "https://api.siliconflow.cn/v1",
|
"base_url": "",
|
||||||
"api_key": "sk-ospynxadyorf",
|
"api_key": "",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"entity_extract": {
|
"entity_extract": {
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "localhost",
|
"provider": "",
|
||||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
"model": "",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"rdf_build": {
|
"rdf_build": {
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "localhost",
|
"provider": "",
|
||||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
"model": "",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embedding": {
|
"embedding": {
|
||||||
"provider": "localhost",
|
"provider": "",
|
||||||
"model": "Pro/BAAI/bge-m3",
|
"model": "",
|
||||||
"dimension": 1024,
|
"dimension": 1024,
|
||||||
},
|
},
|
||||||
"rag": {
|
"rag": {
|
||||||
|
|
@ -111,24 +102,23 @@ global_config = dict(
|
||||||
"res_top_k": 10,
|
"res_top_k": 10,
|
||||||
},
|
},
|
||||||
"llm": {
|
"llm": {
|
||||||
"provider": "localhost",
|
"provider": "",
|
||||||
"model": "qa",
|
"model": "",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"persistence": {
|
"persistence": {
|
||||||
"data_root_path": "data",
|
"data_root_path": "",
|
||||||
"raw_data_path": "data/raw.json",
|
"raw_data_path": "",
|
||||||
"openie_data_path": "data/openie.json",
|
"openie_data_path": "",
|
||||||
"embedding_data_dir": "data/embedding",
|
"embedding_data_dir": "",
|
||||||
"rag_data_dir": "data/rag",
|
"rag_data_dir": "",
|
||||||
},
|
},
|
||||||
"info_extraction":{
|
"info_extraction": {
|
||||||
"workers": 10,
|
"workers": 10,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# _load_config(global_config, parser.parse_args().config_path)
|
|
||||||
file_path = os.path.abspath(__file__)
|
file_path = os.path.abspath(__file__)
|
||||||
dir_path = os.path.dirname(file_path)
|
dir_path = os.path.dirname(file_path)
|
||||||
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import json
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from .config import INVALID_ENTITY, global_config
|
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||||
|
|
||||||
|
|
||||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||||
|
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from .llm_client import LLMMessage
|
|
||||||
|
|
||||||
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
|
||||||
|
|
||||||
输出格式示例:
|
|
||||||
[ "实体A", "实体B", "实体C" ]
|
|
||||||
|
|
||||||
请注意以下要求:
|
|
||||||
- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
|
|
||||||
- 尽可能多的提取出段落中的全部实体;
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
|
|
||||||
messages = [
|
|
||||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
|
||||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
|
||||||
]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
|
||||||
|
|
||||||
请使用JSON回复,使用三元组的JSON列表输出RDF图中的关系(每个三元组代表一个关系)。
|
|
||||||
|
|
||||||
输出格式示例:
|
|
||||||
[
|
|
||||||
["某实体","关系","某属性"],
|
|
||||||
["某实体","关系","某实体"],
|
|
||||||
["某资源","关系","某属性"]
|
|
||||||
]
|
|
||||||
|
|
||||||
请注意以下要求:
|
|
||||||
- 每个三元组应包含每个段落的实体命名列表中的至少一个命名实体,但最好是两个。
|
|
||||||
- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
|
|
||||||
messages = [
|
|
||||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
|
||||||
LLMMessage(
|
|
||||||
"user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
|
|
||||||
).to_dict(),
|
|
||||||
]
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
qa_system_prompt = """
|
|
||||||
你是一个性能优异的QA系统。请根据给定的问题和一些可能对你有帮助的信息作出回答。
|
|
||||||
|
|
||||||
请注意以下要求:
|
|
||||||
- 你可以使用给定的信息来回答问题,但请不要直接引用它们。
|
|
||||||
- 你的回答应该简洁明了,避免冗长的解释。
|
|
||||||
- 如果你无法回答问题,请直接说“我不知道”。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def build_qa_context(
|
|
||||||
question: str, knowledge: list[(str, str, str)]
|
|
||||||
) -> List[LLMMessage]:
|
|
||||||
knowledge = "\n".join(
|
|
||||||
[f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
|
|
||||||
)
|
|
||||||
messages = [
|
|
||||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
|
||||||
LLMMessage(
|
|
||||||
"user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
|
|
||||||
).to_dict(),
|
|
||||||
]
|
|
||||||
return messages
|
|
||||||
|
|
@ -2,7 +2,6 @@ import time
|
||||||
from typing import Tuple, List, Dict
|
from typing import Tuple, List, Dict
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
# from . import prompt_template
|
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from .config import global_config
|
from .lpmmconfig import global_config
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
with open("raw.txt", "r", encoding="utf-8") as f:
|
|
||||||
raw = f.read()
|
|
||||||
|
|
||||||
# 一行一行的处理文件
|
|
||||||
paragraphs = []
|
|
||||||
paragraph = ""
|
|
||||||
for line in raw.split("\n"):
|
|
||||||
if line.strip() == "":
|
|
||||||
# 有空行,表示段落结束
|
|
||||||
if paragraph != "":
|
|
||||||
paragraphs.append(paragraph)
|
|
||||||
paragraph = ""
|
|
||||||
else:
|
|
||||||
paragraph += line + "\n"
|
|
||||||
|
|
||||||
if paragraph != "":
|
|
||||||
paragraphs.append(paragraph)
|
|
||||||
|
|
||||||
with open("raw.json", "w", encoding="utf-8") as f:
|
|
||||||
json.dump(paragraphs, f, ensure_ascii=False, indent=4)
|
|
||||||
|
|
@ -1,58 +0,0 @@
|
||||||
import jsonlines
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Any, Union, Optional
|
|
||||||
from src.config import global_config as config
|
|
||||||
|
|
||||||
|
|
||||||
class DataLoader:
|
|
||||||
"""数据加载工具类,用于从/data目录下加载各种格式的数据文件"""
|
|
||||||
|
|
||||||
def __init__(self, custom_data_dir: Optional[Union[str, Path]] = None):
|
|
||||||
"""
|
|
||||||
初始化数据加载器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
|
|
||||||
"""
|
|
||||||
self.data_dir = (
|
|
||||||
Path(custom_data_dir)
|
|
||||||
if custom_data_dir
|
|
||||||
else Path(config["persistence"]["data_root_path"])
|
|
||||||
)
|
|
||||||
if not self.data_dir.exists():
|
|
||||||
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
|
|
||||||
|
|
||||||
def _resolve_file_path(self, filename: str) -> Path:
|
|
||||||
"""
|
|
||||||
解析文件路径
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: 文件名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
完整的文件路径
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: 当文件不存在时抛出
|
|
||||||
"""
|
|
||||||
file_path = self.data_dir / filename
|
|
||||||
if not file_path.exists():
|
|
||||||
raise FileNotFoundError(f"文件 {filename} 不存在")
|
|
||||||
return file_path
|
|
||||||
|
|
||||||
def load_jsonl(self, filename: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
加载JSONL格式的文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: 文件名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含所有数据的列表
|
|
||||||
"""
|
|
||||||
file_path = self._resolve_file_path(filename)
|
|
||||||
data = []
|
|
||||||
with jsonlines.open(file_path) as reader:
|
|
||||||
for obj in reader:
|
|
||||||
data.append(obj)
|
|
||||||
return data
|
|
||||||
|
|
@ -1,17 +0,0 @@
|
||||||
import networkx as nx
|
|
||||||
from matplotlib import pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def draw_graph_and_show(graph):
|
|
||||||
"""绘制图并显示,画布大小1280*1280"""
|
|
||||||
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
|
|
||||||
nx.draw_networkx(
|
|
||||||
graph,
|
|
||||||
node_size=100,
|
|
||||||
width=0.5,
|
|
||||||
with_labels=True,
|
|
||||||
labels=nx.get_node_attributes(graph, "content"),
|
|
||||||
font_family="Sarasa Mono SC",
|
|
||||||
font_size=8,
|
|
||||||
)
|
|
||||||
fig.show()
|
|
||||||
Loading…
Reference in New Issue