mirror of https://github.com/Mai-with-u/MaiBot.git
better:优化记忆检索占用
parent
85864c7013
commit
26784b00a5
|
|
@ -390,6 +390,7 @@ class ChatHistorySummarizer:
|
|||
"theme": theme,
|
||||
"keywords": json.dumps(keywords, ensure_ascii=False),
|
||||
"summary": summary,
|
||||
"count": 0,
|
||||
}
|
||||
|
||||
# 使用db_save存储(使用start_time和chat_id作为唯一标识)
|
||||
|
|
|
|||
|
|
@ -377,6 +377,7 @@ class ChatHistory(BaseModel):
|
|||
theme = TextField() # 主题:这段对话的主要内容,一个简短的标题
|
||||
keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储
|
||||
summary = TextField() # 概括:对这段话的平文本概括
|
||||
count = IntegerField(default=0) # 被检索次数
|
||||
|
||||
class Meta:
|
||||
table_name = "chat_history"
|
||||
|
|
|
|||
|
|
@ -280,12 +280,12 @@ async def _react_agent_solve_question(
|
|||
return False, "未找到相关信息", thinking_steps
|
||||
|
||||
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) -> str:
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
|
||||
"""获取最近一段时间内的查询历史
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认1小时
|
||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
||||
|
||||
Returns:
|
||||
str: 格式化的查询历史字符串
|
||||
|
|
@ -302,7 +302,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0)
|
|||
(ThinkingBack.update_time >= start_time)
|
||||
)
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(8) # 最多返回10条最近的记录
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
)
|
||||
|
||||
if not records.exists():
|
||||
|
|
@ -314,7 +314,8 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0)
|
|||
for record in records:
|
||||
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
||||
answer_preview = ""
|
||||
if record.answer:
|
||||
# 只有找到答案时才显示答案内容
|
||||
if record.found_answer and record.answer:
|
||||
# 截取答案前100字符
|
||||
answer_preview = record.answer[:100]
|
||||
if len(record.answer) > 100:
|
||||
|
|
@ -554,7 +555,7 @@ async def build_memory_retrieval_prompt(
|
|||
chat_id = chat_stream.stream_id
|
||||
|
||||
# 获取最近查询历史(最近1小时内的查询)
|
||||
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0)
|
||||
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0)
|
||||
if not recent_query_history:
|
||||
recent_query_history = "最近没有查询记录。"
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import parse_keywords_string
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
|
|
@ -116,10 +114,19 @@ async def query_chat_history(
|
|||
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
|
||||
|
||||
records = filtered_records
|
||||
|
||||
# 对即将返回的记录增加使用计数
|
||||
records_to_use = records[:3]
|
||||
for record in records_to_use:
|
||||
try:
|
||||
ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute()
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records[:10]: # 最多返回10条记录
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
result_parts = []
|
||||
|
||||
# 添加主题
|
||||
|
|
@ -146,66 +153,11 @@ async def query_chat_history(
|
|||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
# 如果只有一条记录,直接返回
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
|
||||
# 多条记录,使用LLM总结
|
||||
try:
|
||||
llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="chat_history_analysis"
|
||||
)
|
||||
|
||||
query_desc = []
|
||||
if keyword:
|
||||
# 解析关键词列表用于显示
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if keywords_list:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
query_desc.append(f"关键词:{keywords_str}")
|
||||
else:
|
||||
query_desc.append(f"关键词:{keyword}")
|
||||
if time_range:
|
||||
if " - " in time_range:
|
||||
query_desc.append(f"时间范围:{time_range}")
|
||||
else:
|
||||
query_desc.append(f"时间点:{time_range}")
|
||||
|
||||
query_info = ",".join(query_desc) if query_desc else "聊天记录概述"
|
||||
|
||||
combined_results = "\n\n---\n\n".join(results)
|
||||
|
||||
analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。
|
||||
查询条件:{query_info}
|
||||
|
||||
聊天记录概述:
|
||||
{combined_results}
|
||||
|
||||
请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。
|
||||
|
||||
总结:"""
|
||||
|
||||
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
|
||||
prompt=analysis_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
logger.info(f"查询聊天历史概述提示词: {analysis_prompt}")
|
||||
logger.info(f"查询聊天历史概述响应: {response}")
|
||||
logger.info(f"查询聊天历史概述推理: {reasoning}")
|
||||
logger.info(f"查询聊天历史概述模型: {model_name}")
|
||||
|
||||
if "无有效信息" in response:
|
||||
return "无有效信息"
|
||||
|
||||
return response
|
||||
|
||||
except Exception as llm_error:
|
||||
logger.error(f"LLM分析聊天记录概述失败: {llm_error}")
|
||||
# 如果LLM分析失败,返回前3条记录的摘要
|
||||
return "\n\n---\n\n".join(results[:3])
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
|
|
|
|||
|
|
@ -1,391 +0,0 @@
|
|||
"""
|
||||
StyleLearner 数据库测试脚本
|
||||
使用数据库中的expression数据测试style_learner功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Tuple
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.common.database.database_model import Expression, db
|
||||
from src.express.style_learner import StyleLearnerManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("style_learner_test")
|
||||
|
||||
|
||||
class StyleLearnerDatabaseTest:
|
||||
"""使用数据库数据测试StyleLearner"""
|
||||
|
||||
def __init__(self, random_state: int = 42):
|
||||
self.random_state = random_state
|
||||
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
|
||||
|
||||
# 测试结果
|
||||
self.test_results = {
|
||||
"total_samples": 0,
|
||||
"train_samples": 0,
|
||||
"test_samples": 0,
|
||||
"unique_styles": 0,
|
||||
"unique_chat_ids": 0,
|
||||
"accuracy": 0.0,
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"predictions": [],
|
||||
"ground_truth": [],
|
||||
"model_save_success": False,
|
||||
"model_save_path": self.manager.model_save_path
|
||||
}
|
||||
|
||||
def load_data_from_database(self) -> List[Dict]:
|
||||
"""
|
||||
从数据库加载expression数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含up_content, style, chat_id的数据列表
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
# 查询所有expression数据
|
||||
expressions = Expression.select().where(
|
||||
(Expression.up_content.is_null(False)) &
|
||||
(Expression.style.is_null(False)) &
|
||||
(Expression.chat_id.is_null(False)) &
|
||||
(Expression.type == "style")
|
||||
)
|
||||
|
||||
data = []
|
||||
for expr in expressions:
|
||||
if expr.up_content and expr.style and expr.chat_id:
|
||||
data.append({
|
||||
"up_content": expr.up_content,
|
||||
"style": expr.style,
|
||||
"chat_id": expr.chat_id,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"context": expr.context,
|
||||
"situation": expr.situation
|
||||
})
|
||||
|
||||
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载数据失败: {e}")
|
||||
return []
|
||||
|
||||
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 预处理后的数据
|
||||
"""
|
||||
# 过滤掉空值或过短的数据
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
up_content = item["up_content"].strip()
|
||||
style = item["style"].strip()
|
||||
|
||||
if len(up_content) >= 2 and len(style) >= 2:
|
||||
filtered_data.append({
|
||||
"up_content": up_content,
|
||||
"style": style,
|
||||
"chat_id": item["chat_id"],
|
||||
"last_active_time": item["last_active_time"],
|
||||
"context": item["context"],
|
||||
"situation": item["situation"]
|
||||
})
|
||||
|
||||
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
|
||||
return filtered_data
|
||||
|
||||
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
分割训练集和测试集
|
||||
训练集使用所有数据,测试集从训练集中随机选择5%
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
|
||||
"""
|
||||
# 训练集使用所有数据
|
||||
train_data = data.copy()
|
||||
|
||||
# 测试集从训练集中随机选择5%
|
||||
test_size = 0.05 # 5%
|
||||
test_data = train_test_split(
|
||||
train_data, test_size=test_size, random_state=self.random_state
|
||||
)[1] # 只取测试集部分
|
||||
|
||||
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条")
|
||||
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
|
||||
return train_data, test_data
|
||||
|
||||
def train_model(self, train_data: List[Dict]) -> None:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
"""
|
||||
logger.info("开始训练模型...")
|
||||
|
||||
# 统计信息
|
||||
chat_ids = set()
|
||||
styles = set()
|
||||
|
||||
for item in train_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
style = item["style"]
|
||||
|
||||
chat_ids.add(chat_id)
|
||||
styles.add(style)
|
||||
|
||||
# 学习映射关系
|
||||
success = self.manager.learn_mapping(chat_id, up_content, style)
|
||||
if not success:
|
||||
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
|
||||
|
||||
self.test_results["train_samples"] = len(train_data)
|
||||
self.test_results["unique_styles"] = len(styles)
|
||||
self.test_results["unique_chat_ids"] = len(chat_ids)
|
||||
|
||||
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
|
||||
|
||||
# 保存训练好的模型
|
||||
logger.info("开始保存训练好的模型...")
|
||||
save_success = self.manager.save_all_models()
|
||||
self.test_results["model_save_success"] = save_success
|
||||
|
||||
if save_success:
|
||||
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
|
||||
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
|
||||
else:
|
||||
logger.warning("部分模型保存失败")
|
||||
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
|
||||
|
||||
def test_model(self, test_data: List[Dict]) -> None:
|
||||
"""
|
||||
测试模型
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
"""
|
||||
logger.info("开始测试模型...")
|
||||
|
||||
predictions = []
|
||||
ground_truth = []
|
||||
correct_predictions = 0
|
||||
|
||||
for item in test_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
true_style = item["style"]
|
||||
|
||||
# 预测风格
|
||||
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
|
||||
|
||||
predictions.append(predicted_style)
|
||||
ground_truth.append(true_style)
|
||||
|
||||
# 检查预测是否正确
|
||||
if predicted_style == true_style:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录详细预测结果
|
||||
self.test_results["predictions"].append({
|
||||
"chat_id": chat_id,
|
||||
"up_content": up_content,
|
||||
"true_style": true_style,
|
||||
"predicted_style": predicted_style,
|
||||
"scores": scores
|
||||
})
|
||||
|
||||
# 计算准确率
|
||||
accuracy = correct_predictions / len(test_data) if test_data else 0
|
||||
|
||||
# 计算其他指标(需要处理None值)
|
||||
valid_predictions = [p for p in predictions if p is not None]
|
||||
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
|
||||
|
||||
if valid_predictions:
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
|
||||
)
|
||||
else:
|
||||
precision = recall = f1 = 0.0
|
||||
|
||||
self.test_results["test_samples"] = len(test_data)
|
||||
self.test_results["accuracy"] = accuracy
|
||||
self.test_results["precision"] = precision
|
||||
self.test_results["recall"] = recall
|
||||
self.test_results["f1_score"] = f1
|
||||
|
||||
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
|
||||
|
||||
def analyze_results(self) -> None:
|
||||
"""分析测试结果"""
|
||||
logger.info("=== 测试结果分析 ===")
|
||||
|
||||
print("\n📊 数据统计:")
|
||||
print(f" 总样本数: {self.test_results['total_samples']}")
|
||||
print(f" 训练样本数: {self.test_results['train_samples']}")
|
||||
print(f" 测试样本数: {self.test_results['test_samples']}")
|
||||
print(f" 唯一风格数: {self.test_results['unique_styles']}")
|
||||
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
|
||||
|
||||
print("\n🎯 模型性能:")
|
||||
print(f" 准确率: {self.test_results['accuracy']:.4f}")
|
||||
print(f" 精确率: {self.test_results['precision']:.4f}")
|
||||
print(f" 召回率: {self.test_results['recall']:.4f}")
|
||||
print(f" F1分数: {self.test_results['f1_score']:.4f}")
|
||||
|
||||
print("\n💾 模型保存:")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
print(f" 保存状态: {save_status}")
|
||||
print(f" 保存路径: {self.test_results['model_save_path']}")
|
||||
|
||||
# 分析各聊天室的性能
|
||||
chat_performance = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
chat_id = pred["chat_id"]
|
||||
if chat_id not in chat_performance:
|
||||
chat_performance[chat_id] = {"correct": 0, "total": 0}
|
||||
|
||||
chat_performance[chat_id]["total"] += 1
|
||||
if pred["predicted_style"] == pred["true_style"]:
|
||||
chat_performance[chat_id]["correct"] += 1
|
||||
|
||||
print("\n📈 各聊天室性能:")
|
||||
for chat_id, perf in chat_performance.items():
|
||||
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
|
||||
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
|
||||
|
||||
# 分析风格分布
|
||||
style_counts = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
style = pred["true_style"]
|
||||
style_counts[style] = style_counts.get(style, 0) + 1
|
||||
|
||||
print("\n🎨 风格分布 (前10个):")
|
||||
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
for style, count in sorted_styles[:10]:
|
||||
print(f" {style}: {count} 次")
|
||||
|
||||
def show_sample_predictions(self, num_samples: int = 10) -> None:
|
||||
"""显示样本预测结果"""
|
||||
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
|
||||
|
||||
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
print(f"\n {i+1}. {status}")
|
||||
print(f" 聊天室: {pred['chat_id']}")
|
||||
print(f" 输入内容: {pred['up_content']}")
|
||||
print(f" 真实风格: {pred['true_style']}")
|
||||
print(f" 预测风格: {pred['predicted_style']}")
|
||||
if pred["scores"]:
|
||||
top_scores = dict(list(pred["scores"].items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
|
||||
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
|
||||
"""保存测试结果到文件"""
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("StyleLearner 数据库测试结果\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
f.write("数据统计:\n")
|
||||
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
|
||||
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
|
||||
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
|
||||
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
|
||||
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
|
||||
|
||||
f.write("模型性能:\n")
|
||||
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
|
||||
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
|
||||
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
|
||||
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
|
||||
|
||||
f.write("模型保存:\n")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
f.write(f" 保存状态: {save_status}\n")
|
||||
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
|
||||
|
||||
f.write("详细预测结果:\n")
|
||||
for i, pred in enumerate(self.test_results["predictions"]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
|
||||
|
||||
logger.info(f"测试结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存测试结果失败: {e}")
|
||||
|
||||
def run_test(self) -> None:
|
||||
"""运行完整测试"""
|
||||
logger.info("开始StyleLearner数据库测试...")
|
||||
|
||||
# 1. 加载数据
|
||||
raw_data = self.load_data_from_database()
|
||||
if not raw_data:
|
||||
logger.error("没有加载到数据,测试终止")
|
||||
return
|
||||
|
||||
# 2. 数据预处理
|
||||
processed_data = self.preprocess_data(raw_data)
|
||||
if not processed_data:
|
||||
logger.error("预处理后没有数据,测试终止")
|
||||
return
|
||||
|
||||
self.test_results["total_samples"] = len(processed_data)
|
||||
|
||||
# 3. 分割数据
|
||||
train_data, test_data = self.split_data(processed_data)
|
||||
|
||||
# 4. 训练模型
|
||||
self.train_model(train_data)
|
||||
|
||||
# 5. 测试模型
|
||||
self.test_model(test_data)
|
||||
|
||||
# 6. 分析结果
|
||||
self.analyze_results()
|
||||
|
||||
# 7. 显示样本预测
|
||||
self.show_sample_predictions(10)
|
||||
|
||||
# 8. 保存结果
|
||||
self.save_results()
|
||||
|
||||
logger.info("StyleLearner数据库测试完成!")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("StyleLearner 数据库测试脚本")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建测试实例
|
||||
test = StyleLearnerDatabaseTest(random_state=42)
|
||||
|
||||
# 运行测试
|
||||
test.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue