diff --git a/scripts/build_io_pairs.py b/scripts/build_io_pairs.py new file mode 100644 index 00000000..2c36dc28 --- /dev/null +++ b/scripts/build_io_pairs.py @@ -0,0 +1,327 @@ +import argparse +import json +import random +import sys +import os +from datetime import datetime +from typing import Dict, Iterable, List, Optional, Tuple + +# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级) +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from src.common.data_models.database_data_model import DatabaseMessages +from src.common.message_repository import find_messages +from src.chat.utils.chat_message_builder import build_readable_messages, build_readable_messages_anonymized + + +SECONDS_5_MINUTES = 5 * 60 + + +def parse_datetime_to_timestamp(value: str) -> float: + """ + 接受多种常见格式并转换为时间戳(秒) + 支持示例: + - 2025-09-29 + - 2025-09-29 00:00:00 + - 2025/09/29 00:00 + - 2025-09-29T00:00:00 + """ + value = value.strip() + fmts = [ + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d", + "%Y/%m/%d", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M", + ] + last_err: Optional[Exception] = None + for fmt in fmts: + try: + dt = datetime.strptime(value, fmt) + return dt.timestamp() + except Exception as e: # noqa: BLE001 + last_err = e + raise ValueError(f"无法解析时间: {value} ({last_err})") + + +def fetch_messages_between( + start_ts: float, + end_ts: float, + platform: Optional[str] = None, +) -> List[DatabaseMessages]: + """使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。""" + filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}} + if platform: + filter_query["chat_info_platform"] = platform + # 当 limit==0 时,sort 生效,这里按时间升序 + return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0) + + +def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]: + groups: Dict[str, List[DatabaseMessages]] = {} + for msg in messages: + groups.setdefault(msg.chat_id, []).append(msg) + # 保证每个分组内按时间升序 + for chat_id, msgs in groups.items(): + msgs.sort(key=lambda m: m.time or 0) + return groups + + +def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages: + """ + 将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。 + processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。 + """ + if not bucket: + raise ValueError("bucket 为空,无法合并") + + latest = bucket[-1] + merged_texts: List[str] = [] + for m in bucket: + text = m.processed_plain_text or "" + if text: + merged_texts.append(text) + + merged = DatabaseMessages( + # 其他信息采用最新消息 + message_id=latest.message_id, + time=latest.time, + chat_id=latest.chat_id, + reply_to=latest.reply_to, + interest_value=latest.interest_value, + key_words=latest.key_words, + key_words_lite=latest.key_words_lite, + is_mentioned=latest.is_mentioned, + is_at=latest.is_at, + reply_probability_boost=latest.reply_probability_boost, + processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text, + display_message=latest.display_message, + priority_mode=latest.priority_mode, + priority_info=latest.priority_info, + additional_config=latest.additional_config, + is_emoji=latest.is_emoji, + is_picid=latest.is_picid, + is_command=latest.is_command, + is_notify=latest.is_notify, + selected_expressions=latest.selected_expressions, + user_id=latest.user_info.user_id, + user_nickname=latest.user_info.user_nickname, + user_cardname=latest.user_info.user_cardname, + user_platform=latest.user_info.platform, + chat_info_group_id=(latest.group_info.group_id if latest.group_info else None), + chat_info_group_name=(latest.group_info.group_name if latest.group_info else None), + chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None), + chat_info_user_id=latest.chat_info.user_info.user_id, + chat_info_user_nickname=latest.chat_info.user_info.user_nickname, + chat_info_user_cardname=latest.chat_info.user_info.user_cardname, + chat_info_user_platform=latest.chat_info.user_info.platform, + chat_info_stream_id=latest.chat_info.stream_id, + chat_info_platform=latest.chat_info.platform, + chat_info_create_time=latest.chat_info.create_time, + chat_info_last_active_time=latest.chat_info.last_active_time, + ) + return merged + + +def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: + """按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。""" + if not messages: + return [] + + merged: List[DatabaseMessages] = [] + bucket: List[DatabaseMessages] = [] + + def flush_bucket() -> None: + nonlocal bucket + if bucket: + merged.append(_merge_bucket_to_message(bucket)) + bucket = [] + + for msg in messages: + if not bucket: + bucket = [msg] + continue + + last = bucket[-1] + same_user = (msg.user_info.user_id == last.user_info.user_id) + close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES) + + if same_user and close_enough: + bucket.append(msg) + else: + flush_bucket() + bucket = [msg] + + flush_bucket() + return merged + + +def build_pairs_for_chat( + merged_messages: List[DatabaseMessages], + min_ctx: int, + max_ctx: int, +) -> List[Tuple[str, str, str]]: + """ + 对每条消息作为 output,从其前面取 20-30 条(可配置)的消息作为 input。 + input 使用 chat_message_builder.build_readable_messages 构建为字符串。 + output 使用该消息的 processed_plain_text。 + """ + pairs: List[Tuple[str, str, str]] = [] + n = len(merged_messages) + if n == 0: + return pairs + + for i in range(n): + # 选择上下文窗口大小 + window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx + start = max(0, i - window) + context_msgs = merged_messages[start:i] + + # 使用匿名化构建 input,并拿到原始显示名 -> 匿名名的映射 + input_str, name_mapping = build_readable_messages_anonymized( + messages=context_msgs, + timestamp_mode="relative", + show_actions=False, + show_pic=True, + ) + + # 输出取 processed_plain_text(不再额外替换) + output_text = merged_messages[i].processed_plain_text or "" + output_id = merged_messages[i].message_id or "" + pairs.append((input_str, output_text, output_id)) + + return pairs + + +def build_pairs( + start_ts: float, + end_ts: float, + platform: Optional[str], + min_ctx: int, + max_ctx: int, +) -> List[Tuple[str, str, str]]: + messages = fetch_messages_between(start_ts, end_ts, platform) + groups = group_by_chat(messages) + + all_pairs: List[Tuple[str, str, str]] = [] + for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用 + merged = merge_adjacent_same_user(msgs) + pairs = build_pairs_for_chat(merged, min_ctx, max_ctx) + all_pairs.extend(pairs) + + return all_pairs + + +def main(argv: Optional[List[str]] = None) -> int: + # 若未提供参数,则进入交互模式 + if argv is None: + argv = sys.argv[1:] + + if len(argv) == 0: + return run_interactive() + + parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表") + parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00") + parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00") + parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息") + parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20") + parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30") + parser.add_argument( + "--output", + default=None, + help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout", + ) + + args = parser.parse_args(argv) + + start_ts = parse_datetime_to_timestamp(args.start) + end_ts = parse_datetime_to_timestamp(args.end) + if end_ts <= start_ts: + raise ValueError("结束时间必须大于起始时间") + + if args.max_ctx < args.min_ctx: + raise ValueError("max_ctx 不能小于 min_ctx") + + pairs = build_pairs(start_ts, end_ts, args.platform, args.min_ctx, args.max_ctx) + + if args.output: + # 保存为 JSONL,每行一个 {input, output, message_id} + with open(args.output, "w", encoding="utf-8") as f: + for input_str, output_str, message_id in pairs: + obj = {"input": input_str, "output": output_str, "message_id": message_id} + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + print(f"已保存 {len(pairs)} 条到 {args.output}") + else: + # 打印到 stdout + for input_str, output_str, message_id in pairs: + print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False)) + + return 0 + + +def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str: + suffix = f"[{default}]" if default not in (None, "") else "" + value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip() + if value == "" and default is not None: + return default + return value + + +def run_interactive() -> int: + print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28") + start_str = _prompt_with_default("请输入起始时间", None) + end_str = _prompt_with_default("请输入结束时间", None) + platform = _prompt_with_default("平台(可留空表示不限)", "") + try: + min_ctx = int(_prompt_with_default("输入上下文最少条数", "20")) + max_ctx = int(_prompt_with_default("输入上下文最多条数", "30")) + except Exception: + print("上下文条数输入有误,使用默认 20/30") + min_ctx, max_ctx = 20, 30 + output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "") + + if not start_str or not end_str: + print("必须提供起始与结束时间。") + return 2 + + try: + start_ts = parse_datetime_to_timestamp(start_str) + end_ts = parse_datetime_to_timestamp(end_str) + except Exception as e: # noqa: BLE001 + print(f"时间解析失败:{e}") + return 2 + + if end_ts <= start_ts: + print("结束时间必须大于起始时间。") + return 2 + + if max_ctx < min_ctx: + print("最多条数不能小于最少条数。") + return 2 + + platform_val = platform if platform != "" else None + pairs = build_pairs(start_ts, end_ts, platform_val, min_ctx, max_ctx) + + if output_path: + with open(output_path, "w", encoding="utf-8") as f: + for input_str, output_str, message_id in pairs: + obj = {"input": input_str, "output": output_str, "message_id": message_id} + f.write(json.dumps(obj, ensure_ascii=False) + "\n") + print(f"已保存 {len(pairs)} 条到 {output_path}") + else: + for input_str, output_str, message_id in pairs: + print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False)) + print(f"总计 {len(pairs)} 条。") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + + diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b02a9164..6cf0feab 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -2,7 +2,7 @@ import time import random import re -from typing import List, Dict, Any, Tuple, Optional, Callable +from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable from rich.traceback import install from src.config.config import global_config @@ -895,6 +895,173 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str: return formatted_string +def build_readable_messages_anonymized( + messages: List[DatabaseMessages], + timestamp_mode: str = "relative", + show_actions: bool = False, + show_pic: bool = True, + replace_bot_name: bool = True, +) -> Tuple[str, Dict[str, str]]: + """ + 仿照 build_readable_messages,构建匿名化的可读消息: + - 所有用户名替换为 用户A、用户B、...、用户Z、用户AA、用户AB ... + - 内容中的 回复 与 @ 也替换为匿名名 + + Returns: + formatted_string: 格式化后的聊天记录字符串 + mapping: 原始显示用户名 -> 匿名名 的映射表 + """ + if not messages: + return "", {} + + # 生成匿名标签:A..Z, AA..AZ, BA.. 等 + def alphabet_labels() -> Iterable[str]: + import string + + letters = string.ascii_uppercase + # 单字母 + for ch in letters: + yield ch + # 多字母(简单生成两位,若需要可继续扩展) + for a in letters: + for b in letters: + yield f"{a}{b}" + + label_iter = alphabet_labels() + user_to_label: Dict[Tuple[str, str], str] = {} + name_mapping: Dict[str, str] = {} + + def get_display_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str: + person = Person(platform=platform, user_id=user_id) + return person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人") + + def get_anon_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str: + key = (platform or "", user_id or "") + # 机器人处理:若需要替换机器人名称,则直接返回 昵称(你) + if replace_bot_name and user_id == global_config.bot.qq_account: + anon = f"{global_config.bot.nickname}(你)" + original_display = get_display_name(platform, user_id, user_nickname, user_cardname) + if original_display not in name_mapping: + name_mapping[original_display] = anon + return anon + if key not in user_to_label: + user_to_label[key] = f"用户{next(label_iter)}" + anon = user_to_label[key] + # 记录原始显示名到匿名名(可能重复显示名时后写覆盖) + original_display = get_display_name(platform, user_id, user_nickname, user_cardname) + if original_display not in name_mapping: + name_mapping[original_display] = anon + return anon + + # 将 DatabaseMessages 转换为可处理结构,并可选拼入动作 + copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages] + + if show_actions and copy_messages: + min_time = min(msg.time or 0 for msg in copy_messages) + max_time = max(msg.time or 0 for msg in copy_messages) + chat_id = messages[0].chat_id if messages else None + + actions_in_range = ( + ActionRecords.select() + .where((ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + ) + action_after_latest = ( + ActionRecords.select() + .where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id)) + .order_by(ActionRecords.time) + .limit(1) + ) + + actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest) + for action in actions: + if action.action_build_into_prompt: + action_msg = MessageAndActionModel( + time=float(action.time), # type: ignore + user_id=global_config.bot.qq_account, + user_platform=global_config.bot.platform, + user_nickname=global_config.bot.nickname, + user_cardname="", + processed_plain_text=f"{action.action_prompt_display}", + display_message=f"{action.action_prompt_display}", + chat_info_platform=str(action.chat_info_platform), + is_action_record=True, + action_name=str(action.action_name), + ) + copy_messages.append(action_msg) + + copy_messages.sort(key=lambda x: x.time or 0) + + # 图片替换帮助 + def process_pic_ids(content: Optional[str]) -> str: + if content is None: + return "" + pic_pattern = r"\[picid:([^\]]+)\]" + + def replace_pic_id(_m: re.Match) -> str: + return "[图片]" if show_pic else "" + + return re.sub(pic_pattern, replace_pic_id, content) + + # 内容引用替换的 resolver:将 / @ 中的 bbb 映射为匿名名 + def anon_name_resolver(platform: str, user_id: str) -> str: + try: + # 与主流程一致处理机器人名字 + if replace_bot_name and user_id == global_config.bot.qq_account: + return f"{global_config.bot.nickname}(你)" + return get_anon_name(platform, user_id, "", None) + except Exception: + return "用户?" + + # 构建结果 + detailed: List[Tuple[float, str, str, bool]] = [] + + for m in copy_messages: + if m.is_action_record: + content = process_pic_ids(m.display_message) + detailed.append((m.time or 0.0, "", content, True)) + continue + + platform = m.user_platform + user_id = m.user_id + user_nickname = m.user_nickname + user_cardname = m.user_cardname + content = m.display_message or m.processed_plain_text or "" + + content = process_pic_ids(content) + anon_name = get_anon_name(platform, user_id, user_nickname, user_cardname) + try: + content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False) + except Exception: + pass + + detailed.append((m.time or 0.0, anon_name, content, False)) + + if not detailed: + return "", name_mapping + + detailed.sort(key=lambda x: x[0]) + + output_lines: List[str] = [] + for ts, name, content, is_action in detailed: + readable_time = translate_timestamp_to_human_readable(ts, mode=timestamp_mode) + if is_action: + output_lines.append(f"{readable_time}, {content}") + else: + output_lines.append(f"{readable_time}, {name}: {content}") + output_lines.append("\n") + + formatted_string = "".join(output_lines).strip() + + # 最后对完整字符串再按映射表做一次替换,处理正文里直接出现的原始昵称 + if name_mapping: + for original_name, anon_name in sorted(name_mapping.items(), key=lambda x: len(x[0]), reverse=True): + if original_name: + formatted_string = formatted_string.replace(original_name, anon_name) + + return formatted_string, name_mapping + + async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: """ 从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。