MaiBot/src/tools/tool_can_use/lpmm_get_knowledge.py

163 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding
# from src.common.database import db
from src.common.logger_manager import get_logger
from typing import Dict, Any
from src.chat.knowledge.knowledge_lib import qa_manager
logger = get_logger("lpmm_get_knowledge_tool")
class SearchKnowledgeFromLPMMTool(BaseTool):
"""从LPMM知识库中搜索相关信息的工具"""
name = "lpmm_search_knowledge"
description = "Search relevant information from knowledge base, use this tool if you need knowledge"
parameters = {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query keywords"},
"threshold": {"type": "number", "description": "Similarity threshold, between 0.0 and 1.0"},
},
"required": ["query"],
}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索
Args:
function_args: 工具参数
Returns:
Dict: 工具执行结果
"""
try:
query = function_args.get("query")
# threshold = function_args.get("threshold", 0.4)
# 调用知识库搜索
embedding = await get_embedding(query, request_type="info_retrieval")
if embedding:
knowledge_info = qa_manager.get_knowledge(query)
logger.debug(f"知识库查询结果: {knowledge_info}")
if knowledge_info:
content = f"你知道这些知识: {knowledge_info}"
else:
content = f"你不太了解有关{query}的知识"
return {"type": "lpmm_knowledge", "id": query, "content": content}
# 如果获取嵌入失败
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量你lpmm知识库炸了"}
except Exception as e:
logger.error(f"知识库搜索工具执行失败: {str(e)}")
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
query_id = query if "query" in locals() else "unknown_query"
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败炸了: {str(e)}"}
# def get_info_from_db(
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
# ) -> Union[str, list]:
# """从数据库中获取相关信息
# Args:
# query_embedding: 查询的嵌入向量
# limit: 最大返回结果数
# threshold: 相似度阈值
# return_raw: 是否返回原始结果
# Returns:
# Union[str, list]: 格式化的信息字符串或原始结果列表
# """
# if not query_embedding:
# return "" if not return_raw else []
# # 使用余弦相似度计算
# pipeline = [
# {
# "$addFields": {
# "dotProduct": {
# "$reduce": {
# "input": {"$range": [0, {"$size": "$embedding"}]},
# "initialValue": 0,
# "in": {
# "$add": [
# "$$value",
# {
# "$multiply": [
# {"$arrayElemAt": ["$embedding", "$$this"]},
# {"$arrayElemAt": [query_embedding, "$$this"]},
# ]
# },
# ]
# },
# }
# },
# "magnitude1": {
# "$sqrt": {
# "$reduce": {
# "input": "$embedding",
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# "magnitude2": {
# "$sqrt": {
# "$reduce": {
# "input": query_embedding,
# "initialValue": 0,
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
# }
# }
# },
# }
# },
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
# {
# "$match": {
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
# }
# },
# {"$sort": {"similarity": -1}},
# {"$limit": limit},
# {"$project": {"content": 1, "similarity": 1}},
# ]
# results = list(db.knowledges.aggregate(pipeline))
# logger.debug(f"知识库查询结果数量: {len(results)}")
# if not results:
# return "" if not return_raw else []
# if return_raw:
# return results
# else:
# # 返回所有找到的内容,用换行分隔
# return "\n".join(str(result["content"]) for result in results)
def _format_results(self, results: list) -> str:
"""格式化结果"""
if not results:
return "未找到相关知识。"
formatted_string = "我找到了一些相关知识:\n"
for i, result in enumerate(results):
# chunk_id = result.get("chunk_id")
text = result.get("text", "")
source = result.get("source", "未知来源")
source_type = result.get("source_type", "未知类型")
similarity = result.get("similarity", 0.0)
formatted_string += (
f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source} \n内容片段: {text}\n\n"
)
# 暂时去掉chunk_id
# formatted_string += f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source}, Chunk ID: {chunk_id} \n内容片段: {text}\n\n"
return formatted_string
# 注册工具
# register_tool(SearchKnowledgeTool)