mirror of https://github.com/Mai-with-u/MaiBot.git
387 lines
14 KiB
Python
387 lines
14 KiB
Python
import argparse
|
||
import json
|
||
import random
|
||
import re
|
||
import sys
|
||
import os
|
||
from datetime import datetime
|
||
from typing import Dict, Iterable, List, Optional, Tuple
|
||
from src.common.data_models.database_data_model import DatabaseMessages
|
||
from src.common.message_repository import find_messages
|
||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||
|
||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||
if PROJECT_ROOT not in sys.path:
|
||
sys.path.insert(0, PROJECT_ROOT)
|
||
|
||
|
||
|
||
|
||
SECONDS_5_MINUTES = 5 * 60
|
||
|
||
|
||
def clean_output_text(text: str) -> str:
|
||
"""
|
||
清理输出文本,移除表情包和回复内容
|
||
- 移除 [表情包:...] 格式的内容
|
||
- 移除 [回复...] 格式的内容
|
||
"""
|
||
if not text:
|
||
return text
|
||
|
||
# 移除表情包内容:[表情包:...]
|
||
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
|
||
|
||
# 移除回复内容:[回复...],说:... 的完整模式
|
||
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
|
||
|
||
# 清理多余的空格和换行
|
||
text = re.sub(r"\s+", " ", text).strip()
|
||
|
||
return text
|
||
|
||
|
||
def parse_datetime_to_timestamp(value: str) -> float:
|
||
"""
|
||
接受多种常见格式并转换为时间戳(秒)
|
||
支持示例:
|
||
- 2025-09-29
|
||
- 2025-09-29 00:00:00
|
||
- 2025/09/29 00:00
|
||
- 2025-09-29T00:00:00
|
||
"""
|
||
value = value.strip()
|
||
fmts = [
|
||
"%Y-%m-%d %H:%M:%S",
|
||
"%Y-%m-%d %H:%M",
|
||
"%Y/%m/%d %H:%M:%S",
|
||
"%Y/%m/%d %H:%M",
|
||
"%Y-%m-%d",
|
||
"%Y/%m/%d",
|
||
"%Y-%m-%dT%H:%M:%S",
|
||
"%Y-%m-%dT%H:%M",
|
||
]
|
||
last_err: Optional[Exception] = None
|
||
for fmt in fmts:
|
||
try:
|
||
dt = datetime.strptime(value, fmt)
|
||
return dt.timestamp()
|
||
except Exception as e: # noqa: BLE001
|
||
last_err = e
|
||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||
|
||
|
||
def fetch_messages_between(
|
||
start_ts: float,
|
||
end_ts: float,
|
||
platform: Optional[str] = None,
|
||
) -> List[DatabaseMessages]:
|
||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||
if platform:
|
||
filter_query["chat_info_platform"] = platform
|
||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||
|
||
|
||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||
for msg in messages:
|
||
groups.setdefault(msg.chat_id, []).append(msg)
|
||
# 保证每个分组内按时间升序
|
||
for _chat_id, msgs in groups.items():
|
||
msgs.sort(key=lambda m: m.time or 0)
|
||
return groups
|
||
|
||
|
||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||
"""
|
||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||
"""
|
||
if not bucket:
|
||
raise ValueError("bucket 为空,无法合并")
|
||
|
||
latest = bucket[-1]
|
||
merged_texts: List[str] = []
|
||
for m in bucket:
|
||
text = m.processed_plain_text or ""
|
||
if text:
|
||
merged_texts.append(text)
|
||
|
||
merged = DatabaseMessages(
|
||
# 其他信息采用最新消息
|
||
message_id=latest.message_id,
|
||
time=latest.time,
|
||
chat_id=latest.chat_id,
|
||
reply_to=latest.reply_to,
|
||
interest_value=latest.interest_value,
|
||
key_words=latest.key_words,
|
||
key_words_lite=latest.key_words_lite,
|
||
is_mentioned=latest.is_mentioned,
|
||
is_at=latest.is_at,
|
||
reply_probability_boost=latest.reply_probability_boost,
|
||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||
display_message=latest.display_message,
|
||
priority_mode=latest.priority_mode,
|
||
priority_info=latest.priority_info,
|
||
additional_config=latest.additional_config,
|
||
is_emoji=latest.is_emoji,
|
||
is_picid=latest.is_picid,
|
||
is_command=latest.is_command,
|
||
is_notify=latest.is_notify,
|
||
selected_expressions=latest.selected_expressions,
|
||
user_id=latest.user_info.user_id,
|
||
user_nickname=latest.user_info.user_nickname,
|
||
user_cardname=latest.user_info.user_cardname,
|
||
user_platform=latest.user_info.platform,
|
||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||
chat_info_stream_id=latest.chat_info.stream_id,
|
||
chat_info_platform=latest.chat_info.platform,
|
||
chat_info_create_time=latest.chat_info.create_time,
|
||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||
)
|
||
return merged
|
||
|
||
|
||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||
if not messages:
|
||
return []
|
||
|
||
merged: List[DatabaseMessages] = []
|
||
bucket: List[DatabaseMessages] = []
|
||
|
||
def flush_bucket() -> None:
|
||
nonlocal bucket
|
||
if bucket:
|
||
merged.append(_merge_bucket_to_message(bucket))
|
||
bucket = []
|
||
|
||
for msg in messages:
|
||
if not bucket:
|
||
bucket = [msg]
|
||
continue
|
||
|
||
last = bucket[-1]
|
||
same_user = msg.user_info.user_id == last.user_info.user_id
|
||
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
|
||
|
||
if same_user and close_enough:
|
||
bucket.append(msg)
|
||
else:
|
||
flush_bucket()
|
||
bucket = [msg]
|
||
|
||
flush_bucket()
|
||
return merged
|
||
|
||
|
||
def build_pairs_for_chat(
|
||
original_messages: List[DatabaseMessages],
|
||
merged_messages: List[DatabaseMessages],
|
||
min_ctx: int,
|
||
max_ctx: int,
|
||
target_user_id: Optional[str] = None,
|
||
) -> List[Tuple[str, str, str]]:
|
||
"""
|
||
对每条合并后的消息作为 output,从其前面取 20-30 条(可配置)的原始消息作为 input。
|
||
input 使用原始未合并的消息构建上下文。
|
||
output 使用合并后消息的 processed_plain_text。
|
||
如果指定了 target_user_id,则只处理该用户的消息作为 output。
|
||
"""
|
||
pairs: List[Tuple[str, str, str]] = []
|
||
n_merged = len(merged_messages)
|
||
n_original = len(original_messages)
|
||
|
||
if n_merged == 0 or n_original == 0:
|
||
return pairs
|
||
|
||
# 为每个合并后的消息找到对应的原始消息位置
|
||
merged_to_original_map = {}
|
||
original_idx = 0
|
||
|
||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||
# 找到这个合并消息对应的第一个原始消息
|
||
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
|
||
original_idx += 1
|
||
|
||
# 如果找到了时间匹配的原始消息,建立映射
|
||
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
|
||
merged_to_original_map[merged_idx] = original_idx
|
||
|
||
for merged_idx in range(n_merged):
|
||
merged_msg = merged_messages[merged_idx]
|
||
|
||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||
continue
|
||
|
||
# 找到对应的原始消息位置
|
||
if merged_idx not in merged_to_original_map:
|
||
continue
|
||
|
||
original_idx = merged_to_original_map[merged_idx]
|
||
|
||
# 选择上下文窗口大小
|
||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||
start = max(0, original_idx - window)
|
||
context_msgs = original_messages[start:original_idx]
|
||
|
||
# 使用原始未合并消息构建 input
|
||
input_str = build_readable_messages(
|
||
messages=context_msgs,
|
||
timestamp_mode="normal_no_YMD",
|
||
show_actions=False,
|
||
show_pic=True,
|
||
)
|
||
|
||
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
|
||
output_text = merged_msg.processed_plain_text or ""
|
||
output_text = clean_output_text(output_text)
|
||
output_id = merged_msg.message_id or ""
|
||
pairs.append((input_str, output_text, output_id))
|
||
|
||
return pairs
|
||
|
||
|
||
def build_pairs(
|
||
start_ts: float,
|
||
end_ts: float,
|
||
platform: Optional[str],
|
||
user_id: Optional[str],
|
||
min_ctx: int,
|
||
max_ctx: int,
|
||
) -> List[Tuple[str, str, str]]:
|
||
# 获取所有消息(不按user_id过滤),这样input上下文可以包含所有用户的消息
|
||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||
groups = group_by_chat(messages)
|
||
|
||
all_pairs: List[Tuple[str, str, str]] = []
|
||
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||
# 对消息进行合并,用于output
|
||
merged = merge_adjacent_same_user(msgs)
|
||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
|
||
all_pairs.extend(pairs)
|
||
|
||
return all_pairs
|
||
|
||
|
||
def main(argv: Optional[List[str]] = None) -> int:
|
||
# 若未提供参数,则进入交互模式
|
||
if argv is None:
|
||
argv = sys.argv[1:]
|
||
|
||
if len(argv) == 0:
|
||
return run_interactive()
|
||
|
||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表,支持按用户ID筛选消息")
|
||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
|
||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||
parser.add_argument(
|
||
"--output",
|
||
default=None,
|
||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||
)
|
||
|
||
args = parser.parse_args(argv)
|
||
|
||
start_ts = parse_datetime_to_timestamp(args.start)
|
||
end_ts = parse_datetime_to_timestamp(args.end)
|
||
if end_ts <= start_ts:
|
||
raise ValueError("结束时间必须大于起始时间")
|
||
|
||
if args.max_ctx < args.min_ctx:
|
||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||
|
||
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
|
||
|
||
if args.output:
|
||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||
with open(args.output, "w", encoding="utf-8") as f:
|
||
for input_str, output_str, message_id in pairs:
|
||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||
else:
|
||
# 打印到 stdout
|
||
for input_str, output_str, message_id in pairs:
|
||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||
|
||
return 0
|
||
|
||
|
||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||
if value == "" and default is not None:
|
||
return default
|
||
return value
|
||
|
||
|
||
def run_interactive() -> int:
|
||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||
start_str = _prompt_with_default("请输入起始时间", None)
|
||
end_str = _prompt_with_default("请输入结束时间", None)
|
||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||
user_id = _prompt_with_default("用户ID(可留空表示不限)", "")
|
||
try:
|
||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||
except Exception:
|
||
print("上下文条数输入有误,使用默认 20/30")
|
||
min_ctx, max_ctx = 20, 30
|
||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||
|
||
if not start_str or not end_str:
|
||
print("必须提供起始与结束时间。")
|
||
return 2
|
||
|
||
try:
|
||
start_ts = parse_datetime_to_timestamp(start_str)
|
||
end_ts = parse_datetime_to_timestamp(end_str)
|
||
except Exception as e: # noqa: BLE001
|
||
print(f"时间解析失败:{e}")
|
||
return 2
|
||
|
||
if end_ts <= start_ts:
|
||
print("结束时间必须大于起始时间。")
|
||
return 2
|
||
|
||
if max_ctx < min_ctx:
|
||
print("最多条数不能小于最少条数。")
|
||
return 2
|
||
|
||
platform_val = platform if platform != "" else None
|
||
user_id_val = user_id if user_id != "" else None
|
||
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
|
||
|
||
if output_path:
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
for input_str, output_str, message_id in pairs:
|
||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||
else:
|
||
for input_str, output_str, message_id in pairs:
|
||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||
print(f"总计 {len(pairs)} 条。")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|