mirror of https://github.com/Mai-with-u/MaiBot.git
修改了知识库检索的逻辑,增强了知识库检索的能力,修改内容在prompt_builder.py最下方
parent
8ce971f230
commit
df53b5394c
|
|
@ -239,14 +239,49 @@ class PromptBuilder:
|
|||
async def get_prompt_info(self, message: str, threshold: float):
|
||||
related_info = ""
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
embedding = await get_embedding(message)
|
||||
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
||||
|
||||
|
||||
# 识别主题
|
||||
topics = await hippocampus._identify_topics(message)
|
||||
logger.info(f"[知识库查询] 识别出的主题:{topics}")
|
||||
logger.info(f"[知识库查询] 主题数量:{len(topics)}")
|
||||
|
||||
# 对每个主题进行知识库查询
|
||||
all_related_info = []
|
||||
for i, topic in enumerate(topics, 1):
|
||||
logger.info(f"[知识库查询] 正在处理第 {i}/{len(topics)} 个主题:{topic}")
|
||||
try:
|
||||
embedding = await get_embedding(topic)
|
||||
if embedding:
|
||||
logger.debug(f"[知识库查询] 主题「{topic}」成功获取embedding")
|
||||
topic_info = self.get_info_from_db(embedding, limit=2, threshold=threshold)
|
||||
if topic_info:
|
||||
logger.info(f"[知识库查询] 主题「{topic}」找到相关知识:\n{topic_info}")
|
||||
all_related_info.append(topic_info)
|
||||
else:
|
||||
logger.debug(f"[知识库查询] 主题「{topic}」未找到相关知识")
|
||||
else:
|
||||
logger.warning(f"[知识库查询] 主题「{topic}」获取embedding失败")
|
||||
except Exception as e:
|
||||
logger.error(f"[知识库查询] 处理主题「{topic}」时发生错误:{str(e)}")
|
||||
continue
|
||||
|
||||
# 合并所有相关主题的知识
|
||||
if all_related_info:
|
||||
related_info = "\n".join(all_related_info)
|
||||
logger.info(f"[知识库查询] 最终合并的知识内容:\n{related_info}")
|
||||
logger.info(f"[知识库查询] 成功处理的主题数量:{len(all_related_info)}/{len(topics)}")
|
||||
else:
|
||||
logger.debug("[知识库查询] 未找到任何相关知识")
|
||||
|
||||
return related_info
|
||||
|
||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
||||
def get_info_from_db(self, query_embedding: list, limit: int = 5, threshold: float = 0.5) -> str:
|
||||
if not query_embedding:
|
||||
logger.warning("[知识库查询] 查询向量为空")
|
||||
return ""
|
||||
|
||||
logger.debug(f"[知识库查询] 开始查询,相似度阈值:{threshold},返回数量限制:{limit}")
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
|
|
@ -300,11 +335,15 @@ class PromptBuilder:
|
|||
]
|
||||
|
||||
results = list(db.knowledges.aggregate(pipeline))
|
||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
||||
|
||||
|
||||
if not results:
|
||||
logger.debug(f"[知识库查询] 未找到相似度大于 {threshold} 的结果")
|
||||
return ""
|
||||
|
||||
# 记录每个结果的相似度
|
||||
for result in results:
|
||||
logger.debug(f"[知识库查询] 找到相似内容 [相似度: {result['similarity']:.3f}]:\n{result['content']}")
|
||||
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return "\n".join(str(result["content"]) for result in results)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue