修改了知识库检索的逻辑,增强了知识库检索的能力,修改内容在prompt_builder.py最下方

pull/618/head
Voyager1 2025-03-30 18:54:29 +08:00
parent 8ce971f230
commit df53b5394c
1 changed files with 45 additions and 6 deletions

View File

@ -239,14 +239,49 @@ class PromptBuilder:
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding, threshold=threshold) # 识别主题
topics = await hippocampus._identify_topics(message)
logger.info(f"[知识库查询] 识别出的主题:{topics}")
logger.info(f"[知识库查询] 主题数量:{len(topics)}")
# 对每个主题进行知识库查询
all_related_info = []
for i, topic in enumerate(topics, 1):
logger.info(f"[知识库查询] 正在处理第 {i}/{len(topics)} 个主题:{topic}")
try:
embedding = await get_embedding(topic)
if embedding:
logger.debug(f"[知识库查询] 主题「{topic}」成功获取embedding")
topic_info = self.get_info_from_db(embedding, limit=2, threshold=threshold)
if topic_info:
logger.info(f"[知识库查询] 主题「{topic}」找到相关知识:\n{topic_info}")
all_related_info.append(topic_info)
else:
logger.debug(f"[知识库查询] 主题「{topic}」未找到相关知识")
else:
logger.warning(f"[知识库查询] 主题「{topic}」获取embedding失败")
except Exception as e:
logger.error(f"[知识库查询] 处理主题「{topic}」时发生错误:{str(e)}")
continue
# 合并所有相关主题的知识
if all_related_info:
related_info = "\n".join(all_related_info)
logger.info(f"[知识库查询] 最终合并的知识内容:\n{related_info}")
logger.info(f"[知识库查询] 成功处理的主题数量:{len(all_related_info)}/{len(topics)}")
else:
logger.debug("[知识库查询] 未找到任何相关知识")
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 5, threshold: float = 0.5) -> str:
if not query_embedding: if not query_embedding:
logger.warning("[知识库查询] 查询向量为空")
return "" return ""
logger.debug(f"[知识库查询] 开始查询,相似度阈值:{threshold},返回数量限制:{limit}")
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@ -300,11 +335,15 @@ class PromptBuilder:
] ]
results = list(db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:
logger.debug(f"[知识库查询] 未找到相似度大于 {threshold} 的结果")
return "" return ""
# 记录每个结果的相似度
for result in results:
logger.debug(f"[知识库查询] 找到相似内容 [相似度: {result['similarity']:.3f}]\n{result['content']}")
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results) return "\n".join(str(result["content"]) for result in results)