mirror of https://github.com/Mai-with-u/MaiBot.git
feat:添加数据提取脚本
parent
0683f56e23
commit
741e123496
|
|
@ -0,0 +1,327 @@
|
|||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, build_readable_messages_anonymized
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err: Optional[Exception] = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def fetch_messages_between(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||||
if platform:
|
||||
filter_query["chat_info_platform"] = platform
|
||||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||||
|
||||
|
||||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
|
||||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||||
"""
|
||||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||||
"""
|
||||
if not bucket:
|
||||
raise ValueError("bucket 为空,无法合并")
|
||||
|
||||
latest = bucket[-1]
|
||||
merged_texts: List[str] = []
|
||||
for m in bucket:
|
||||
text = m.processed_plain_text or ""
|
||||
if text:
|
||||
merged_texts.append(text)
|
||||
|
||||
merged = DatabaseMessages(
|
||||
# 其他信息采用最新消息
|
||||
message_id=latest.message_id,
|
||||
time=latest.time,
|
||||
chat_id=latest.chat_id,
|
||||
reply_to=latest.reply_to,
|
||||
interest_value=latest.interest_value,
|
||||
key_words=latest.key_words,
|
||||
key_words_lite=latest.key_words_lite,
|
||||
is_mentioned=latest.is_mentioned,
|
||||
is_at=latest.is_at,
|
||||
reply_probability_boost=latest.reply_probability_boost,
|
||||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||||
display_message=latest.display_message,
|
||||
priority_mode=latest.priority_mode,
|
||||
priority_info=latest.priority_info,
|
||||
additional_config=latest.additional_config,
|
||||
is_emoji=latest.is_emoji,
|
||||
is_picid=latest.is_picid,
|
||||
is_command=latest.is_command,
|
||||
is_notify=latest.is_notify,
|
||||
selected_expressions=latest.selected_expressions,
|
||||
user_id=latest.user_info.user_id,
|
||||
user_nickname=latest.user_info.user_nickname,
|
||||
user_cardname=latest.user_info.user_cardname,
|
||||
user_platform=latest.user_info.platform,
|
||||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||||
chat_info_stream_id=latest.chat_info.stream_id,
|
||||
chat_info_platform=latest.chat_info.platform,
|
||||
chat_info_create_time=latest.chat_info.create_time,
|
||||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: List[DatabaseMessages] = []
|
||||
bucket: List[DatabaseMessages] = []
|
||||
|
||||
def flush_bucket() -> None:
|
||||
nonlocal bucket
|
||||
if bucket:
|
||||
merged.append(_merge_bucket_to_message(bucket))
|
||||
bucket = []
|
||||
|
||||
for msg in messages:
|
||||
if not bucket:
|
||||
bucket = [msg]
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = (msg.user_info.user_id == last.user_info.user_id)
|
||||
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
else:
|
||||
flush_bucket()
|
||||
bucket = [msg]
|
||||
|
||||
flush_bucket()
|
||||
return merged
|
||||
|
||||
|
||||
def build_pairs_for_chat(
|
||||
merged_messages: List[DatabaseMessages],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
对每条消息作为 output,从其前面取 20-30 条(可配置)的消息作为 input。
|
||||
input 使用 chat_message_builder.build_readable_messages 构建为字符串。
|
||||
output 使用该消息的 processed_plain_text。
|
||||
"""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n = len(merged_messages)
|
||||
if n == 0:
|
||||
return pairs
|
||||
|
||||
for i in range(n):
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, i - window)
|
||||
context_msgs = merged_messages[start:i]
|
||||
|
||||
# 使用匿名化构建 input,并拿到原始显示名 -> 匿名名的映射
|
||||
input_str, name_mapping = build_readable_messages_anonymized(
|
||||
messages=context_msgs,
|
||||
timestamp_mode="relative",
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
)
|
||||
|
||||
# 输出取 processed_plain_text(不再额外替换)
|
||||
output_text = merged_messages[i].processed_plain_text or ""
|
||||
output_id = merged_messages[i].message_id or ""
|
||||
pairs.append((input_str, output_text, output_id))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def build_pairs(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
pairs = build_pairs_for_chat(merged, min_ctx, max_ctx)
|
||||
all_pairs.extend(pairs)
|
||||
|
||||
return all_pairs
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
# 若未提供参数,则进入交互模式
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
if len(argv) == 0:
|
||||
return run_interactive()
|
||||
|
||||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表")
|
||||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
start_ts = parse_datetime_to_timestamp(args.start)
|
||||
end_ts = parse_datetime_to_timestamp(args.end)
|
||||
if end_ts <= start_ts:
|
||||
raise ValueError("结束时间必须大于起始时间")
|
||||
|
||||
if args.max_ctx < args.min_ctx:
|
||||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||||
|
||||
pairs = build_pairs(start_ts, end_ts, args.platform, args.min_ctx, args.max_ctx)
|
||||
|
||||
if args.output:
|
||||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||||
else:
|
||||
# 打印到 stdout
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||||
if value == "" and default is not None:
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def run_interactive() -> int:
|
||||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||||
start_str = _prompt_with_default("请输入起始时间", None)
|
||||
end_str = _prompt_with_default("请输入结束时间", None)
|
||||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||||
try:
|
||||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||||
except Exception:
|
||||
print("上下文条数输入有误,使用默认 20/30")
|
||||
min_ctx, max_ctx = 20, 30
|
||||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||||
|
||||
if not start_str or not end_str:
|
||||
print("必须提供起始与结束时间。")
|
||||
return 2
|
||||
|
||||
try:
|
||||
start_ts = parse_datetime_to_timestamp(start_str)
|
||||
end_ts = parse_datetime_to_timestamp(end_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"时间解析失败:{e}")
|
||||
return 2
|
||||
|
||||
if end_ts <= start_ts:
|
||||
print("结束时间必须大于起始时间。")
|
||||
return 2
|
||||
|
||||
if max_ctx < min_ctx:
|
||||
print("最多条数不能小于最少条数。")
|
||||
return 2
|
||||
|
||||
platform_val = platform if platform != "" else None
|
||||
pairs = build_pairs(start_ts, end_ts, platform_val, min_ctx, max_ctx)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||||
else:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
print(f"总计 {len(pairs)} 条。")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
|
|
@ -2,7 +2,7 @@ import time
|
|||
import random
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
|
|
@ -895,6 +895,173 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||
return formatted_string
|
||||
|
||||
|
||||
def build_readable_messages_anonymized(
|
||||
messages: List[DatabaseMessages],
|
||||
timestamp_mode: str = "relative",
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
replace_bot_name: bool = True,
|
||||
) -> Tuple[str, Dict[str, str]]:
|
||||
"""
|
||||
仿照 build_readable_messages,构建匿名化的可读消息:
|
||||
- 所有用户名替换为 用户A、用户B、...、用户Z、用户AA、用户AB ...
|
||||
- 内容中的 回复<aaa:bbb> 与 @<aaa:bbb> 也替换为匿名名
|
||||
|
||||
Returns:
|
||||
formatted_string: 格式化后的聊天记录字符串
|
||||
mapping: 原始显示用户名 -> 匿名名 的映射表
|
||||
"""
|
||||
if not messages:
|
||||
return "", {}
|
||||
|
||||
# 生成匿名标签:A..Z, AA..AZ, BA.. 等
|
||||
def alphabet_labels() -> Iterable[str]:
|
||||
import string
|
||||
|
||||
letters = string.ascii_uppercase
|
||||
# 单字母
|
||||
for ch in letters:
|
||||
yield ch
|
||||
# 多字母(简单生成两位,若需要可继续扩展)
|
||||
for a in letters:
|
||||
for b in letters:
|
||||
yield f"{a}{b}"
|
||||
|
||||
label_iter = alphabet_labels()
|
||||
user_to_label: Dict[Tuple[str, str], str] = {}
|
||||
name_mapping: Dict[str, str] = {}
|
||||
|
||||
def get_display_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str:
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
|
||||
|
||||
def get_anon_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str:
|
||||
key = (platform or "", user_id or "")
|
||||
# 机器人处理:若需要替换机器人名称,则直接返回 昵称(你)
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
anon = f"{global_config.bot.nickname}(你)"
|
||||
original_display = get_display_name(platform, user_id, user_nickname, user_cardname)
|
||||
if original_display not in name_mapping:
|
||||
name_mapping[original_display] = anon
|
||||
return anon
|
||||
if key not in user_to_label:
|
||||
user_to_label[key] = f"用户{next(label_iter)}"
|
||||
anon = user_to_label[key]
|
||||
# 记录原始显示名到匿名名(可能重复显示名时后写覆盖)
|
||||
original_display = get_display_name(platform, user_id, user_nickname, user_cardname)
|
||||
if original_display not in name_mapping:
|
||||
name_mapping[original_display] = anon
|
||||
return anon
|
||||
|
||||
# 将 DatabaseMessages 转换为可处理结构,并可选拼入动作
|
||||
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
|
||||
|
||||
if show_actions and copy_messages:
|
||||
min_time = min(msg.time or 0 for msg in copy_messages)
|
||||
max_time = max(msg.time or 0 for msg in copy_messages)
|
||||
chat_id = messages[0].chat_id if messages else None
|
||||
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
||||
for action in actions:
|
||||
if action.action_build_into_prompt:
|
||||
action_msg = MessageAndActionModel(
|
||||
time=float(action.time), # type: ignore
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_platform=global_config.bot.platform,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
user_cardname="",
|
||||
processed_plain_text=f"{action.action_prompt_display}",
|
||||
display_message=f"{action.action_prompt_display}",
|
||||
chat_info_platform=str(action.chat_info_platform),
|
||||
is_action_record=True,
|
||||
action_name=str(action.action_name),
|
||||
)
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
copy_messages.sort(key=lambda x: x.time or 0)
|
||||
|
||||
# 图片替换帮助
|
||||
def process_pic_ids(content: Optional[str]) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(_m: re.Match) -> str:
|
||||
return "[图片]" if show_pic else ""
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 内容引用替换的 resolver:将 <aaa:bbb> / @<aaa:bbb> 中的 bbb 映射为匿名名
|
||||
def anon_name_resolver(platform: str, user_id: str) -> str:
|
||||
try:
|
||||
# 与主流程一致处理机器人名字
|
||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
return get_anon_name(platform, user_id, "", None)
|
||||
except Exception:
|
||||
return "用户?"
|
||||
|
||||
# 构建结果
|
||||
detailed: List[Tuple[float, str, str, bool]] = []
|
||||
|
||||
for m in copy_messages:
|
||||
if m.is_action_record:
|
||||
content = process_pic_ids(m.display_message)
|
||||
detailed.append((m.time or 0.0, "", content, True))
|
||||
continue
|
||||
|
||||
platform = m.user_platform
|
||||
user_id = m.user_id
|
||||
user_nickname = m.user_nickname
|
||||
user_cardname = m.user_cardname
|
||||
content = m.display_message or m.processed_plain_text or ""
|
||||
|
||||
content = process_pic_ids(content)
|
||||
anon_name = get_anon_name(platform, user_id, user_nickname, user_cardname)
|
||||
try:
|
||||
content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
detailed.append((m.time or 0.0, anon_name, content, False))
|
||||
|
||||
if not detailed:
|
||||
return "", name_mapping
|
||||
|
||||
detailed.sort(key=lambda x: x[0])
|
||||
|
||||
output_lines: List[str] = []
|
||||
for ts, name, content, is_action in detailed:
|
||||
readable_time = translate_timestamp_to_human_readable(ts, mode=timestamp_mode)
|
||||
if is_action:
|
||||
output_lines.append(f"{readable_time}, {content}")
|
||||
else:
|
||||
output_lines.append(f"{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n")
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
# 最后对完整字符串再按映射表做一次替换,处理正文里直接出现的原始昵称
|
||||
if name_mapping:
|
||||
for original_name, anon_name in sorted(name_mapping.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
if original_name:
|
||||
formatted_string = formatted_string.replace(original_name, anon_name)
|
||||
|
||||
return formatted_string, name_mapping
|
||||
|
||||
|
||||
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)。
|
||||
|
|
|
|||
Loading…
Reference in New Issue