mirror of https://github.com/Mai-with-u/MaiBot.git
137 lines
5.1 KiB
Python
137 lines
5.1 KiB
Python
from src.do_tool.tool_can_use.base_tool import BaseTool
|
||
from src.plugins.chat.utils import get_embedding
|
||
from src.common.database import db
|
||
from src.common.logger import get_module_logger
|
||
from typing import Dict, Any, Union
|
||
|
||
logger = get_module_logger("get_knowledge_tool")
|
||
|
||
|
||
class SearchKnowledgeTool(BaseTool):
|
||
"""从知识库中搜索相关信息的工具"""
|
||
|
||
name = "search_knowledge"
|
||
description = "从知识库中搜索相关信息"
|
||
parameters = {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "搜索查询关键词"},
|
||
"threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
|
||
},
|
||
"required": ["query"],
|
||
}
|
||
|
||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""执行知识库搜索
|
||
|
||
Args:
|
||
function_args: 工具参数
|
||
message_txt: 原始消息文本
|
||
|
||
Returns:
|
||
Dict: 工具执行结果
|
||
"""
|
||
try:
|
||
query = function_args.get("query")
|
||
threshold = function_args.get("threshold", 0.4)
|
||
|
||
# 调用知识库搜索
|
||
embedding = await get_embedding(query, request_type="info_retrieval")
|
||
if embedding:
|
||
knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||
if knowledge_info:
|
||
content = f"你知道这些知识: {knowledge_info}"
|
||
else:
|
||
content = f"你不太了解有关{query}的知识"
|
||
return {"name": "search_knowledge", "content": content}
|
||
return {"name": "search_knowledge", "content": f"无法获取关于'{query}'的嵌入向量"}
|
||
except Exception as e:
|
||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||
return {"name": "search_knowledge", "content": f"知识库搜索失败: {str(e)}"}
|
||
|
||
@staticmethod
|
||
def get_info_from_db(
|
||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||
) -> Union[str, list]:
|
||
"""从数据库中获取相关信息
|
||
|
||
Args:
|
||
query_embedding: 查询的嵌入向量
|
||
limit: 最大返回结果数
|
||
threshold: 相似度阈值
|
||
return_raw: 是否返回原始结果
|
||
|
||
Returns:
|
||
Union[str, list]: 格式化的信息字符串或原始结果列表
|
||
"""
|
||
if not query_embedding:
|
||
return "" if not return_raw else []
|
||
|
||
# 使用余弦相似度计算
|
||
pipeline = [
|
||
{
|
||
"$addFields": {
|
||
"dotProduct": {
|
||
"$reduce": {
|
||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||
"initialValue": 0,
|
||
"in": {
|
||
"$add": [
|
||
"$$value",
|
||
{
|
||
"$multiply": [
|
||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||
]
|
||
},
|
||
]
|
||
},
|
||
}
|
||
},
|
||
"magnitude1": {
|
||
"$sqrt": {
|
||
"$reduce": {
|
||
"input": "$embedding",
|
||
"initialValue": 0,
|
||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||
}
|
||
}
|
||
},
|
||
"magnitude2": {
|
||
"$sqrt": {
|
||
"$reduce": {
|
||
"input": query_embedding,
|
||
"initialValue": 0,
|
||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||
}
|
||
}
|
||
},
|
||
}
|
||
},
|
||
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||
{
|
||
"$match": {
|
||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||
}
|
||
},
|
||
{"$sort": {"similarity": -1}},
|
||
{"$limit": limit},
|
||
{"$project": {"content": 1, "similarity": 1}},
|
||
]
|
||
|
||
results = list(db.knowledges.aggregate(pipeline))
|
||
logger.debug(f"知识库查询结果数量: {len(results)}")
|
||
|
||
if not results:
|
||
return "" if not return_raw else []
|
||
|
||
if return_raw:
|
||
return results
|
||
else:
|
||
# 返回所有找到的内容,用换行分隔
|
||
return "\n".join(str(result["content"]) for result in results)
|
||
|
||
|
||
# 注册工具
|
||
# register_tool(SearchKnowledgeTool)
|