mirror of https://github.com/Mai-with-u/MaiBot.git
修改了知识库检索的逻辑,增强了知识库检索的能力,修改内容在prompt_builder.py最下方
parent
8ce971f230
commit
df53b5394c
|
|
@ -239,14 +239,49 @@ class PromptBuilder:
|
||||||
async def get_prompt_info(self, message: str, threshold: float):
|
async def get_prompt_info(self, message: str, threshold: float):
|
||||||
related_info = ""
|
related_info = ""
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
embedding = await get_embedding(message)
|
|
||||||
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
# 识别主题
|
||||||
|
topics = await hippocampus._identify_topics(message)
|
||||||
|
logger.info(f"[知识库查询] 识别出的主题:{topics}")
|
||||||
|
logger.info(f"[知识库查询] 主题数量:{len(topics)}")
|
||||||
|
|
||||||
|
# 对每个主题进行知识库查询
|
||||||
|
all_related_info = []
|
||||||
|
for i, topic in enumerate(topics, 1):
|
||||||
|
logger.info(f"[知识库查询] 正在处理第 {i}/{len(topics)} 个主题:{topic}")
|
||||||
|
try:
|
||||||
|
embedding = await get_embedding(topic)
|
||||||
|
if embedding:
|
||||||
|
logger.debug(f"[知识库查询] 主题「{topic}」成功获取embedding")
|
||||||
|
topic_info = self.get_info_from_db(embedding, limit=2, threshold=threshold)
|
||||||
|
if topic_info:
|
||||||
|
logger.info(f"[知识库查询] 主题「{topic}」找到相关知识:\n{topic_info}")
|
||||||
|
all_related_info.append(topic_info)
|
||||||
|
else:
|
||||||
|
logger.debug(f"[知识库查询] 主题「{topic}」未找到相关知识")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[知识库查询] 主题「{topic}」获取embedding失败")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[知识库查询] 处理主题「{topic}」时发生错误:{str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 合并所有相关主题的知识
|
||||||
|
if all_related_info:
|
||||||
|
related_info = "\n".join(all_related_info)
|
||||||
|
logger.info(f"[知识库查询] 最终合并的知识内容:\n{related_info}")
|
||||||
|
logger.info(f"[知识库查询] 成功处理的主题数量:{len(all_related_info)}/{len(topics)}")
|
||||||
|
else:
|
||||||
|
logger.debug("[知识库查询] 未找到任何相关知识")
|
||||||
|
|
||||||
return related_info
|
return related_info
|
||||||
|
|
||||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
def get_info_from_db(self, query_embedding: list, limit: int = 5, threshold: float = 0.5) -> str:
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
|
logger.warning("[知识库查询] 查询向量为空")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
logger.debug(f"[知识库查询] 开始查询,相似度阈值:{threshold},返回数量限制:{limit}")
|
||||||
|
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{
|
{
|
||||||
|
|
@ -300,11 +335,15 @@ class PromptBuilder:
|
||||||
]
|
]
|
||||||
|
|
||||||
results = list(db.knowledges.aggregate(pipeline))
|
results = list(db.knowledges.aggregate(pipeline))
|
||||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
|
logger.debug(f"[知识库查询] 未找到相似度大于 {threshold} 的结果")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# 记录每个结果的相似度
|
||||||
|
for result in results:
|
||||||
|
logger.debug(f"[知识库查询] 找到相似内容 [相似度: {result['similarity']:.3f}]:\n{result['content']}")
|
||||||
|
|
||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue