feat:添加数据提取脚本

pull/1273/head
SengokuCola 2025-09-29 20:28:26 +08:00
parent 0683f56e23
commit 741e123496
2 changed files with 495 additions and 1 deletions

View File

@ -0,0 +1,327 @@
import argparse
import json
import random
import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages, build_readable_messages_anonymized
SECONDS_5_MINUTES = 5 * 60
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳
支持示例
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err: Optional[Exception] = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e: # noqa: BLE001
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def fetch_messages_between(
start_ts: float,
end_ts: float,
platform: Optional[str] = None,
) -> List[DatabaseMessages]:
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
if platform:
filter_query["chat_info_platform"] = platform
# 当 limit==0 时sort 生效,这里按时间升序
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
groups: Dict[str, List[DatabaseMessages]] = {}
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
"""
将相邻同一 user_id 5 分钟内的消息 bucket 合并为一条
processed_plain_text 合并以换行连接其余字段取最新一条时间最大
"""
if not bucket:
raise ValueError("bucket 为空,无法合并")
latest = bucket[-1]
merged_texts: List[str] = []
for m in bucket:
text = m.processed_plain_text or ""
if text:
merged_texts.append(text)
merged = DatabaseMessages(
# 其他信息采用最新消息
message_id=latest.message_id,
time=latest.time,
chat_id=latest.chat_id,
reply_to=latest.reply_to,
interest_value=latest.interest_value,
key_words=latest.key_words,
key_words_lite=latest.key_words_lite,
is_mentioned=latest.is_mentioned,
is_at=latest.is_at,
reply_probability_boost=latest.reply_probability_boost,
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
display_message=latest.display_message,
priority_mode=latest.priority_mode,
priority_info=latest.priority_info,
additional_config=latest.additional_config,
is_emoji=latest.is_emoji,
is_picid=latest.is_picid,
is_command=latest.is_command,
is_notify=latest.is_notify,
selected_expressions=latest.selected_expressions,
user_id=latest.user_info.user_id,
user_nickname=latest.user_info.user_nickname,
user_cardname=latest.user_info.user_cardname,
user_platform=latest.user_info.platform,
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
chat_info_user_id=latest.chat_info.user_info.user_id,
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
chat_info_user_platform=latest.chat_info.user_info.platform,
chat_info_stream_id=latest.chat_info.stream_id,
chat_info_platform=latest.chat_info.platform,
chat_info_create_time=latest.chat_info.create_time,
chat_info_last_active_time=latest.chat_info.last_active_time,
)
return merged
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
if not messages:
return []
merged: List[DatabaseMessages] = []
bucket: List[DatabaseMessages] = []
def flush_bucket() -> None:
nonlocal bucket
if bucket:
merged.append(_merge_bucket_to_message(bucket))
bucket = []
for msg in messages:
if not bucket:
bucket = [msg]
continue
last = bucket[-1]
same_user = (msg.user_info.user_id == last.user_info.user_id)
close_enough = ((msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES)
if same_user and close_enough:
bucket.append(msg)
else:
flush_bucket()
bucket = [msg]
flush_bucket()
return merged
def build_pairs_for_chat(
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
"""
对每条消息作为 output从其前面取 20-30 可配置的消息作为 input
input 使用 chat_message_builder.build_readable_messages 构建为字符串
output 使用该消息的 processed_plain_text
"""
pairs: List[Tuple[str, str, str]] = []
n = len(merged_messages)
if n == 0:
return pairs
for i in range(n):
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, i - window)
context_msgs = merged_messages[start:i]
# 使用匿名化构建 input并拿到原始显示名 -> 匿名名的映射
input_str, name_mapping = build_readable_messages_anonymized(
messages=context_msgs,
timestamp_mode="relative",
show_actions=False,
show_pic=True,
)
# 输出取 processed_plain_text不再额外替换
output_text = merged_messages[i].processed_plain_text or ""
output_id = merged_messages[i].message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
merged = merge_adjacent_same_user(msgs)
pairs = build_pairs_for_chat(merged, min_ctx, max_ctx)
all_pairs.extend(pairs)
return all_pairs
def main(argv: Optional[List[str]] = None) -> int:
# 若未提供参数,则进入交互模式
if argv is None:
argv = sys.argv[1:]
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
"--output",
default=None,
help="输出保存路径,支持 .jsonl每行 {input, output}若不指定则打印到stdout",
)
args = parser.parse_args(argv)
start_ts = parse_datetime_to_timestamp(args.start)
end_ts = parse_datetime_to_timestamp(args.end)
if end_ts <= start_ts:
raise ValueError("结束时间必须大于起始时间")
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
with open(args.output, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {args.output}")
else:
# 打印到 stdout
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
return 0
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
suffix = f"[{default}]" if default not in (None, "") else ""
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
if value == "" and default is not None:
return default
return value
def run_interactive() -> int:
print("进入交互模式直接回车采用默认值。时间格式例如2025-09-28 00:00:00 或 2025-09-28")
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
except Exception:
print("上下文条数输入有误,使用默认 20/30")
min_ctx, max_ctx = 20, 30
output_path = _prompt_with_default("输出路径(.jsonl可留空打印到控制台", "")
if not start_str or not end_str:
print("必须提供起始与结束时间。")
return 2
try:
start_ts = parse_datetime_to_timestamp(start_str)
end_ts = parse_datetime_to_timestamp(end_str)
except Exception as e: # noqa: BLE001
print(f"时间解析失败:{e}")
return 2
if end_ts <= start_ts:
print("结束时间必须大于起始时间。")
return 2
if max_ctx < min_ctx:
print("最多条数不能小于最少条数。")
return 2
platform_val = platform if platform != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {output_path}")
else:
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
print(f"总计 {len(pairs)} 条。")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -2,7 +2,7 @@ import time
import random
import re
from typing import List, Dict, Any, Tuple, Optional, Callable
from typing import List, Dict, Any, Tuple, Optional, Callable, Iterable
from rich.traceback import install
from src.config.config import global_config
@ -895,6 +895,173 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
return formatted_string
def build_readable_messages_anonymized(
messages: List[DatabaseMessages],
timestamp_mode: str = "relative",
show_actions: bool = False,
show_pic: bool = True,
replace_bot_name: bool = True,
) -> Tuple[str, Dict[str, str]]:
"""
仿照 build_readable_messages构建匿名化的可读消息
- 所有用户名替换为 用户A用户B...用户Z用户AA用户AB ...
- 内容中的 回复<aaa:bbb> @<aaa:bbb> 也替换为匿名名
Returns:
formatted_string: 格式化后的聊天记录字符串
mapping: 原始显示用户名 -> 匿名名 的映射表
"""
if not messages:
return "", {}
# 生成匿名标签A..Z, AA..AZ, BA.. 等
def alphabet_labels() -> Iterable[str]:
import string
letters = string.ascii_uppercase
# 单字母
for ch in letters:
yield ch
# 多字母(简单生成两位,若需要可继续扩展)
for a in letters:
for b in letters:
yield f"{a}{b}"
label_iter = alphabet_labels()
user_to_label: Dict[Tuple[str, str], str] = {}
name_mapping: Dict[str, str] = {}
def get_display_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str:
person = Person(platform=platform, user_id=user_id)
return person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
def get_anon_name(platform: str, user_id: str, user_nickname: str, user_cardname: Optional[str]) -> str:
key = (platform or "", user_id or "")
# 机器人处理:若需要替换机器人名称,则直接返回 昵称(你)
if replace_bot_name and user_id == global_config.bot.qq_account:
anon = f"{global_config.bot.nickname}(你)"
original_display = get_display_name(platform, user_id, user_nickname, user_cardname)
if original_display not in name_mapping:
name_mapping[original_display] = anon
return anon
if key not in user_to_label:
user_to_label[key] = f"用户{next(label_iter)}"
anon = user_to_label[key]
# 记录原始显示名到匿名名(可能重复显示名时后写覆盖)
original_display = get_display_name(platform, user_id, user_nickname, user_cardname)
if original_display not in name_mapping:
name_mapping[original_display] = anon
return anon
# 将 DatabaseMessages 转换为可处理结构,并可选拼入动作
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
if show_actions and copy_messages:
min_time = min(msg.time or 0 for msg in copy_messages)
max_time = max(msg.time or 0 for msg in copy_messages)
chat_id = messages[0].chat_id if messages else None
actions_in_range = (
ActionRecords.select()
.where((ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
)
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
for action in actions:
if action.action_build_into_prompt:
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account,
user_platform=global_config.bot.platform,
user_nickname=global_config.bot.nickname,
user_cardname="",
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True,
action_name=str(action.action_name),
)
copy_messages.append(action_msg)
copy_messages.sort(key=lambda x: x.time or 0)
# 图片替换帮助
def process_pic_ids(content: Optional[str]) -> str:
if content is None:
return ""
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(_m: re.Match) -> str:
return "[图片]" if show_pic else ""
return re.sub(pic_pattern, replace_pic_id, content)
# 内容引用替换的 resolver将 <aaa:bbb> / @<aaa:bbb> 中的 bbb 映射为匿名名
def anon_name_resolver(platform: str, user_id: str) -> str:
try:
# 与主流程一致处理机器人名字
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
return get_anon_name(platform, user_id, "", None)
except Exception:
return "用户?"
# 构建结果
detailed: List[Tuple[float, str, str, bool]] = []
for m in copy_messages:
if m.is_action_record:
content = process_pic_ids(m.display_message)
detailed.append((m.time or 0.0, "", content, True))
continue
platform = m.user_platform
user_id = m.user_id
user_nickname = m.user_nickname
user_cardname = m.user_cardname
content = m.display_message or m.processed_plain_text or ""
content = process_pic_ids(content)
anon_name = get_anon_name(platform, user_id, user_nickname, user_cardname)
try:
content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
except Exception:
pass
detailed.append((m.time or 0.0, anon_name, content, False))
if not detailed:
return "", name_mapping
detailed.sort(key=lambda x: x[0])
output_lines: List[str] = []
for ts, name, content, is_action in detailed:
readable_time = translate_timestamp_to_human_readable(ts, mode=timestamp_mode)
if is_action:
output_lines.append(f"{readable_time}, {content}")
else:
output_lines.append(f"{readable_time}, {name}: {content}")
output_lines.append("\n")
formatted_string = "".join(output_lines).strip()
# 最后对完整字符串再按映射表做一次替换,处理正文里直接出现的原始昵称
if name_mapping:
for original_name, anon_name in sorted(name_mapping.items(), key=lambda x: len(x[0]), reverse=True):
if original_name:
formatted_string = formatted_string.replace(original_name, anon_name)
return formatted_string, name_mapping
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
"""
从消息列表中提取不重复的 person_id 列表 (忽略机器人自身)