MaiBot/scripts/build_io_pairs.py

385 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import argparse
import json
import random
import re
import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60
def clean_output_text(text: str) -> str:
"""
清理输出文本,移除表情包和回复内容
- 移除 [表情包:...] 格式的内容
- 移除 [回复...] 格式的内容
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
# 清理多余的空格和换行
text = re.sub(r"\s+", " ", text).strip()
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err: Optional[Exception] = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e: # noqa: BLE001
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def fetch_messages_between(
start_ts: float,
end_ts: float,
platform: Optional[str] = None,
) -> List[DatabaseMessages]:
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
if platform:
filter_query["chat_info_platform"] = platform
# 当 limit==0 时sort 生效,这里按时间升序
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
groups: Dict[str, List[DatabaseMessages]] = {}
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for _chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
"""
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
"""
if not bucket:
raise ValueError("bucket 为空,无法合并")
latest = bucket[-1]
merged_texts: List[str] = []
for m in bucket:
text = m.processed_plain_text or ""
if text:
merged_texts.append(text)
merged = DatabaseMessages(
# 其他信息采用最新消息
message_id=latest.message_id,
time=latest.time,
chat_id=latest.chat_id,
reply_to=latest.reply_to,
interest_value=latest.interest_value,
key_words=latest.key_words,
key_words_lite=latest.key_words_lite,
is_mentioned=latest.is_mentioned,
is_at=latest.is_at,
reply_probability_boost=latest.reply_probability_boost,
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
display_message=latest.display_message,
priority_mode=latest.priority_mode,
priority_info=latest.priority_info,
additional_config=latest.additional_config,
is_emoji=latest.is_emoji,
is_picid=latest.is_picid,
is_command=latest.is_command,
is_notify=latest.is_notify,
selected_expressions=latest.selected_expressions,
user_id=latest.user_info.user_id,
user_nickname=latest.user_info.user_nickname,
user_cardname=latest.user_info.user_cardname,
user_platform=latest.user_info.platform,
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
chat_info_user_id=latest.chat_info.user_info.user_id,
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
chat_info_user_platform=latest.chat_info.user_info.platform,
chat_info_stream_id=latest.chat_info.stream_id,
chat_info_platform=latest.chat_info.platform,
chat_info_create_time=latest.chat_info.create_time,
chat_info_last_active_time=latest.chat_info.last_active_time,
)
return merged
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
if not messages:
return []
merged: List[DatabaseMessages] = []
bucket: List[DatabaseMessages] = []
def flush_bucket() -> None:
nonlocal bucket
if bucket:
merged.append(_merge_bucket_to_message(bucket))
bucket = []
for msg in messages:
if not bucket:
bucket = [msg]
continue
last = bucket[-1]
same_user = msg.user_info.user_id == last.user_info.user_id
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
if same_user and close_enough:
bucket.append(msg)
else:
flush_bucket()
bucket = [msg]
flush_bucket()
return merged
def build_pairs_for_chat(
original_messages: List[DatabaseMessages],
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
target_user_id: Optional[str] = None,
) -> List[Tuple[str, str, str]]:
"""
对每条合并后的消息作为 output从其前面取 20-30 条(可配置)的原始消息作为 input。
input 使用原始未合并的消息构建上下文。
output 使用合并后消息的 processed_plain_text。
如果指定了 target_user_id则只处理该用户的消息作为 output。
"""
pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window)
context_msgs = original_messages[start:original_idx]
# 使用原始未合并消息构建 input
input_str = build_readable_messages(
messages=context_msgs,
timestamp_mode="normal_no_YMD",
show_actions=False,
show_pic=True,
)
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
output_text = merged_msg.processed_plain_text or ""
output_text = clean_output_text(output_text)
output_id = merged_msg.message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
user_id: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
# 获取所有消息不按user_id过滤这样input上下文可以包含所有用户的消息
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
all_pairs.extend(pairs)
return all_pairs
def main(argv: Optional[List[str]] = None) -> int:
# 若未提供参数,则进入交互模式
if argv is None:
argv = sys.argv[1:]
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表支持按用户ID筛选消息")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
"--output",
default=None,
help="输出保存路径,支持 .jsonl每行 {input, output}若不指定则打印到stdout",
)
args = parser.parse_args(argv)
start_ts = parse_datetime_to_timestamp(args.start)
end_ts = parse_datetime_to_timestamp(args.end)
if end_ts <= start_ts:
raise ValueError("结束时间必须大于起始时间")
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
with open(args.output, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {args.output}")
else:
# 打印到 stdout
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
return 0
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
suffix = f"[{default}]" if default not in (None, "") else ""
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
if value == "" and default is not None:
return default
return value
def run_interactive() -> int:
print("进入交互模式直接回车采用默认值。时间格式例如2025-09-28 00:00:00 或 2025-09-28")
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
user_id = _prompt_with_default("用户ID可留空表示不限", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
except Exception:
print("上下文条数输入有误,使用默认 20/30")
min_ctx, max_ctx = 20, 30
output_path = _prompt_with_default("输出路径(.jsonl可留空打印到控制台", "")
if not start_str or not end_str:
print("必须提供起始与结束时间。")
return 2
try:
start_ts = parse_datetime_to_timestamp(start_str)
end_ts = parse_datetime_to_timestamp(end_str)
except Exception as e: # noqa: BLE001
print(f"时间解析失败:{e}")
return 2
if end_ts <= start_ts:
print("结束时间必须大于起始时间。")
return 2
if max_ctx < min_ctx:
print("最多条数不能小于最少条数。")
return 2
platform_val = platform if platform != "" else None
user_id_val = user_id if user_id != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {output_path}")
else:
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
print(f"总计 {len(pairs)} 条。")
return 0
if __name__ == "__main__":
sys.exit(main())