mirror of https://github.com/Mai-with-u/MaiBot.git
大力修复
parent
d72b34709c
commit
bf7827a571
|
|
@ -815,7 +815,7 @@ def parse_text_timestamps(text: str, mode: str = "normal") -> str:
|
|||
|
||||
# 从后向前替换,避免位置改变
|
||||
converted_timestamps.reverse()
|
||||
for ts, match, readable_time in converted_timestamps:
|
||||
for _, match, readable_time in converted_timestamps:
|
||||
pattern_instance = re.escape(match.group(0))
|
||||
if readable_time in readable_time_used:
|
||||
# 如果相同格式的时间已存在,替换为空字符串
|
||||
|
|
|
|||
|
|
@ -9,11 +9,13 @@ from ...moods.moods import MoodManager
|
|||
from ....individuality.individuality import Individuality
|
||||
from ...memory_system.Hippocampus import HippocampusManager
|
||||
from ...schedule.schedule_generator import bot_schedule
|
||||
from ...config.config import global_config
|
||||
from src.config.config import global_config
|
||||
from ...person_info.relationship_manager import relationship_manager
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugins.knowledge.src.qa_manager import qa_manager
|
||||
from src.plugins.knowledge.knowledge_lib import qa_manager
|
||||
|
||||
from src.plugins.chat.chat_stream import ChatStream
|
||||
|
||||
logger = get_module_logger("prompt")
|
||||
|
||||
|
|
@ -54,7 +56,7 @@ class PromptBuilder:
|
|||
self.activate_messages = ""
|
||||
|
||||
async def _build_prompt(
|
||||
self, chat_stream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
self, chat_stream: ChatStream, message_txt: str, sender_name: str = "某人", stream_id: Optional[int] = None
|
||||
) -> tuple[str, str]:
|
||||
# 开始构建prompt
|
||||
prompt_personality = "你"
|
||||
|
|
@ -102,16 +104,14 @@ class PromptBuilder:
|
|||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
related_memory_info = ""
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
# memory_prompt = f"你想起你之前见过的事情:{related_memory_info}。\n以上是你的回忆,不一定是目前聊天里的人说的,也不一定是现在发生的事情,请记住。\n"
|
||||
memory_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", related_memory_info=related_memory_info
|
||||
)
|
||||
else:
|
||||
related_memory_info = ""
|
||||
|
||||
# print(f"相关记忆:{related_memory_info}")
|
||||
|
||||
|
|
@ -230,224 +230,6 @@ class PromptBuilder:
|
|||
related_info += qa_manager.get_knowledge(message)
|
||||
|
||||
return related_info
|
||||
# start_time = time.time()
|
||||
# related_info = ""
|
||||
# logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
|
||||
# # 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||
# topics = []
|
||||
# try:
|
||||
# # 先尝试使用记忆系统的方法获取主题
|
||||
# hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||
# topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||
|
||||
# # 提取关键词
|
||||
# topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||
# if not topics:
|
||||
# topics = []
|
||||
# else:
|
||||
# topics = [
|
||||
# topic.strip()
|
||||
# for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||
# if topic.strip()
|
||||
# ]
|
||||
|
||||
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||
# except Exception as e:
|
||||
# logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||
# # 如果LLM提取失败,使用jieba分词提取关键词作为备选
|
||||
# words = jieba.cut(message)
|
||||
# topics = [word for word in words if len(word) > 1][:5]
|
||||
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||
|
||||
# 如果无法提取到主题,直接使用整个消息
|
||||
# if not topics:
|
||||
# logger.info("未能提取到任何主题,使用整个消息进行查询")
|
||||
# embedding = await get_embedding(message, request_type="prompt_build")
|
||||
# if not embedding:
|
||||
# logger.error("获取消息嵌入向量失败")
|
||||
# return ""
|
||||
|
||||
# related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||
# logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||
# return related_info
|
||||
|
||||
# # 2. 对每个主题进行知识库查询
|
||||
# logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||
|
||||
# # 优化:批量获取嵌入向量,减少API调用
|
||||
# embeddings = {}
|
||||
# topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||
# if message: # 确保消息非空
|
||||
# topics_batch.append(message)
|
||||
|
||||
# 批量获取嵌入向量
|
||||
# embed_start_time = time.time()
|
||||
# for text in topics_batch:
|
||||
# if not text or len(text.strip()) == 0:
|
||||
# continue
|
||||
|
||||
# try:
|
||||
# embedding = await get_embedding(text, request_type="prompt_build")
|
||||
# if embedding:
|
||||
# embeddings[text] = embedding
|
||||
# else:
|
||||
# logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||
|
||||
# logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||
|
||||
# if not embeddings:
|
||||
# logger.error("所有嵌入向量获取失败")
|
||||
# return ""
|
||||
|
||||
# # 3. 对每个主题进行知识库查询
|
||||
# all_results = []
|
||||
# query_start_time = time.time()
|
||||
|
||||
# # 首先添加原始消息的查询结果
|
||||
# if message in embeddings:
|
||||
# original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||
# if original_results:
|
||||
# for result in original_results:
|
||||
# result["topic"] = "原始消息"
|
||||
# all_results.extend(original_results)
|
||||
# logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||
|
||||
# # 然后添加每个主题的查询结果
|
||||
# for topic in topics:
|
||||
# if not topic or topic not in embeddings:
|
||||
# continue
|
||||
|
||||
# try:
|
||||
|
||||
# # topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||
# # if topic_results:
|
||||
# # # 添加主题标记
|
||||
# # for result in topic_results:
|
||||
# # result["topic"] = topic
|
||||
# # all_results.extend(topic_results)
|
||||
# # logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||
# except Exception as e:
|
||||
# logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||
|
||||
# logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||
|
||||
# # 4. 去重和过滤
|
||||
# process_start_time = time.time()
|
||||
# unique_contents = set()
|
||||
# filtered_results = []
|
||||
# for result in all_results:
|
||||
# content = result["content"]
|
||||
# if content not in unique_contents:
|
||||
# unique_contents.add(content)
|
||||
# filtered_results.append(result)
|
||||
|
||||
# # 5. 按相似度排序
|
||||
# filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# # 6. 限制总数量(最多10条)
|
||||
# filtered_results = filtered_results[:10]
|
||||
# logger.info(
|
||||
# f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
||||
# )
|
||||
|
||||
# # 7. 格式化输出
|
||||
# if filtered_results:
|
||||
# format_start_time = time.time()
|
||||
# grouped_results = {}
|
||||
# for result in filtered_results:
|
||||
# topic = result["topic"]
|
||||
# if topic not in grouped_results:
|
||||
# grouped_results[topic] = []
|
||||
# grouped_results[topic].append(result)
|
||||
|
||||
# 按主题组织输出
|
||||
# for topic, results in grouped_results.items():
|
||||
# related_info += f"【主题: {topic}】\n"
|
||||
# for _i, result in enumerate(results, 1):
|
||||
# _similarity = result["similarity"]
|
||||
# content = result["content"].strip()
|
||||
# # 调试:为内容添加序号和相似度信息
|
||||
# # related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||
# related_info += f"{content}\n"
|
||||
# related_info += "\n"
|
||||
|
||||
# # logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||
|
||||
# logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||
# return related_info
|
||||
|
||||
# def get_info_from_db(
|
||||
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
# ) -> Union[str, list]:
|
||||
# if not query_embedding:
|
||||
# return "" if not return_raw else []
|
||||
# # 使用余弦相似度计算
|
||||
# pipeline = [
|
||||
# {
|
||||
# "$addFields": {
|
||||
# "dotProduct": {
|
||||
# "$reduce": {
|
||||
# "input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
# "initialValue": 0,
|
||||
# "in": {
|
||||
# "$add": [
|
||||
# "$$value",
|
||||
# {
|
||||
# "$multiply": [
|
||||
# {"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
# {"$arrayElemAt": [query_embedding, "$$this"]},
|
||||
# ]
|
||||
# },
|
||||
# ]
|
||||
# },
|
||||
# }
|
||||
# },
|
||||
# "magnitude1": {
|
||||
# "$sqrt": {
|
||||
# "$reduce": {
|
||||
# "input": "$embedding",
|
||||
# "initialValue": 0,
|
||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# "magnitude2": {
|
||||
# "$sqrt": {
|
||||
# "$reduce": {
|
||||
# "input": query_embedding,
|
||||
# "initialValue": 0,
|
||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||
# }
|
||||
# }
|
||||
# },
|
||||
# }
|
||||
# },
|
||||
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||
# {
|
||||
# "$match": {
|
||||
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
# }
|
||||
# },
|
||||
# {"$sort": {"similarity": -1}},
|
||||
# {"$limit": limit},
|
||||
# {"$project": {"content": 1, "similarity": 1}},
|
||||
# ]
|
||||
|
||||
# results = list(db.knowledges.aggregate(pipeline))
|
||||
# logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||
|
||||
# if not results:
|
||||
# return "" if not return_raw else []
|
||||
|
||||
# if return_raw:
|
||||
# return results
|
||||
# else:
|
||||
# # 返回所有找到的内容,用换行分隔
|
||||
# return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
|
|
|||
|
|
@ -1,122 +0,0 @@
|
|||
import os
|
||||
import toml
|
||||
from .global_logger import logger
|
||||
PG_NAMESPACE = "paragraph"
|
||||
ENT_NAMESPACE = "entity"
|
||||
REL_NAMESPACE = "relation"
|
||||
|
||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
|
||||
# 无效实体
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"她们",
|
||||
"它们",
|
||||
]
|
||||
|
||||
|
||||
def _load_config(config, config_file_path):
|
||||
"""读取TOML格式的配置文件"""
|
||||
if not os.path.exists(config_file_path):
|
||||
return
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
file_config = toml.load(f)
|
||||
|
||||
if "llm_providers" in file_config:
|
||||
for provider in file_config["llm_providers"]:
|
||||
if provider["name"] not in config["llm_providers"]:
|
||||
config["llm_providers"][provider["name"]] = dict()
|
||||
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
|
||||
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
|
||||
|
||||
if "entity_extract" in file_config:
|
||||
config["entity_extract"] = file_config["entity_extract"]
|
||||
|
||||
if "rdf_build" in file_config:
|
||||
config["rdf_build"] = file_config["rdf_build"]
|
||||
|
||||
if "embedding" in file_config:
|
||||
config["embedding"] = file_config["embedding"]
|
||||
|
||||
if "rag" in file_config:
|
||||
config["rag"] = file_config["rag"]
|
||||
|
||||
if "qa" in file_config:
|
||||
config["qa"] = file_config["qa"]
|
||||
|
||||
if "persistence" in file_config:
|
||||
config["persistence"] = file_config["persistence"]
|
||||
|
||||
logger.info(f"Configurations loaded from file: {config_file_path}")
|
||||
|
||||
global_config = dict(
|
||||
{
|
||||
"llm_providers": {
|
||||
"localhost": {
|
||||
"base_url": "http://localhost:8000",
|
||||
"api_key": "",
|
||||
}
|
||||
},
|
||||
"entity_extract": {
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "entity-extract",
|
||||
}
|
||||
},
|
||||
"rdf_build": {
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "rdf-build",
|
||||
}
|
||||
},
|
||||
"embedding": {
|
||||
"provider": "localhost",
|
||||
"model": "embed",
|
||||
"dimension": 1024,
|
||||
},
|
||||
"rag": {
|
||||
"params": {
|
||||
"synonym_search_top_k": 10,
|
||||
"synonym_threshold": 0.75,
|
||||
}
|
||||
},
|
||||
"qa": {
|
||||
"params": {
|
||||
"relation_search_top_k": 10,
|
||||
"relation_threshold": 0.75,
|
||||
"paragraph_search_top_k": 10,
|
||||
"paragraph_node_weight": 0.05,
|
||||
"ent_filter_top_k": 10,
|
||||
"ppr_damping": 0.8,
|
||||
"res_top_k": 10,
|
||||
},
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "qa",
|
||||
},
|
||||
},
|
||||
"persistence": {
|
||||
"data_root_path": "data",
|
||||
"raw_data_path": "data/raw.json",
|
||||
"openie_data_path": "data/openie.json",
|
||||
"embedding_data_dir": "data/embedding",
|
||||
"rag_data_dir": "data/rag",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# _load_config(global_config, parser.parse_args().config_path)
|
||||
file_path = os.path.abspath(__file__)
|
||||
dir_path = os.path.dirname(file_path)
|
||||
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
||||
config_path = os.path.join(root_path, "config", "lpmm_config.toml")
|
||||
_load_config(global_config, config_path)
|
||||
|
|
@ -4,7 +4,7 @@ from typing import List
|
|||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .config import global_config, INVALID_ENTITY
|
||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||
from .llm_client import LLMClient
|
||||
from .utils.json_fix import fix_broken_generated_json
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import tqdm
|
|||
from ..lib import quick_algo
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .config import (
|
||||
from .lpmmconfig import (
|
||||
ENT_NAMESPACE,
|
||||
PG_NAMESPACE,
|
||||
RAG_ENT_CNT_NAMESPACE,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import toml
|
||||
import argparse
|
||||
from .global_logger import logger
|
||||
|
||||
PG_NAMESPACE = "paragraph"
|
||||
ENT_NAMESPACE = "entity"
|
||||
|
|
@ -56,42 +56,33 @@ def _load_config(config, config_file_path):
|
|||
|
||||
if "persistence" in file_config:
|
||||
config["persistence"] = file_config["persistence"]
|
||||
print(config)
|
||||
print("Configurations loaded from file: ", config_file_path)
|
||||
|
||||
logger.debug("Configurations loaded from file: ", config_file_path)
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Configurations for the pipeline")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
default="lpmm_config.toml",
|
||||
help="Path to the configuration file",
|
||||
)
|
||||
|
||||
global_config = dict(
|
||||
{
|
||||
"llm_providers": {
|
||||
"localhost": {
|
||||
"base_url": "https://api.siliconflow.cn/v1",
|
||||
"api_key": "sk-ospynxadyorf",
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
}
|
||||
},
|
||||
"entity_extract": {
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
||||
"provider": "",
|
||||
"model": "",
|
||||
}
|
||||
},
|
||||
"rdf_build": {
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
||||
"provider": "",
|
||||
"model": "",
|
||||
}
|
||||
},
|
||||
"embedding": {
|
||||
"provider": "localhost",
|
||||
"model": "Pro/BAAI/bge-m3",
|
||||
"provider": "",
|
||||
"model": "",
|
||||
"dimension": 1024,
|
||||
},
|
||||
"rag": {
|
||||
|
|
@ -111,24 +102,23 @@ global_config = dict(
|
|||
"res_top_k": 10,
|
||||
},
|
||||
"llm": {
|
||||
"provider": "localhost",
|
||||
"model": "qa",
|
||||
"provider": "",
|
||||
"model": "",
|
||||
},
|
||||
},
|
||||
"persistence": {
|
||||
"data_root_path": "data",
|
||||
"raw_data_path": "data/raw.json",
|
||||
"openie_data_path": "data/openie.json",
|
||||
"embedding_data_dir": "data/embedding",
|
||||
"rag_data_dir": "data/rag",
|
||||
"data_root_path": "",
|
||||
"raw_data_path": "",
|
||||
"openie_data_path": "",
|
||||
"embedding_data_dir": "",
|
||||
"rag_data_dir": "",
|
||||
},
|
||||
"info_extraction":{
|
||||
"info_extraction": {
|
||||
"workers": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# _load_config(global_config, parser.parse_args().config_path)
|
||||
file_path = os.path.abspath(__file__)
|
||||
dir_path = os.path.dirname(file_path)
|
||||
root_path = os.path.join(dir_path, os.pardir, os.pardir, os.pardir, os.pardir)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from .config import INVALID_ENTITY, global_config
|
||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -1,73 +0,0 @@
|
|||
from typing import List
|
||||
|
||||
from .llm_client import LLMMessage
|
||||
|
||||
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
||||
|
||||
输出格式示例:
|
||||
[ "实体A", "实体B", "实体C" ]
|
||||
|
||||
请注意以下要求:
|
||||
- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
|
||||
- 尽可能多的提取出段落中的全部实体;
|
||||
"""
|
||||
|
||||
|
||||
def build_entity_extract_context(paragraph: str) -> List[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
||||
|
||||
请使用JSON回复,使用三元组的JSON列表输出RDF图中的关系(每个三元组代表一个关系)。
|
||||
|
||||
输出格式示例:
|
||||
[
|
||||
["某实体","关系","某属性"],
|
||||
["某实体","关系","某实体"],
|
||||
["某资源","关系","某属性"]
|
||||
]
|
||||
|
||||
请注意以下要求:
|
||||
- 每个三元组应包含每个段落的实体命名列表中的至少一个命名实体,但最好是两个。
|
||||
- 将代词(如“你”、“我”、“他”、“她”、“它”等)转化为对应的实体命名,以避免指代不清。
|
||||
"""
|
||||
|
||||
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> List[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
||||
LLMMessage(
|
||||
"user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```"""
|
||||
).to_dict(),
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
qa_system_prompt = """
|
||||
你是一个性能优异的QA系统。请根据给定的问题和一些可能对你有帮助的信息作出回答。
|
||||
|
||||
请注意以下要求:
|
||||
- 你可以使用给定的信息来回答问题,但请不要直接引用它们。
|
||||
- 你的回答应该简洁明了,避免冗长的解释。
|
||||
- 如果你无法回答问题,请直接说“我不知道”。
|
||||
"""
|
||||
|
||||
|
||||
def build_qa_context(
|
||||
question: str, knowledge: list[(str, str, str)]
|
||||
) -> List[LLMMessage]:
|
||||
knowledge = "\n".join(
|
||||
[f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]
|
||||
)
|
||||
messages = [
|
||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
||||
LLMMessage(
|
||||
"user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}"
|
||||
).to_dict(),
|
||||
]
|
||||
return messages
|
||||
|
|
@ -2,7 +2,6 @@ import time
|
|||
from typing import Tuple, List, Dict
|
||||
|
||||
from .global_logger import logger
|
||||
# from . import prompt_template
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
from .kg_manager import KGManager
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .config import global_config
|
||||
from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +0,0 @@
|
|||
import json
|
||||
|
||||
|
||||
with open("raw.txt", "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
|
||||
# 一行一行的处理文件
|
||||
paragraphs = []
|
||||
paragraph = ""
|
||||
for line in raw.split("\n"):
|
||||
if line.strip() == "":
|
||||
# 有空行,表示段落结束
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph)
|
||||
paragraph = ""
|
||||
else:
|
||||
paragraph += line + "\n"
|
||||
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph)
|
||||
|
||||
with open("raw.json", "w", encoding="utf-8") as f:
|
||||
json.dump(paragraphs, f, ensure_ascii=False, indent=4)
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
import jsonlines
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Union, Optional
|
||||
from src.config import global_config as config
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""数据加载工具类,用于从/data目录下加载各种格式的数据文件"""
|
||||
|
||||
def __init__(self, custom_data_dir: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
初始化数据加载器
|
||||
|
||||
Args:
|
||||
custom_data_dir: 可选的自定义数据目录路径,如果不提供则使用配置文件中的默认路径
|
||||
"""
|
||||
self.data_dir = (
|
||||
Path(custom_data_dir)
|
||||
if custom_data_dir
|
||||
else Path(config["persistence"]["data_root_path"])
|
||||
)
|
||||
if not self.data_dir.exists():
|
||||
raise FileNotFoundError(f"数据目录 {self.data_dir} 不存在")
|
||||
|
||||
def _resolve_file_path(self, filename: str) -> Path:
|
||||
"""
|
||||
解析文件路径
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
完整的文件路径
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 当文件不存在时抛出
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"文件 {filename} 不存在")
|
||||
return file_path
|
||||
|
||||
def load_jsonl(self, filename: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
加载JSONL格式的文件
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
|
||||
Returns:
|
||||
包含所有数据的列表
|
||||
"""
|
||||
file_path = self._resolve_file_path(filename)
|
||||
data = []
|
||||
with jsonlines.open(file_path) as reader:
|
||||
for obj in reader:
|
||||
data.append(obj)
|
||||
return data
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
import networkx as nx
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def draw_graph_and_show(graph):
|
||||
"""绘制图并显示,画布大小1280*1280"""
|
||||
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
|
||||
nx.draw_networkx(
|
||||
graph,
|
||||
node_size=100,
|
||||
width=0.5,
|
||||
with_labels=True,
|
||||
labels=nx.get_node_attributes(graph, "content"),
|
||||
font_family="Sarasa Mono SC",
|
||||
font_size=8,
|
||||
)
|
||||
fig.show()
|
||||
Loading…
Reference in New Issue