diff --git a/src/chat/utils/chat_history_summarizer.py b/src/chat/utils/chat_history_summarizer.py index abe02022..4e5a67bd 100644 --- a/src/chat/utils/chat_history_summarizer.py +++ b/src/chat/utils/chat_history_summarizer.py @@ -390,6 +390,7 @@ class ChatHistorySummarizer: "theme": theme, "keywords": json.dumps(keywords, ensure_ascii=False), "summary": summary, + "count": 0, } # 使用db_save存储(使用start_time和chat_id作为唯一标识) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 40e50b21..186bb65c 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -377,6 +377,7 @@ class ChatHistory(BaseModel): theme = TextField() # 主题:这段对话的主要内容,一个简短的标题 keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储 summary = TextField() # 概括:对这段话的平文本概括 + count = IntegerField(default=0) # 被检索次数 class Meta: table_name = "chat_history" diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 0b608d07..e466ef9d 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -280,12 +280,12 @@ async def _react_agent_solve_question( return False, "未找到相关信息", thinking_steps -def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) -> str: +def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str: """获取最近一段时间内的查询历史 Args: chat_id: 聊天ID - time_window_seconds: 时间窗口(秒),默认1小时 + time_window_seconds: 时间窗口(秒),默认10分钟 Returns: str: 格式化的查询历史字符串 @@ -302,7 +302,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) (ThinkingBack.update_time >= start_time) ) .order_by(ThinkingBack.update_time.desc()) - .limit(8) # 最多返回10条最近的记录 + .limit(5) # 最多返回5条最近的记录 ) if not records.exists(): @@ -314,7 +314,8 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 3600.0) for record in records: status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" answer_preview = "" - if record.answer: + # 只有找到答案时才显示答案内容 + if record.found_answer and record.answer: # 截取答案前100字符 answer_preview = record.answer[:100] if len(record.answer) > 100: @@ -554,7 +555,7 @@ async def build_memory_retrieval_prompt( chat_id = chat_stream.stream_id # 获取最近查询历史(最近1小时内的查询) - recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0) + recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0) if not recent_query_history: recent_query_history = "最近没有查询记录。" diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 4be8ccab..f95ee266 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -6,9 +6,7 @@ import json from typing import Optional from src.common.logger import get_logger -from src.config.config import model_config from src.common.database.database_model import ChatHistory -from src.llm_models.utils_model import LLMRequest from src.chat.utils.utils import parse_keywords_string from .tool_registry import register_memory_retrieval_tool from .tool_utils import parse_datetime_to_timestamp, parse_time_range @@ -116,10 +114,19 @@ async def query_chat_history( return f"未找到包含关键词'{keywords_str}'的聊天记录概述" records = filtered_records + + # 对即将返回的记录增加使用计数 + records_to_use = records[:3] + for record in records_to_use: + try: + ChatHistory.update(count=ChatHistory.count + 1).where(ChatHistory.id == record.id).execute() + record.count = (record.count or 0) + 1 + except Exception as update_error: + logger.error(f"更新聊天记录概述计数失败: {update_error}") # 构建结果文本 results = [] - for record in records[:10]: # 最多返回10条记录 + for record in records_to_use: # 最多返回3条记录 result_parts = [] # 添加主题 @@ -146,66 +153,11 @@ async def query_chat_history( if not results: return "未找到相关聊天记录概述" - # 如果只有一条记录,直接返回 - if len(results) == 1: - return results[0] - - # 多条记录,使用LLM总结 - try: - llm_request = LLMRequest( - model_set=model_config.model_task_config.utils_small, - request_type="chat_history_analysis" - ) - - query_desc = [] - if keyword: - # 解析关键词列表用于显示 - keywords_list = parse_keywords_string(keyword) - if keywords_list: - keywords_str = "、".join(keywords_list) - query_desc.append(f"关键词:{keywords_str}") - else: - query_desc.append(f"关键词:{keyword}") - if time_range: - if " - " in time_range: - query_desc.append(f"时间范围:{time_range}") - else: - query_desc.append(f"时间点:{time_range}") - - query_info = ",".join(query_desc) if query_desc else "聊天记录概述" - - combined_results = "\n\n---\n\n".join(results) - - analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。 -查询条件:{query_info} - -聊天记录概述: -{combined_results} - -请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。 - -总结:""" - - response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async( - prompt=analysis_prompt, - temperature=0.3, - max_tokens=512 - ) - - logger.info(f"查询聊天历史概述提示词: {analysis_prompt}") - logger.info(f"查询聊天历史概述响应: {response}") - logger.info(f"查询聊天历史概述推理: {reasoning}") - logger.info(f"查询聊天历史概述模型: {model_name}") - - if "无有效信息" in response: - return "无有效信息" - - return response - - except Exception as llm_error: - logger.error(f"LLM分析聊天记录概述失败: {llm_error}") - # 如果LLM分析失败,返回前3条记录的摘要 - return "\n\n---\n\n".join(results[:3]) + response_text = "\n\n---\n\n".join(results) + if len(records) > len(records_to_use): + omitted_count = len(records) - len(records_to_use) + response_text += f"\n\n(还有{omitted_count}条历史记录已省略)" + return response_text except Exception as e: logger.error(f"查询聊天历史概述失败: {e}") diff --git a/test_style_learner_db.py b/test_style_learner_db.py deleted file mode 100644 index ba1e2023..00000000 --- a/test_style_learner_db.py +++ /dev/null @@ -1,391 +0,0 @@ -""" -StyleLearner 数据库测试脚本 -使用数据库中的expression数据测试style_learner功能 -""" - -import os -import sys -from typing import List, Dict, Tuple -from sklearn.model_selection import train_test_split -from sklearn.metrics import precision_recall_fscore_support - -# 添加项目根目录到Python路径 -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from src.common.database.database_model import Expression, db -from src.express.style_learner import StyleLearnerManager -from src.common.logger import get_logger - -logger = get_logger("style_learner_test") - - -class StyleLearnerDatabaseTest: - """使用数据库数据测试StyleLearner""" - - def __init__(self, random_state: int = 42): - self.random_state = random_state - self.manager = StyleLearnerManager(model_save_path="data/test_style_models") - - # 测试结果 - self.test_results = { - "total_samples": 0, - "train_samples": 0, - "test_samples": 0, - "unique_styles": 0, - "unique_chat_ids": 0, - "accuracy": 0.0, - "precision": 0.0, - "recall": 0.0, - "f1_score": 0.0, - "predictions": [], - "ground_truth": [], - "model_save_success": False, - "model_save_path": self.manager.model_save_path - } - - def load_data_from_database(self) -> List[Dict]: - """ - 从数据库加载expression数据 - - Returns: - List[Dict]: 包含up_content, style, chat_id的数据列表 - """ - try: - # 连接数据库 - db.connect(reuse_if_open=True) - - # 查询所有expression数据 - expressions = Expression.select().where( - (Expression.up_content.is_null(False)) & - (Expression.style.is_null(False)) & - (Expression.chat_id.is_null(False)) & - (Expression.type == "style") - ) - - data = [] - for expr in expressions: - if expr.up_content and expr.style and expr.chat_id: - data.append({ - "up_content": expr.up_content, - "style": expr.style, - "chat_id": expr.chat_id, - "last_active_time": expr.last_active_time, - "context": expr.context, - "situation": expr.situation - }) - - logger.info(f"从数据库加载了 {len(data)} 条expression数据") - return data - - except Exception as e: - logger.error(f"从数据库加载数据失败: {e}") - return [] - - def preprocess_data(self, data: List[Dict]) -> List[Dict]: - """ - 数据预处理 - - Args: - data: 原始数据 - - Returns: - List[Dict]: 预处理后的数据 - """ - # 过滤掉空值或过短的数据 - filtered_data = [] - for item in data: - up_content = item["up_content"].strip() - style = item["style"].strip() - - if len(up_content) >= 2 and len(style) >= 2: - filtered_data.append({ - "up_content": up_content, - "style": style, - "chat_id": item["chat_id"], - "last_active_time": item["last_active_time"], - "context": item["context"], - "situation": item["situation"] - }) - - logger.info(f"预处理后剩余 {len(filtered_data)} 条数据") - return filtered_data - - def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]: - """ - 分割训练集和测试集 - 训练集使用所有数据,测试集从训练集中随机选择5% - - Args: - data: 预处理后的数据 - - Returns: - Tuple[List[Dict], List[Dict]]: (训练集, 测试集) - """ - # 训练集使用所有数据 - train_data = data.copy() - - # 测试集从训练集中随机选择5% - test_size = 0.05 # 5% - test_data = train_test_split( - train_data, test_size=test_size, random_state=self.random_state - )[1] # 只取测试集部分 - - logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条") - logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%") - return train_data, test_data - - def train_model(self, train_data: List[Dict]) -> None: - """ - 训练模型 - - Args: - train_data: 训练数据 - """ - logger.info("开始训练模型...") - - # 统计信息 - chat_ids = set() - styles = set() - - for item in train_data: - chat_id = item["chat_id"] - up_content = item["up_content"] - style = item["style"] - - chat_ids.add(chat_id) - styles.add(style) - - # 学习映射关系 - success = self.manager.learn_mapping(chat_id, up_content, style) - if not success: - logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}") - - self.test_results["train_samples"] = len(train_data) - self.test_results["unique_styles"] = len(styles) - self.test_results["unique_chat_ids"] = len(chat_ids) - - logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室") - - # 保存训练好的模型 - logger.info("开始保存训练好的模型...") - save_success = self.manager.save_all_models() - self.test_results["model_save_success"] = save_success - - if save_success: - logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}") - print(f"✅ 模型已保存到: {self.manager.model_save_path}") - else: - logger.warning("部分模型保存失败") - print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}") - - def test_model(self, test_data: List[Dict]) -> None: - """ - 测试模型 - - Args: - test_data: 测试数据 - """ - logger.info("开始测试模型...") - - predictions = [] - ground_truth = [] - correct_predictions = 0 - - for item in test_data: - chat_id = item["chat_id"] - up_content = item["up_content"] - true_style = item["style"] - - # 预测风格 - predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1) - - predictions.append(predicted_style) - ground_truth.append(true_style) - - # 检查预测是否正确 - if predicted_style == true_style: - correct_predictions += 1 - - # 记录详细预测结果 - self.test_results["predictions"].append({ - "chat_id": chat_id, - "up_content": up_content, - "true_style": true_style, - "predicted_style": predicted_style, - "scores": scores - }) - - # 计算准确率 - accuracy = correct_predictions / len(test_data) if test_data else 0 - - # 计算其他指标(需要处理None值) - valid_predictions = [p for p in predictions if p is not None] - valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None] - - if valid_predictions: - precision, recall, f1, _ = precision_recall_fscore_support( - valid_ground_truth, valid_predictions, average='weighted', zero_division=0 - ) - else: - precision = recall = f1 = 0.0 - - self.test_results["test_samples"] = len(test_data) - self.test_results["accuracy"] = accuracy - self.test_results["precision"] = precision - self.test_results["recall"] = recall - self.test_results["f1_score"] = f1 - - logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}") - - def analyze_results(self) -> None: - """分析测试结果""" - logger.info("=== 测试结果分析 ===") - - print("\n📊 数据统计:") - print(f" 总样本数: {self.test_results['total_samples']}") - print(f" 训练样本数: {self.test_results['train_samples']}") - print(f" 测试样本数: {self.test_results['test_samples']}") - print(f" 唯一风格数: {self.test_results['unique_styles']}") - print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}") - - print("\n🎯 模型性能:") - print(f" 准确率: {self.test_results['accuracy']:.4f}") - print(f" 精确率: {self.test_results['precision']:.4f}") - print(f" 召回率: {self.test_results['recall']:.4f}") - print(f" F1分数: {self.test_results['f1_score']:.4f}") - - print("\n💾 模型保存:") - save_status = "成功" if self.test_results['model_save_success'] else "失败" - print(f" 保存状态: {save_status}") - print(f" 保存路径: {self.test_results['model_save_path']}") - - # 分析各聊天室的性能 - chat_performance = {} - for pred in self.test_results["predictions"]: - chat_id = pred["chat_id"] - if chat_id not in chat_performance: - chat_performance[chat_id] = {"correct": 0, "total": 0} - - chat_performance[chat_id]["total"] += 1 - if pred["predicted_style"] == pred["true_style"]: - chat_performance[chat_id]["correct"] += 1 - - print("\n📈 各聊天室性能:") - for chat_id, perf in chat_performance.items(): - accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0 - print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})") - - # 分析风格分布 - style_counts = {} - for pred in self.test_results["predictions"]: - style = pred["true_style"] - style_counts[style] = style_counts.get(style, 0) + 1 - - print("\n🎨 风格分布 (前10个):") - sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True) - for style, count in sorted_styles[:10]: - print(f" {style}: {count} 次") - - def show_sample_predictions(self, num_samples: int = 10) -> None: - """显示样本预测结果""" - print(f"\n🔍 样本预测结果 (前{num_samples}个):") - - for i, pred in enumerate(self.test_results["predictions"][:num_samples]): - status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗" - print(f"\n {i+1}. {status}") - print(f" 聊天室: {pred['chat_id']}") - print(f" 输入内容: {pred['up_content']}") - print(f" 真实风格: {pred['true_style']}") - print(f" 预测风格: {pred['predicted_style']}") - if pred["scores"]: - top_scores = dict(list(pred["scores"].items())[:3]) - print(f" 分数: {top_scores}") - - def save_results(self, output_file: str = "style_learner_test_results.txt") -> None: - """保存测试结果到文件""" - try: - with open(output_file, "w", encoding="utf-8") as f: - f.write("StyleLearner 数据库测试结果\n") - f.write("=" * 50 + "\n\n") - - f.write("数据统计:\n") - f.write(f" 总样本数: {self.test_results['total_samples']}\n") - f.write(f" 训练样本数: {self.test_results['train_samples']}\n") - f.write(f" 测试样本数: {self.test_results['test_samples']}\n") - f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n") - f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n") - - f.write("模型性能:\n") - f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n") - f.write(f" 精确率: {self.test_results['precision']:.4f}\n") - f.write(f" 召回率: {self.test_results['recall']:.4f}\n") - f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n") - - f.write("模型保存:\n") - save_status = "成功" if self.test_results['model_save_success'] else "失败" - f.write(f" 保存状态: {save_status}\n") - f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n") - - f.write("详细预测结果:\n") - for i, pred in enumerate(self.test_results["predictions"]): - status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗" - f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n") - - logger.info(f"测试结果已保存到 {output_file}") - - except Exception as e: - logger.error(f"保存测试结果失败: {e}") - - def run_test(self) -> None: - """运行完整测试""" - logger.info("开始StyleLearner数据库测试...") - - # 1. 加载数据 - raw_data = self.load_data_from_database() - if not raw_data: - logger.error("没有加载到数据,测试终止") - return - - # 2. 数据预处理 - processed_data = self.preprocess_data(raw_data) - if not processed_data: - logger.error("预处理后没有数据,测试终止") - return - - self.test_results["total_samples"] = len(processed_data) - - # 3. 分割数据 - train_data, test_data = self.split_data(processed_data) - - # 4. 训练模型 - self.train_model(train_data) - - # 5. 测试模型 - self.test_model(test_data) - - # 6. 分析结果 - self.analyze_results() - - # 7. 显示样本预测 - self.show_sample_predictions(10) - - # 8. 保存结果 - self.save_results() - - logger.info("StyleLearner数据库测试完成!") - - -def main(): - """主函数""" - print("StyleLearner 数据库测试脚本") - print("=" * 50) - - # 创建测试实例 - test = StyleLearnerDatabaseTest(random_state=42) - - # 运行测试 - test.run_test() - - -if __name__ == "__main__": - main()