pull/1374/head
墨梓柒 2025-11-19 23:35:14 +08:00
parent 2f58605644
commit 44f427dc64
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
42 changed files with 1742 additions and 2062 deletions

4
bot.py
View File

@ -1,7 +1,6 @@
import asyncio
import hashlib
import os
import sys
import time
import platform
import traceback
@ -30,7 +29,7 @@ else:
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
initialize_logging()
@ -215,6 +214,7 @@ if __name__ == "__main__":
# 初始化 WebSocket 日志推送
from src.common.logger import initialize_ws_handler
initialize_ws_handler(loop)
try:

View File

@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60

View File

@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

View File

@ -57,8 +57,8 @@ from src.common.database.database import db
from src.common.database.database_model import Emoji
# 常量定义
MAGIC = b'MMIP'
FOOTER_MAGIC = b'MMFF'
MAGIC = b"MMIP"
FOOTER_MAGIC = b"MMFF"
VERSION = 1
FOOTER_VERSION = 1
@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
# 支持的图片格式
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'}
SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"}
# 创建控制台对象
console = Console()
@ -75,6 +75,7 @@ console = Console()
class MMIPKGError(Exception):
"""MMIPKG 相关错误"""
pass
@ -97,55 +98,55 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]:
try:
with Image.open(file_path) as img:
width, height = img.size
format_lower = img.format.lower() if img.format else 'unknown'
format_lower = img.format.lower() if img.format else "unknown"
mime_map = {
'jpeg': 'image/jpeg',
'jpg': 'image/jpeg',
'png': 'image/png',
'gif': 'image/gif',
'webp': 'image/webp',
'avif': 'image/avif',
'bmp': 'image/bmp'
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
"avif": "image/avif",
"bmp": "image/bmp",
}
mime_type = mime_map.get(format_lower, f'image/{format_lower}')
mime_type = mime_map.get(format_lower, f"image/{format_lower}")
return width, height, mime_type
except Exception as e:
print(f"警告: 无法读取图片信息 {file_path}: {e}")
return 0, 0, 'image/unknown'
return 0, 0, "image/unknown"
def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes:
def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes:
"""重新编码图片"""
try:
with Image.open(file_path) as img:
# 转换为 RGB如果需要
if img.mode in ('RGBA', 'LA', 'P'):
if output_format.lower() == 'jpeg':
if img.mode in ("RGBA", "LA", "P"):
if output_format.lower() == "jpeg":
# JPEG 不支持透明度,转为白色背景
background = Image.new('RGB', img.size, (255, 255, 255))
if img.mode == 'P':
img = img.convert('RGBA')
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P":
img = img.convert("RGBA")
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
elif output_format.lower() == 'webp':
elif output_format.lower() == "webp":
# WebP 支持透明度
if img.mode == 'P':
img = img.convert('RGBA')
elif img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
if img.mode == "P":
img = img.convert("RGBA")
elif img.mode not in ("RGB", "RGBA"):
img = img.convert("RGB")
# 编码图片
output = io.BytesIO()
save_kwargs = {'format': output_format.upper()}
save_kwargs = {"format": output_format.upper()}
if output_format.lower() in {'jpeg', 'jpg'}:
save_kwargs['quality'] = quality
save_kwargs['optimize'] = True
elif output_format.lower() == 'webp':
save_kwargs['quality'] = quality
save_kwargs['method'] = 6 # 更好的压缩
elif output_format.lower() == 'png':
save_kwargs['optimize'] = True
if output_format.lower() in {"jpeg", "jpg"}:
save_kwargs["quality"] = quality
save_kwargs["optimize"] = True
elif output_format.lower() == "webp":
save_kwargs["quality"] = quality
save_kwargs["method"] = 6 # 更好的压缩
elif output_format.lower() == "png":
save_kwargs["optimize"] = True
img.save(output, **save_kwargs)
return output.getvalue()
@ -156,11 +157,13 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8
class MMIPKGPacker:
"""MMIPKG 打包器"""
def __init__(self,
use_compression: bool = True,
zstd_level: int = 3,
reencode: Optional[str] = None,
reencode_quality: int = 80):
def __init__(
self,
use_compression: bool = True,
zstd_level: int = 3,
reencode: Optional[str] = None,
reencode_quality: int = 80,
):
self.use_compression = use_compression and zstd is not None
self.zstd_level = zstd_level
self.reencode = reencode
@ -170,8 +173,9 @@ class MMIPKGPacker:
print("警告: zstandard 未安装,将不使用压缩")
self.use_compression = False
def pack_from_db(self, output_path: str, pack_name: Optional[str] = None,
custom_manifest: Optional[Dict] = None) -> bool:
def pack_from_db(
self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None
) -> bool:
"""从数据库导出已注册的表情包
Args:
@ -205,12 +209,14 @@ class MMIPKGPacker:
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console
console=console,
) as progress:
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
for idx, emoji in enumerate(emojis, 1):
progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}")
progress.update(
task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}"
)
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
@ -224,10 +230,10 @@ class MMIPKGPacker:
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
except Exception as e:
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
with open(emoji.full_path, 'rb') as f:
with open(emoji.full_path, "rb") as f:
img_bytes = f.read()
else:
with open(emoji.full_path, 'rb') as f:
with open(emoji.full_path, "rb") as f:
img_bytes = f.read()
# 计算 SHA256
@ -259,7 +265,7 @@ class MMIPKGPacker:
"emoji_hash": emoji.emoji_hash or "",
"is_registered": True,
"is_banned": emoji.is_banned or False,
}
},
}
items.append(item)
@ -281,7 +287,7 @@ class MMIPKGPacker:
"p": pack_id, # pack_id
"n": pack_name, # pack_name
"t": datetime.now().isoformat(), # created_at
"a": items # items array
"a": items, # items array
}
# 添加自定义字段
@ -308,26 +314,28 @@ class MMIPKGPacker:
except Exception as e:
print(f"打包失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
if not db.is_closed():
db.close()
def _write_package(self, output_path: str, manifest_bytes: bytes,
image_data_list: List[bytes], payload_size: int) -> bool:
def _write_package(
self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int
) -> bool:
"""写入打包文件"""
try:
with open(output_path, 'wb') as f:
with open(output_path, "wb") as f:
# 写入 Header (32 bytes)
flags = 0x01 if self.use_compression else 0x00
header = MAGIC # 4 bytes
header += struct.pack('B', VERSION) # 1 byte
header += struct.pack('B', flags) # 1 byte
header += b'\x00\x00' # 2 bytes reserved
header += struct.pack('>Q', payload_size) # 8 bytes
header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes
header += b'\x00' * 8 # 8 bytes reserved
header += struct.pack("B", VERSION) # 1 byte
header += struct.pack("B", flags) # 1 byte
header += b"\x00\x00" # 2 bytes reserved
header += struct.pack(">Q", payload_size) # 8 bytes
header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes
header += b"\x00" * 8 # 8 bytes reserved
assert len(header) == 32, f"Header size mismatch: {len(header)}"
f.write(header)
@ -342,7 +350,7 @@ class MMIPKGPacker:
with compressor.stream_writer(f, closefd=False) as writer:
# 写入 manifest
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
writer.write(manifest_len_bytes)
writer.write(manifest_bytes)
payload_sha.update(manifest_len_bytes)
@ -355,13 +363,13 @@ class MMIPKGPacker:
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
console=console
console=console,
) as progress:
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
for idx, img_bytes in enumerate(image_data_list, 1):
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
img_len_bytes = struct.pack('>I', len(img_bytes))
img_len_bytes = struct.pack(">I", len(img_bytes))
writer.write(img_len_bytes)
writer.write(img_bytes)
payload_sha.update(img_len_bytes)
@ -370,7 +378,7 @@ class MMIPKGPacker:
else:
# 不压缩,直接写入
# 写入 manifest
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
f.write(manifest_len_bytes)
f.write(manifest_bytes)
payload_sha.update(manifest_len_bytes)
@ -383,13 +391,13 @@ class MMIPKGPacker:
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
console=console
console=console,
) as progress:
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
for idx, img_bytes in enumerate(image_data_list, 1):
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
img_len_bytes = struct.pack('>I', len(img_bytes))
img_len_bytes = struct.pack(">I", len(img_bytes))
f.write(img_len_bytes)
f.write(img_bytes)
payload_sha.update(img_len_bytes)
@ -400,8 +408,8 @@ class MMIPKGPacker:
file_sha256 = payload_sha.digest()
footer = FOOTER_MAGIC # 4 bytes
footer += file_sha256 # 32 bytes
footer += struct.pack('B', FOOTER_VERSION) # 1 byte
footer += b'\x00' * 3 # 3 bytes reserved
footer += struct.pack("B", FOOTER_VERSION) # 1 byte
footer += b"\x00" * 3 # 3 bytes reserved
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
f.write(footer)
@ -419,6 +427,7 @@ class MMIPKGPacker:
except Exception as e:
print(f"写入文件失败: {e}")
import traceback
traceback.print_exc()
return False
@ -429,10 +438,9 @@ class MMIPKGUnpacker:
def __init__(self, verify_sha: bool = True):
self.verify_sha = verify_sha
def import_to_db(self, package_path: str,
output_dir: Optional[str] = None,
replace_existing: bool = False,
batch_size: int = 500) -> bool:
def import_to_db(
self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500
) -> bool:
"""导入到数据库"""
try:
if not os.path.exists(package_path):
@ -451,7 +459,7 @@ class MMIPKGUnpacker:
print(f"正在读取包: {package_path}")
with open(package_path, 'rb') as f:
with open(package_path, "rb") as f:
# 读取 Header
header = f.read(32)
if len(header) != 32:
@ -461,15 +469,15 @@ class MMIPKGUnpacker:
if magic != MAGIC:
raise MMIPKGError(f"无效的 MAGIC: {magic}")
version = struct.unpack('B', header[4:5])[0]
version = struct.unpack("B", header[4:5])[0]
if version != VERSION:
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
flags = struct.unpack('B', header[5:6])[0]
flags = struct.unpack("B", header[5:6])[0]
is_compressed = bool(flags & 0x01)
payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0]
manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0]
payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0]
manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0]
# 安全检查
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
@ -519,7 +527,9 @@ class MMIPKGUnpacker:
# 方法2如果流式失败尝试直接解压兼容旧格式
print(f" 流式解压失败,尝试直接解压: {e}")
try:
payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len)
payload_data = decompressor.decompress(
compressed_data, max_output_size=payload_uncompressed_len
)
except Exception as e2:
raise MMIPKGError(f"解压失败: {e2}") from e2
else:
@ -537,7 +547,7 @@ class MMIPKGUnpacker:
# 读取 manifest
manifest_len_bytes = payload_stream.read(4)
manifest_len = struct.unpack('>I', manifest_len_bytes)[0]
manifest_len = struct.unpack(">I", manifest_len_bytes)[0]
manifest_bytes = payload_stream.read(manifest_len)
manifest = msgpack.unpackb(manifest_bytes, raw=False)
@ -553,20 +563,21 @@ class MMIPKGUnpacker:
print(f" 表情包数量: {len(items)}")
# 导入表情包
return self._import_items(payload_stream, items, output_dir,
replace_existing, batch_size)
return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size)
except Exception as e:
print(f"导入失败: {e}")
import traceback
traceback.print_exc()
return False
finally:
if not db.is_closed():
db.close()
def _import_items(self, payload_stream: BinaryIO, items: List[Dict],
output_dir: str, replace_existing: bool, batch_size: int) -> bool:
def _import_items(
self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int
) -> bool:
"""导入 items 到数据库"""
try:
imported_count = 0
@ -581,7 +592,7 @@ class MMIPKGUnpacker:
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeRemainingColumn(),
console=console
console=console,
) as progress:
task = progress.add_task("[cyan]导入表情包...", total=len(items))
@ -597,7 +608,7 @@ class MMIPKGUnpacker:
progress.advance(task)
continue
img_len = struct.unpack('>I', img_len_bytes)[0]
img_len = struct.unpack(">I", img_len_bytes)[0]
img_bytes = payload_stream.read(img_len)
if len(img_bytes) != img_len:
@ -641,7 +652,7 @@ class MMIPKGUnpacker:
file_path = os.path.join(output_dir, filename)
counter += 1
with open(file_path, 'wb') as img_file:
with open(file_path, "wb") as img_file:
img_file.write(img_bytes)
# 准备数据库记录
@ -700,6 +711,7 @@ class MMIPKGUnpacker:
except Exception as e:
console.print(f"[red]导入 items 失败: {e}[/red]")
import traceback
traceback.print_exc()
return False
@ -719,8 +731,9 @@ def print_menu():
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
console.print(" [0] [bold]退出[/bold]")
console.print()
def get_input(prompt: str, default: Optional[str] = None,
choices: Optional[List[str]] = None) -> str:
def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str:
"""获取用户输入"""
if default:
prompt = f"{prompt} (默认: {default})"
@ -760,9 +773,9 @@ def get_yes_no(prompt: str, default: bool = False) -> bool:
if not value:
return default
if value in ('y', 'yes', ''):
if value in ("y", "yes", ""):
return True
elif value in ('n', 'no', ''):
elif value in ("n", "no", ""):
return False
else:
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
@ -843,8 +856,8 @@ def interactive_export():
output_path = get_input(" 输出文件路径", default_filename)
# 确保有 .mmipkg 扩展名
if not output_path.endswith('.mmipkg'):
output_path += '.mmipkg'
if not output_path.endswith(".mmipkg"):
output_path += ".mmipkg"
# 获取包名称
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
@ -853,9 +866,7 @@ def interactive_export():
# 自定义 manifest
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
if get_yes_no(" 是否添加包的作者和介绍信息", False):
custom_manifest = {
"author": author
} if (author := input(" 作者名称(可选): ").strip()) else {}
custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {}
# 介绍信息
console.print(" 包介绍(限制 100 字以内):")
@ -888,9 +899,9 @@ def interactive_export():
console.print(" webp: 推荐,体积小且支持透明度")
console.print(" jpeg: 最小体积,但不支持透明度")
console.print(" png: 无损,文件较大")
reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png'])
reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"])
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80
else:
reencode = None
quality = 80
@ -920,10 +931,7 @@ def interactive_export():
# 开始导出
console.print("\n[cyan]开始导出...[/cyan]")
packer = MMIPKGPacker(
use_compression=use_compression,
zstd_level=zstd_level,
reencode=reencode,
reencode_quality=quality
use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality
)
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
@ -944,11 +952,11 @@ def interactive_import():
# 选择导入模式
print_import_mode_selection()
import_mode = get_input("请选择", "1", ['1', '2'])
import_mode = get_input("请选择", "1", ["1", "2"])
input_files = []
if import_mode == '1':
if import_mode == "1":
# 自动扫描模式
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
os.makedirs(import_dir, exist_ok=True)
@ -957,7 +965,7 @@ def interactive_import():
# 查找所有 .mmipkg 文件
for file in os.listdir(import_dir):
if file.endswith('.mmipkg'):
if file.endswith(".mmipkg"):
file_path = os.path.join(import_dir, file)
if os.path.isfile(file_path):
input_files.append(file_path)
@ -1032,7 +1040,7 @@ def interactive_import():
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console
console=console,
) as progress:
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
@ -1044,10 +1052,7 @@ def interactive_import():
console.print(f"[bold]{'=' * 70}[/bold]")
success = unpacker.import_to_db(
input_path,
output_dir=output_dir,
replace_existing=replace_existing,
batch_size=batch_size
input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size
)
if success:
@ -1076,16 +1081,16 @@ def main():
while True:
print_menu()
try:
choice = get_input("请选择", "1", ['0', '1', '2'])
choice = get_input("请选择", "1", ["0", "1", "2"])
except KeyboardInterrupt:
console.print("\n[green]再见![/green]")
return 0
if choice == '0':
if choice == "0":
console.print("\n[green]再见![/green]")
return 0
elif choice == '1':
elif choice == "1":
try:
interactive_export()
except KeyboardInterrupt:
@ -1093,6 +1098,7 @@ def main():
except Exception as e:
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
import traceback
traceback.print_exc()
try:
@ -1100,7 +1106,7 @@ def main():
except (KeyboardInterrupt, EOFError):
pass
elif choice == '2':
elif choice == "2":
try:
interactive_import()
except KeyboardInterrupt:
@ -1108,6 +1114,7 @@ def main():
except Exception as e:
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
import traceback
traceback.print_exc()
try:
@ -1121,5 +1128,5 @@ def main():
return 0
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -334,7 +334,6 @@ class HeartFChatting:
self.consecutive_no_reply_count = 0
reason = ""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,

View File

@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def get_qa_manager():
return qa_manager
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:

View File

@ -128,11 +128,10 @@ class QAManager:
selected_knowledge = knowledge[:limit]
formatted_knowledge = [
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
for i, k in enumerate(selected_knowledge)
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
]
# if max_score is not None:
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
found_knowledge = "\n".join(formatted_knowledge)
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:

View File

@ -226,7 +226,9 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块

View File

@ -241,7 +241,9 @@ class PrivateReplyer:
return f"{sender_relation}"
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块

View File

@ -204,8 +204,9 @@ class WebSocketLogHandler(logging.Handler):
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get('event', formatted_msg)
message = log_dict.get("event", formatted_msg)
except (json.JSONDecodeError, ValueError):
# 不是 JSON,直接使用消息
message = formatted_msg
@ -228,10 +229,7 @@ class WebSocketLogHandler(logging.Handler):
import asyncio
from src.webui.logs_ws import broadcast_log
asyncio.run_coroutine_threadsafe(
broadcast_log(log_data),
self.loop
)
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
except Exception:
# WebSocket 推送失败不影响日志记录
pass

View File

@ -467,11 +467,7 @@ class ExpressionLearner:
up_content: str,
current_time: float,
) -> None:
expr_obj = (
Expression.select()
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
.first()
)
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
if expr_obj:
await self._update_existing_expression(

View File

@ -42,8 +42,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
@ -262,7 +260,6 @@ class ExpressionSelector:
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt)
if not content:

View File

@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [
str(alias or "").strip().lower()
for alias in getattr(bot_config, "alias_names", []) or []
]
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
@ -188,9 +185,7 @@ async def _enrich_raw_content_if_needed(
# 获取该消息的前三条消息
try:
previous_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=target_message.time,
limit=3
chat_id=chat_id, timestamp=target_message.time, limit=3
)
if previous_messages:
@ -245,7 +240,7 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [3,6, 10, 20, 40, 60, 100]
thresholds = [3, 6, 10, 20, 40, 60, 100]
if count < thresholds[0]:
return False
@ -311,7 +306,9 @@ class JargonMiner:
raw_content_list = []
if raw_content_str:
try:
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
raw_content_list = (
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
)
if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError):
@ -360,7 +357,6 @@ class JargonMiner:
jargon_obj.save()
return
# 步骤2: 仅基于content推断
prompt2 = await global_prompt_manager.format_prompt(
"jargon_inference_content_only_prompt",
@ -388,7 +384,6 @@ class JargonMiner:
logger.error(f"jargon {content} 推断2解析失败: {e}")
return
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
@ -457,7 +452,9 @@ class JargonMiner:
jargon_obj.is_complete = True
jargon_obj.save()
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式
if is_jargon:
@ -475,6 +472,7 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon推断失败: {e}")
import traceback
traceback.print_exc()
def should_trigger(self) -> bool:
@ -571,10 +569,7 @@ class JargonMiner:
if _contains_bot_self_name(content):
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
entries.append({
"content": content,
"raw_content": raw_content_list
})
entries.append({"content": content, "raw_content": raw_content_list})
except Exception as e:
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
return
@ -612,19 +607,10 @@ class JargonMiner:
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
# 开启all_global无视chat_id查询所有content匹配的记录所有记录都是全局的
query = (
Jargon.select()
.where(Jargon.content == content)
)
query = Jargon.select().where(Jargon.content == content)
else:
# 关闭all_global只查询chat_id匹配的记录不考虑is_global
query = (
Jargon.select()
.where(
(Jargon.chat_id == self.chat_id) &
(Jargon.content == content)
)
)
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
if query.exists():
obj = query.get()
@ -637,7 +623,9 @@ class JargonMiner:
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
@ -676,7 +664,7 @@ class JargonMiner:
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=self.chat_id,
is_global=is_global_new,
count=1
count=1,
)
saved += 1
except Exception as e:
@ -720,11 +708,7 @@ async def extract_and_store_jargon(chat_id: str) -> None:
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:
"""
搜索jargon支持大小写不敏感和模糊搜索
@ -747,10 +731,7 @@ def search_jargon(
keyword = keyword.strip()
# 构建查询
query = Jargon.select(
Jargon.content,
Jargon.meaning
)
query = Jargon.select(Jargon.content, Jargon.meaning)
# 构建搜索条件
if case_sensitive:
@ -760,7 +741,7 @@ def search_jargon(
search_condition = Jargon.content.contains(keyword)
else:
# 精确匹配
search_condition = (Jargon.content == keyword)
search_condition = Jargon.content == keyword
else:
# 大小写不敏感
if fuzzy:
@ -768,7 +749,7 @@ def search_jargon(
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
else:
# 精确匹配使用LOWER函数
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
query = query.where(search_condition)
@ -779,14 +760,10 @@ def search_jargon(
else:
# 关闭all_global如果提供了chat_id优先搜索该聊天或global的jargon
if chat_id:
query = query.where(
(Jargon.chat_id == chat_id) | Jargon.is_global
)
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
# 只返回有meaning的记录
query = query.where(
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 按count降序排序优先返回出现频率高的
query = query.order_by(Jargon.count.desc())
@ -797,10 +774,7 @@ def search_jargon(
# 执行查询并返回结果
results = []
for jargon in query:
results.append({
"content": jargon.content or "",
"meaning": jargon.meaning or ""
})
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
return results
@ -840,10 +814,7 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
if global_config.jargon.all_global:
query = Jargon.select().where(Jargon.content == jargon_keyword)
else:
query = Jargon.select().where(
(Jargon.chat_id == chat_id) &
(Jargon.content == jargon_keyword)
)
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
if query.exists():
# 更新现有记录
@ -854,7 +825,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
raw_content=json.dumps([raw_content], ensure_ascii=False),
chat_id=chat_id,
is_global=is_global_new,
count=1
count=1,
)
logger.info(f"创建新jargon记录: {jargon_keyword}")
except Exception as e:
logger.error(f"存储jargon失败: {e}")

View File

@ -116,9 +116,7 @@ class MessageBuilder:
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0 and not (
self.__role == RoleType.Assistant and self.__tool_calls
):
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")

View File

@ -47,6 +47,7 @@ class MainSystem:
"""注册 WebUI API 路由"""
try:
from src.webui.routes import router as webui_router
self.server.register_router(webui_router)
logger.info("WebUI API 路由已注册")
except Exception as e:
@ -55,6 +56,7 @@ class MainSystem:
def _setup_webui(self):
"""设置 WebUI根据环境变量决定模式"""
import os
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
if not webui_enabled:
logger.info("WebUI 已禁用")
@ -64,6 +66,7 @@ class MainSystem:
try:
from src.webui.manager import setup_webui
setup_webui(mode=webui_mode)
except Exception as e:
logger.error(f"设置 WebUI 失败: {e}")

View File

@ -17,7 +17,7 @@ from src.llm_models.payload_content.message import MessageBuilder, RoleType, Mes
logger = get_logger("memory_retrieval")
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 3600 # 未找到答案记录保留时长
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
_last_not_found_cleanup_ts: float = 0.0
@ -33,10 +33,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
try:
deleted_rows = (
ThinkingBack.delete()
.where(
(ThinkingBack.found_answer == 0) &
(ThinkingBack.update_time < threshold_time)
)
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
.execute()
)
if deleted_rows:
@ -45,6 +42,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
except Exception as e:
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
def init_memory_retrieval_prompt():
"""初始化记忆检索相关的 prompt 模板和工具"""
# 首先注册所有工具
@ -221,10 +219,7 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
return None
async def _retrieve_concepts_with_jargon(
concepts: List[str],
chat_id: str
) -> str:
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
"""对概念列表进行jargon检索
Args:
@ -246,25 +241,13 @@ async def _retrieve_concepts_with_jargon(
continue
# 先尝试精确匹配
jargon_results = search_jargon(
keyword=concept,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not jargon_results:
jargon_results = search_jargon(
keyword=concept,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if jargon_results:
@ -298,11 +281,7 @@ async def _retrieve_concepts_with_jargon(
async def _react_agent_solve_question(
question: str,
chat_id: str,
max_iterations: int = 5,
timeout: float = 30.0,
initial_info: str = ""
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
"""使用ReAct架构的Agent来解决问题
@ -343,11 +322,12 @@ async def _react_agent_solve_question(
remaining_iterations = max_iterations - current_iteration
is_final_iteration = current_iteration >= max_iterations
if is_final_iteration:
# 最后一次迭代使用最终prompt
tool_definitions = []
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0最后一次迭代不提供工具调用")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0最后一次迭代不提供工具调用"
)
prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_final_prompt",
@ -370,7 +350,9 @@ async def _react_agent_solve_question(
else:
# 非最终迭代使用head_prompt
tool_definitions = tool_registry.get_tool_definitions()
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
)
head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
@ -401,7 +383,7 @@ async def _react_agent_solve_question(
# 优化日志展示 - 合并所有消息到一条日志
log_lines = []
for idx, msg in enumerate(messages, 1):
role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
# 处理内容 - 显示完整内容,不截断
if isinstance(msg.content, str):
@ -437,14 +419,22 @@ async def _react_agent_solve_question(
return messages
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory(
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
)
if not success:
logger.error(f"ReAct Agent LLM调用失败: {response}")
@ -465,12 +455,7 @@ async def _react_agent_solve_question(
assistant_message = assistant_builder.build()
# 记录思考步骤
step = {
"iteration": iteration + 1,
"thought": response,
"actions": [],
"observations": []
}
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
# 优先从思考内容中提取found_answer或not_enough_info
def extract_quoted_content(text, func_name, param_name):
@ -495,14 +480,14 @@ async def _react_agent_solve_question(
return None
# 查找参数名和等号
param_pattern = f'{param_name}='
param_pattern = f"{param_name}="
param_pos = text_lower.find(param_pattern, func_pos)
if param_pos == -1:
return None
# 跳过参数名、等号和空白
start_pos = param_pos + len(param_pattern)
while start_pos < len(text) and text[start_pos] in ' \t\n':
while start_pos < len(text) and text[start_pos] in " \t\n":
start_pos += 1
if start_pos >= len(text):
@ -518,13 +503,13 @@ async def _react_agent_solve_question(
while end_pos < len(text):
if text[end_pos] == quote_char:
# 检查是否是转义的引号
if end_pos > start_pos + 1 and text[end_pos - 1] == '\\':
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
end_pos += 1
continue
# 找到匹配的引号
content = text[start_pos + 1:end_pos]
content = text[start_pos + 1 : end_pos]
# 处理转义字符
content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\')
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
return content
end_pos += 1
@ -536,27 +521,35 @@ async def _react_agent_solve_question(
# 只检查responseLLM的直接输出内容不检查reasoning_content
if response:
found_answer_content = extract_quoted_content(response, 'found_answer', 'answer')
found_answer_content = extract_quoted_content(response, "found_answer", "answer")
if not found_answer_content:
not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason')
not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason")
# 如果从输出内容中找到了答案,直接返回
if found_answer_content:
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
step["observations"] = ["从LLM输出内容中检测到found_answer"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
)
return True, found_answer_content, thinking_steps, False
if not_enough_info_reason:
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}})
step["actions"].append(
{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}
)
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
)
return False, not_enough_info_reason, thinking_steps, False
if is_final_iteration:
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}})
step["actions"].append(
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}
)
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
@ -596,7 +589,9 @@ async def _react_agent_solve_question(
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 普通工具调用
tool = tool_registry.get_tool(tool_name)
@ -606,6 +601,7 @@ async def _react_agent_solve_question(
# 如果工具函数签名需要chat_id添加它
import inspect
sig = inspect.signature(tool.execute_func)
if "chat_id" in sig.parameters:
tool_params["chat_id"] = chat_id
@ -625,7 +621,7 @@ async def _react_agent_solve_question(
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
else:
error_msg = f"未知的工具类型: {tool_name}"
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}")
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
# 并行执行所有工具
@ -636,7 +632,7 @@ async def _react_agent_solve_question(
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
if isinstance(observation, Exception):
observation = f"工具执行异常: {str(observation)}"
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}")
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
observation_text = observation if isinstance(observation, str) else str(observation)
step["observations"].append(observation_text)
@ -655,7 +651,9 @@ async def _react_agent_solve_question(
# 迭代超时应该直接视为not_enough_info而不是使用已有信息
# 只有Agent明确返回found_answer时才认为找到了答案
if collected_info:
logger.warning(f"ReAct Agent达到最大迭代次数或超时但未明确返回found_answer。已收集信息: {collected_info[:100]}...")
logger.warning(
f"ReAct Agent达到最大迭代次数或超时但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
)
if is_timeout:
logger.warning("ReAct Agent超时直接视为not_enough_info")
else:
@ -680,10 +678,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
# 查询最近时间窗口内的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.update_time >= start_time)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
.order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录
)
@ -735,9 +730,9 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.update_time >= start_time) &
(ThinkingBack.found_answer == 1)
(ThinkingBack.chat_id == chat_id)
& (ThinkingBack.update_time >= start_time)
& (ThinkingBack.found_answer == 1)
)
.order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录
@ -775,10 +770,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
# 按更新时间倒序,获取最新的记录
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
.order_by(ThinkingBack.update_time.desc())
.limit(1)
)
@ -857,6 +849,7 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
if jargon_keyword:
from src.jargon.jargon_miner import store_jargon_from_answer
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
else:
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
@ -882,14 +875,8 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
logger.error(f"分析问题和答案时发生异常: {e}")
def _store_thinking_back(
chat_id: str,
question: str,
context: str,
found_answer: bool,
answer: str,
thinking_steps: List[Dict[str, Any]]
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
) -> None:
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
@ -907,10 +894,7 @@ def _store_thinking_back(
# 先查询是否已存在相同chat_id和问题的记录
existing = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
.order_by(ThinkingBack.update_time.desc())
.limit(1)
)
@ -935,19 +919,14 @@ def _store_thinking_back(
answer=answer,
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
create_time=now,
update_time=now
update_time=now,
)
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
except Exception as e:
logger.error(f"存储思考过程失败: {e}")
async def _process_single_question(
question: str,
chat_id: str,
context: str,
initial_info: str = ""
) -> Optional[str]:
async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
"""处理单个问题的查询(包含缓存检查逻辑)
Args:
@ -1015,7 +994,7 @@ async def _process_single_question(
chat_id=chat_id,
max_iterations=global_config.memory.max_agent_iterations,
timeout=120.0,
initial_info=question_initial_info
initial_info=question_initial_info,
)
# 存储到数据库(超时时不存储)
@ -1026,7 +1005,7 @@ async def _process_single_question(
context=context,
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
thinking_steps=thinking_steps,
)
else:
logger.info(f"ReAct Agent超时不存储到数据库问题: {question[:50]}...")
@ -1112,7 +1091,6 @@ async def build_memory_retrieval_prompt(
else:
logger.info("概念检索未找到任何结果")
# 获取缓存的记忆与question时使用相同的时间窗口和数量限制
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
@ -1141,12 +1119,7 @@ async def build_memory_retrieval_prompt(
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
_process_single_question(
question=question,
chat_id=chat_id,
context=message,
initial_info=initial_info
)
_process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
for question in questions
]
@ -1179,7 +1152,9 @@ async def build_memory_retrieval_prompt(
if all_results:
retrieved_memory = "\n\n".join(all_results)
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
logger.info(
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
logger.debug("所有问题均未找到答案,且无缓存记忆")

View File

@ -3,6 +3,7 @@
记忆系统工具函数
包含模糊查找相似度计算等工具函数
"""
import json
import re
from datetime import datetime
@ -14,6 +15,7 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
@ -52,6 +54,7 @@ def parse_md_json(json_text: str) -> list[str]:
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
@ -97,10 +100,10 @@ def preprocess_text(text: str) -> str:
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r"\s+", " ", text).strip()
return text
@ -109,7 +112,6 @@ def preprocess_text(text: str) -> str:
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳
@ -164,4 +166,3 @@ def parse_time_range(time_range: str) -> Tuple[float, float]:
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp

View File

@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()

View File

@ -15,10 +15,7 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None,
fuzzy: bool = True
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
@ -50,17 +47,11 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
@ -91,7 +82,9 @@ async def query_chat_history(
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
@ -102,20 +95,24 @@ async def query_chat_history(
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
matched = True
for kw in keywords_lower:
kw_matched = (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list))
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched:
matched = False
break
@ -160,6 +157,7 @@ async def query_chat_history(
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
@ -199,20 +197,20 @@ def register_tool():
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
"required": False
"required": False,
},
{
"name": "time_range",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
"required": False,
},
{
"name": "fuzzy",
"type": "boolean",
"description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系",
"required": False
}
"required": False,
},
],
execute_func=query_chat_history
execute_func=query_chat_history,
)

View File

@ -73,5 +73,3 @@ def register_tool():
],
execute_func=query_lpmm_knowledge,
)

View File

@ -26,7 +26,9 @@ def _format_group_nick_names(group_nick_name_field) -> str:
try:
# 解析JSON格式的群昵称列表
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
group_nick_names_data = (
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
)
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
return ""
@ -71,9 +73,7 @@ async def query_person_info(person_name: str) -> str:
return "用户名称为空"
# 构建查询条件(使用模糊查询)
query = PersonInfo.select().where(
PersonInfo.person_name.contains(person_name)
)
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
# 执行查询
records = list(query.limit(20)) # 最多返回20条记录
@ -137,7 +137,11 @@ async def query_person_info(person_name: str) -> str:
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@ -206,7 +210,11 @@ async def query_person_info(person_name: str) -> str:
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@ -275,13 +283,7 @@ def register_tool():
name="query_person_info",
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
parameters=[
{
"name": "person_name",
"type": "string",
"description": "用户名称,用于查询用户信息",
"required": True
}
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
],
execute_func=query_person_info
execute_func=query_person_info,
)

View File

@ -82,11 +82,7 @@ class MemoryRetrievalTool:
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
return tool_def

View File

@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
class Person:
@classmethod
def register_person(
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
cls,
platform: str,
user_id: str,
nickname: str,
group_id: Optional[str] = None,
group_nick_name: Optional[str] = None,
):
"""
注册新用户的类方法
@ -781,7 +786,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if len(parts) >= 2:
existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过)
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True
break

View File

@ -125,7 +125,6 @@ class ToolExecutor:
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)

View File

@ -1,4 +1,5 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel):
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel):
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel):
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
@ -120,7 +127,7 @@ async def get_emoji_list(
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包列表
@ -145,10 +152,7 @@ async def get_emoji_list(
# 搜索过滤
if search:
query = query.where(
(Emoji.description.contains(search)) |
(Emoji.emoji_hash.contains(search))
)
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
# 注册状态过滤
if is_registered is not None:
@ -164,10 +168,9 @@ async def get_emoji_list(
# 排序:使用次数倒序,然后按记录时间倒序
from peewee import Case
query = query.order_by(
Emoji.usage_count.desc(),
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
Emoji.record_time.desc()
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
)
# 获取总数
@ -180,13 +183,7 @@ async def get_emoji_list(
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
return EmojiListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
@ -196,10 +193,7 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
获取表情包详细信息
@ -218,10 +212,7 @@ async def get_emoji_detail(
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(
success=True,
data=emoji_to_response(emoji)
)
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException:
raise
@ -231,11 +222,7 @@ async def get_emoji_detail(
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新表情包只更新提供的字段
@ -262,15 +249,15 @@ async def update_emoji(
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 处理情感标签(转换为 JSON
if 'emotion' in update_data:
if update_data['emotion'] is None:
update_data['emotion'] = None
if "emotion" in update_data:
if update_data["emotion"] is None:
update_data["emotion"] = None
else:
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False)
# 如果注册状态从 False 变为 True记录注册时间
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
update_data['register_time'] = time.time()
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = time.time()
# 执行更新
for field, value in update_data.items():
@ -281,9 +268,7 @@ async def update_emoji(
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=emoji_to_response(emoji)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException:
@ -294,10 +279,7 @@ async def update_emoji(
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
删除表情包
@ -324,10 +306,7 @@ async def delete_emoji(
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(
success=True,
message=f"成功删除表情包: {emoji_hash}"
)
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException:
raise
@ -337,9 +316,7 @@ async def delete_emoji(
@router.get("/stats/summary")
async def get_emoji_stats(
authorization: Optional[str] = Header(None)
):
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
"""
获取表情包统计数据
@ -369,7 +346,7 @@ async def get_emoji_stats(
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count
"usage_count": emoji.usage_count,
}
for emoji in top_used
]
@ -382,8 +359,8 @@ async def get_emoji_stats(
"banned": banned,
"unregistered": total - registered,
"formats": formats,
"top_used": top_used_list
}
"top_used": top_used_list,
},
}
except HTTPException:
@ -394,10 +371,7 @@ async def get_emoji_stats(
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
注册表情包快捷操作
@ -429,11 +403,7 @@ async def register_emoji(
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包注册成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException:
raise
@ -443,10 +413,7 @@ async def register_emoji(
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
禁用表情包快捷操作
@ -472,11 +439,7 @@ async def ban_emoji(
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包禁用成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException:
raise
@ -489,7 +452,7 @@ async def ban_emoji(
async def get_emoji_thumbnail(
emoji_id: int,
token: Optional[str] = Query(None, description="访问令牌"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包缩略图
@ -523,25 +486,20 @@ async def get_emoji_thumbnail(
# 根据格式设置 MIME 类型
mime_types = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'gif': 'image/gif',
'webp': 'image/webp',
'bmp': 'image/bmp'
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
)
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包缩略图失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e

View File

@ -1,4 +1,5 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
context: Optional[str] = None
@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
@ -112,7 +121,7 @@ async def get_expression_list(
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
@ -136,9 +145,9 @@ async def get_expression_list(
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) |
(Expression.style.contains(search)) |
(Expression.context.contains(search))
(Expression.situation.contains(search))
| (Expression.style.contains(search))
| (Expression.context.contains(search))
)
# 聊天ID过滤
@ -147,9 +156,9 @@ async def get_expression_list(
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
Expression.last_active_time.desc()
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
@ -162,13 +171,7 @@ async def get_expression_list(
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
@ -178,10 +181,7 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
"""
获取表达方式详细信息
@ -200,10 +200,7 @@ async def get_expression_detail(
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(
success=True,
data=expression_to_response(expression)
)
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
@ -213,10 +210,7 @@ async def get_expression_detail(
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
authorization: Optional[str] = Header(None)
):
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
"""
创建新的表达方式
@ -246,9 +240,7 @@ async def create_expression(
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True,
message="表达方式创建成功",
data=expression_to_response(expression)
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
@ -260,9 +252,7 @@ async def create_expression(
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
authorization: Optional[str] = Header(None)
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式只更新提供的字段
@ -290,7 +280,7 @@ async def update_expression(
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data['last_active_time'] = time.time()
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
@ -301,9 +291,7 @@ async def update_expression(
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=expression_to_response(expression)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
@ -314,10 +302,7 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
"""
删除表达方式
@ -344,10 +329,7 @@ async def delete_expression(
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(
success=True,
message=f"成功删除表达方式: {situation}"
)
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
@ -357,9 +339,7 @@ async def delete_expression(
@router.get("/stats/summary")
async def get_expression_stats(
authorization: Optional[str] = Header(None)
):
async def get_expression_stats(authorization: Optional[str] = Header(None)):
"""
获取表达方式统计数据
@ -382,10 +362,11 @@ async def get_expression_stats(
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = Expression.select().where(
(Expression.create_date.is_null(False)) &
(Expression.create_date >= seven_days_ago)
).count()
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
@ -393,8 +374,8 @@ async def get_expression_stats(
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
}
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:

View File

@ -1,4 +1,5 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
@ -47,7 +50,7 @@ class GitMirrorConfig:
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None
"created_at": None,
},
{
"id": "hk-gh-proxy",
@ -56,7 +59,7 @@ class GitMirrorConfig:
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None
"created_at": None,
},
{
"id": "cdn-gh-proxy",
@ -65,7 +68,7 @@ class GitMirrorConfig:
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
@ -74,7 +77,7 @@ class GitMirrorConfig:
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None
"created_at": None,
},
{
"id": "meyzh-github",
@ -83,7 +86,7 @@ class GitMirrorConfig:
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None
"created_at": None,
},
{
"id": "github",
@ -92,8 +95,8 @@ class GitMirrorConfig:
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None
}
"created_at": None,
},
]
def __init__(self):
@ -106,7 +109,7 @@ class GitMirrorConfig:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
@ -145,14 +148,14 @@ class GitMirrorConfig:
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, 'w', encoding='utf-8') as f:
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
@ -182,7 +185,7 @@ class GitMirrorConfig:
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
@ -209,7 +212,7 @@ class GitMirrorConfig:
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat()
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
@ -225,7 +228,7 @@ class GitMirrorConfig:
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
@ -279,12 +282,7 @@ class GitMirrorConfig:
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(
self,
max_retries: int = 3,
timeout: int = 30,
config: Optional[GitMirrorConfig] = None
):
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
@ -323,46 +321,25 @@ class GitMirrorService:
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {
"installed": False,
"error": "系统中未找到 Git请先安装 Git"
}
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(
["git", "--version"],
capture_output=True,
text=True,
timeout=5
)
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {
"installed": True,
"version": version,
"path": git_path
}
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {
"installed": False,
"error": f"Git 命令执行失败: {result.stderr}"
}
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {
"installed": False,
"error": "Git 版本检测超时"
}
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {
"installed": False,
"error": f"检测 Git 时发生错误: {str(e)}"
}
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
@ -371,7 +348,7 @@ class GitMirrorService:
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
@ -403,12 +380,7 @@ class GitMirrorService:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
@ -427,14 +399,12 @@ class GitMirrorService:
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(
owner, repo, branch, file_path, mirror
)
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
@ -445,7 +415,7 @@ class GitMirrorService:
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
@ -461,26 +431,16 @@ class GitMirrorService:
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror: Dict[str, Any]
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
@ -508,7 +468,7 @@ class GitMirrorService:
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
@ -520,13 +480,7 @@ class GitMirrorService:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
@ -536,7 +490,7 @@ class GitMirrorService:
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
@ -569,12 +523,7 @@ class GitMirrorService:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
@ -582,20 +531,13 @@ class GitMirrorService:
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(
owner, repo, target_path, branch, depth, mirror
)
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源克隆均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
@ -604,7 +546,7 @@ class GitMirrorService:
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any]
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
@ -614,12 +556,7 @@ class GitMirrorService:
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self,
url: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror_type: str
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
@ -657,7 +594,7 @@ class GitMirrorService:
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install"
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
@ -670,7 +607,7 @@ class GitMirrorService:
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
@ -683,7 +620,7 @@ class GitMirrorService:
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default"
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
@ -710,13 +647,7 @@ class GitMirrorService:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例

View File

@ -1,4 +1,5 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
import json
@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),

View File

@ -1,4 +1,5 @@
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
import os
from pathlib import Path
from src.common.logger import get_logger
@ -55,10 +56,10 @@ def setup_production_mode() -> bool:
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('application/javascript', '.mjs')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type('application/json', '.json')
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
server = get_global_server()

View File

@ -1,4 +1,5 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel):
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
@ -118,7 +125,7 @@ async def get_person_list(
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
@ -143,9 +150,9 @@ async def get_person_list(
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search)) |
(PersonInfo.nickname.contains(search)) |
(PersonInfo.user_id.contains(search))
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
@ -159,10 +166,8 @@ async def get_person_list(
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
PersonInfo.last_know.desc()
)
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
@ -174,13 +179,7 @@ async def get_person_list(
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
@ -190,10 +189,7 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
"""
获取人物详细信息
@ -212,10 +208,7 @@ async def get_person_detail(
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(
success=True,
data=person_to_response(person)
)
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
@ -225,11 +218,7 @@ async def get_person_detail(
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新人物信息只更新提供的字段
@ -256,7 +245,7 @@ async def update_person(
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data['last_know'] = time.time()
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
@ -267,9 +256,7 @@ async def update_person(
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=person_to_response(person)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
@ -280,10 +267,7 @@ async def update_person(
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
"""
删除人物信息
@ -310,10 +294,7 @@ async def delete_person(
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(
success=True,
message=f"成功删除人物信息: {person_name}"
)
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
@ -323,9 +304,7 @@ async def delete_person(
@router.get("/stats/summary")
async def get_person_stats(
authorization: Optional[str] = Header(None)
):
async def get_person_stats(authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
@ -348,15 +327,7 @@ async def get_person_stats(
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {
"success": True,
"data": {
"total": total,
"known": known,
"unknown": unknown,
"platforms": platforms
}
}
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise

View File

@ -1,4 +1,5 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
import json
@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0
"loaded_plugins": 0,
}
@ -57,7 +58,7 @@ async def update_progress(
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0
loaded_plugins: int = 0,
):
"""更新并广播进度
@ -80,7 +81,7 @@ async def update_progress(
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time()
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)

View File

@ -30,12 +30,12 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
(major, minor, patch) 三元组
"""
# 移除 snapshot 等后缀
base_version = version_str.split('.snapshot')[0].split('.dev')[0].split('.alpha')[0].split('.beta')[0]
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
parts = base_version.split('.')
parts = base_version.split(".")
if len(parts) < 3:
# 补齐到 3 位
parts.extend(['0'] * (3 - len(parts)))
parts.extend(["0"] * (3 - len(parts)))
try:
major = int(parts[0])
@ -49,8 +49,10 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
# ============ 请求/响应模型 ============
class FetchRawFileRequest(BaseModel):
"""获取 Raw 文件请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
branch: str = Field(..., description="分支名称", example="main")
@ -61,6 +63,7 @@ class FetchRawFileRequest(BaseModel):
class FetchRawFileResponse(BaseModel):
"""获取 Raw 文件响应"""
success: bool = Field(..., description="是否成功")
data: Optional[str] = Field(None, description="文件内容")
error: Optional[str] = Field(None, description="错误信息")
@ -71,6 +74,7 @@ class FetchRawFileResponse(BaseModel):
class CloneRepositoryRequest(BaseModel):
"""克隆仓库请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
target_path: str = Field(..., description="目标路径(相对于插件目录)")
@ -82,6 +86,7 @@ class CloneRepositoryRequest(BaseModel):
class CloneRepositoryResponse(BaseModel):
"""克隆仓库响应"""
success: bool = Field(..., description="是否成功")
path: Optional[str] = Field(None, description="克隆路径")
error: Optional[str] = Field(None, description="错误信息")
@ -93,6 +98,7 @@ class CloneRepositoryResponse(BaseModel):
class MirrorConfigResponse(BaseModel):
"""镜像源配置响应"""
id: str = Field(..., description="镜像源 ID")
name: str = Field(..., description="镜像源名称")
raw_prefix: str = Field(..., description="Raw 文件前缀")
@ -103,12 +109,14 @@ class MirrorConfigResponse(BaseModel):
class AvailableMirrorsResponse(BaseModel):
"""可用镜像源列表响应"""
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: List[str] = Field(..., description="默认优先级顺序ID 列表)")
class AddMirrorRequest(BaseModel):
"""添加镜像源请求"""
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
@ -119,6 +127,7 @@ class AddMirrorRequest(BaseModel):
class UpdateMirrorRequest(BaseModel):
"""更新镜像源请求"""
name: Optional[str] = Field(None, description="镜像源名称")
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
@ -128,6 +137,7 @@ class UpdateMirrorRequest(BaseModel):
class GitStatusResponse(BaseModel):
"""Git 安装状态响应"""
installed: bool = Field(..., description="是否已安装 Git")
version: Optional[str] = Field(None, description="Git 版本号")
path: Optional[str] = Field(None, description="Git 可执行文件路径")
@ -136,6 +146,7 @@ class GitStatusResponse(BaseModel):
class InstallPluginRequest(BaseModel):
"""安装插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
@ -144,6 +155,7 @@ class InstallPluginRequest(BaseModel):
class VersionResponse(BaseModel):
"""麦麦版本响应"""
version: str = Field(..., description="麦麦版本号")
version_major: int = Field(..., description="主版本号")
version_minor: int = Field(..., description="次版本号")
@ -152,11 +164,13 @@ class VersionResponse(BaseModel):
class UninstallPluginRequest(BaseModel):
"""卸载插件请求"""
plugin_id: str = Field(..., description="插件 ID")
class UpdatePluginRequest(BaseModel):
"""更新插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
@ -165,6 +179,7 @@ class UpdatePluginRequest(BaseModel):
# ============ API 路由 ============
@router.get("/version", response_model=VersionResponse)
async def get_maimai_version() -> VersionResponse:
"""
@ -174,12 +189,7 @@ async def get_maimai_version() -> VersionResponse:
"""
major, minor, patch = parse_version(MMC_VERSION)
return VersionResponse(
version=MMC_VERSION,
version_major=major,
version_minor=minor,
version_patch=patch
)
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
@router.get("/git-status", response_model=GitStatusResponse)
@ -196,9 +206,7 @@ async def check_git_status() -> GitStatusResponse:
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
async def get_available_mirrors(
authorization: Optional[str] = Header(None)
) -> AvailableMirrorsResponse:
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
"""
获取所有可用的镜像源配置
"""
@ -219,22 +227,16 @@ async def get_available_mirrors(
raw_prefix=m["raw_prefix"],
clone_prefix=m["clone_prefix"],
enabled=m["enabled"],
priority=m["priority"]
priority=m["priority"],
)
for m in all_mirrors
]
return AvailableMirrorsResponse(
mirrors=mirrors,
default_priority=config.get_default_priority_list()
)
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
@router.post("/mirrors", response_model=MirrorConfigResponse)
async def add_mirror(
request: AddMirrorRequest,
authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
"""
添加新的镜像源
"""
@ -254,7 +256,7 @@ async def add_mirror(
raw_prefix=request.raw_prefix,
clone_prefix=request.clone_prefix,
enabled=request.enabled,
priority=request.priority
priority=request.priority,
)
return MirrorConfigResponse(
@ -263,7 +265,7 @@ async def add_mirror(
raw_prefix=mirror["raw_prefix"],
clone_prefix=mirror["clone_prefix"],
enabled=mirror["enabled"],
priority=mirror["priority"]
priority=mirror["priority"],
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
@ -274,9 +276,7 @@ async def add_mirror(
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
async def update_mirror(
mirror_id: str,
request: UpdateMirrorRequest,
authorization: Optional[str] = Header(None)
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
) -> MirrorConfigResponse:
"""
更新镜像源配置
@ -297,7 +297,7 @@ async def update_mirror(
raw_prefix=request.raw_prefix,
clone_prefix=request.clone_prefix,
enabled=request.enabled,
priority=request.priority
priority=request.priority,
)
if not mirror:
@ -309,7 +309,7 @@ async def update_mirror(
raw_prefix=mirror["raw_prefix"],
clone_prefix=mirror["clone_prefix"],
enabled=mirror["enabled"],
priority=mirror["priority"]
priority=mirror["priority"],
)
except HTTPException:
raise
@ -319,10 +319,7 @@ async def update_mirror(
@router.delete("/mirrors/{mirror_id}")
async def delete_mirror(
mirror_id: str,
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
删除镜像源
"""
@ -340,16 +337,12 @@ async def delete_mirror(
if not success:
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
return {
"success": True,
"message": f"已删除镜像源: {mirror_id}"
}
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
async def fetch_raw_file(
request: FetchRawFileRequest,
authorization: Optional[str] = Header(None)
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
) -> FetchRawFileResponse:
"""
获取 GitHub 仓库的 Raw 文件内容
@ -376,7 +369,7 @@ async def fetch_raw_file(
progress=10,
message=f"正在获取插件列表: {request.file_path}",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
try:
@ -389,22 +382,19 @@ async def fetch_raw_file(
branch=request.branch,
file_path=request.file_path,
mirror_id=request.mirror_id,
custom_url=request.custom_url
custom_url=request.custom_url,
)
if result.get("success"):
# 更新进度:成功获取
await update_progress(
stage="loading",
progress=70,
message="正在解析插件数据...",
total_plugins=0,
loaded_plugins=0
stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0
)
# 尝试解析插件数量
try:
import json
data = json.loads(result.get("data", "[]"))
total = len(data) if isinstance(data, list) else 0
@ -414,16 +404,12 @@ async def fetch_raw_file(
progress=100,
message=f"成功加载 {total} 个插件",
total_plugins=total,
loaded_plugins=total
loaded_plugins=total,
)
except Exception:
# 如果解析失败,仍然发送成功状态
await update_progress(
stage="success",
progress=100,
message="加载完成",
total_plugins=0,
loaded_plugins=0
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
)
return FetchRawFileResponse(**result)
@ -433,12 +419,7 @@ async def fetch_raw_file(
# 发送错误进度
await update_progress(
stage="error",
progress=0,
message="加载失败",
error=str(e),
total_plugins=0,
loaded_plugins=0
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@ -446,8 +427,7 @@ async def fetch_raw_file(
@router.post("/clone", response_model=CloneRepositoryResponse)
async def clone_repository(
request: CloneRepositoryRequest,
authorization: Optional[str] = Header(None)
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
) -> CloneRepositoryResponse:
"""
克隆 GitHub 仓库到本地
@ -460,9 +440,7 @@ async def clone_repository(
if not token or not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
logger.info(
f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}"
)
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
try:
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
@ -478,7 +456,7 @@ async def clone_repository(
branch=request.branch,
mirror_id=request.mirror_id,
custom_url=request.custom_url,
depth=request.depth
depth=request.depth,
)
return CloneRepositoryResponse(**result)
@ -489,10 +467,7 @@ async def clone_repository(
@router.post("/install")
async def install_plugin(
request: InstallPluginRequest,
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
安装插件
@ -513,16 +488,16 @@ async def install_plugin(
progress=5,
message=f"开始安装插件: {request.plugin_id}",
operation="install",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 1. 解析仓库 URL
# repository_url 格式: https://github.com/owner/repo
repo_url = request.repository_url.rstrip('/')
if repo_url.endswith('.git'):
repo_url = request.repository_url.rstrip("/")
if repo_url.endswith(".git"):
repo_url = repo_url[:-4]
parts = repo_url.split('/')
parts = repo_url.split("/")
if len(parts) < 2:
raise HTTPException(status_code=400, detail="无效的仓库 URL")
@ -534,7 +509,7 @@ async def install_plugin(
progress=10,
message=f"解析仓库信息: {owner}/{repo}",
operation="install",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 2. 确定插件安装路径
@ -548,10 +523,10 @@ async def install_plugin(
await update_progress(
stage="error",
progress=0,
message=f"插件已存在",
message="插件已存在",
operation="install",
plugin_id=request.plugin_id,
error="插件已安装,请先卸载"
error="插件已安装,请先卸载",
)
raise HTTPException(status_code=400, detail="插件已安装")
@ -560,31 +535,26 @@ async def install_plugin(
progress=15,
message=f"准备克隆到: {target_path}",
operation="install",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
service = get_git_mirror_service()
# 如果是 GitHub 仓库,使用镜像源
if 'github.com' in repo_url:
if "github.com" in repo_url:
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=target_path,
branch=request.branch,
mirror_id=request.mirror_id,
depth=1 # 浅克隆,节省时间和空间
depth=1, # 浅克隆,节省时间和空间
)
else:
# 自定义仓库,直接使用 URL
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=target_path,
branch=request.branch,
custom_url=repo_url,
depth=1
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
)
if not result.get("success"):
@ -595,23 +565,20 @@ async def install_plugin(
message="克隆仓库失败",
operation="install",
plugin_id=request.plugin_id,
error=error_msg
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 4. 验证插件完整性
await update_progress(
stage="loading",
progress=85,
message="验证插件文件...",
operation="install",
plugin_id=request.plugin_id
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
)
manifest_path = target_path / "_manifest.json"
if not manifest_path.exists():
# 清理失败的安装
import shutil
shutil.rmtree(target_path, ignore_errors=True)
await update_progress(
@ -620,26 +587,23 @@ async def install_plugin(
message="插件缺少 _manifest.json",
operation="install",
plugin_id=request.plugin_id,
error="无效的插件格式"
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
# 5. 读取并验证 manifest
await update_progress(
stage="loading",
progress=90,
message="读取插件配置...",
operation="install",
plugin_id=request.plugin_id
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
)
try:
import json as json_module
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
# 基本验证
required_fields = ['manifest_version', 'name', 'version', 'author']
required_fields = ["manifest_version", "name", "version", "author"]
for field in required_fields:
if field not in manifest:
raise ValueError(f"缺少必需字段: {field}")
@ -647,6 +611,7 @@ async def install_plugin(
except Exception as e:
# 清理失败的安装
import shutil
shutil.rmtree(target_path, ignore_errors=True)
await update_progress(
@ -655,7 +620,7 @@ async def install_plugin(
message="_manifest.json 无效",
operation="install",
plugin_id=request.plugin_id,
error=str(e)
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -665,16 +630,16 @@ async def install_plugin(
progress=100,
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
operation="install",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
return {
"success": True,
"message": "插件安装成功",
"plugin_id": request.plugin_id,
"plugin_name": manifest['name'],
"version": manifest['version'],
"path": str(target_path)
"plugin_name": manifest["name"],
"version": manifest["version"],
"path": str(target_path),
}
except HTTPException:
@ -688,7 +653,7 @@ async def install_plugin(
message="安装失败",
operation="install",
plugin_id=request.plugin_id,
error=str(e)
error=str(e),
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@ -696,8 +661,7 @@ async def install_plugin(
@router.post("/uninstall")
async def uninstall_plugin(
request: UninstallPluginRequest,
authorization: Optional[str] = Header(None)
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
卸载插件
@ -719,7 +683,7 @@ async def uninstall_plugin(
progress=10,
message=f"开始卸载插件: {request.plugin_id}",
operation="uninstall",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 1. 检查插件是否存在
@ -733,7 +697,7 @@ async def uninstall_plugin(
message="插件不存在",
operation="uninstall",
plugin_id=request.plugin_id,
error="插件未安装或已被删除"
error="插件未安装或已被删除",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -742,7 +706,7 @@ async def uninstall_plugin(
progress=30,
message=f"正在删除插件文件: {plugin_path}",
operation="uninstall",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 2. 读取插件信息(用于日志)
@ -752,7 +716,8 @@ async def uninstall_plugin(
if manifest_path.exists():
try:
import json as json_module
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
plugin_name = manifest.get("name", request.plugin_id)
except Exception:
@ -763,7 +728,7 @@ async def uninstall_plugin(
progress=50,
message=f"正在删除 {plugin_name}...",
operation="uninstall",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 3. 删除插件目录
@ -773,6 +738,7 @@ async def uninstall_plugin(
def remove_readonly(func, path, _):
"""清除只读属性并删除文件"""
import os
os.chmod(path, stat.S_IWRITE)
func(path)
@ -786,15 +752,10 @@ async def uninstall_plugin(
progress=100,
message=f"成功卸载插件: {plugin_name}",
operation="uninstall",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
return {
"success": True,
"message": "插件卸载成功",
"plugin_id": request.plugin_id,
"plugin_name": plugin_name
}
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
except HTTPException:
raise
@ -807,7 +768,7 @@ async def uninstall_plugin(
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
error="权限不足,无法删除插件文件"
error="权限不足,无法删除插件文件",
)
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
@ -820,17 +781,14 @@ async def uninstall_plugin(
message="卸载失败",
operation="uninstall",
plugin_id=request.plugin_id,
error=str(e)
error=str(e),
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.post("/update")
async def update_plugin(
request: UpdatePluginRequest,
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
更新插件
@ -851,7 +809,7 @@ async def update_plugin(
progress=5,
message=f"开始更新插件: {request.plugin_id}",
operation="update",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 1. 检查插件是否已安装
@ -865,7 +823,7 @@ async def update_plugin(
message="插件不存在",
operation="update",
plugin_id=request.plugin_id,
error="插件未安装,请先安装"
error="插件未安装,请先安装",
)
raise HTTPException(status_code=404, detail="插件未安装")
@ -877,10 +835,11 @@ async def update_plugin(
if manifest_path.exists():
try:
import json as json_module
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
old_version = manifest.get("version", "unknown")
plugin_name = manifest.get("name", request.plugin_id)
_plugin_name = manifest.get("name", request.plugin_id)
except Exception:
pass
@ -889,16 +848,12 @@ async def update_plugin(
progress=10,
message=f"当前版本: {old_version},准备更新...",
operation="update",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
# 3. 删除旧版本
await update_progress(
stage="loading",
progress=20,
message="正在删除旧版本...",
operation="update",
plugin_id=request.plugin_id
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
)
import shutil
@ -907,6 +862,7 @@ async def update_plugin(
def remove_readonly(func, path, _):
"""清除只读属性并删除文件"""
import os
os.chmod(path, stat.S_IWRITE)
func(path)
@ -920,14 +876,14 @@ async def update_plugin(
progress=30,
message="正在准备下载新版本...",
operation="update",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
repo_url = request.repository_url.rstrip('/')
if repo_url.endswith('.git'):
repo_url = request.repository_url.rstrip("/")
if repo_url.endswith(".git"):
repo_url = repo_url[:-4]
parts = repo_url.split('/')
parts = repo_url.split("/")
if len(parts) < 2:
raise HTTPException(status_code=400, detail="无效的仓库 URL")
@ -937,23 +893,18 @@ async def update_plugin(
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
service = get_git_mirror_service()
if 'github.com' in repo_url:
if "github.com" in repo_url:
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=plugin_path,
branch=request.branch,
mirror_id=request.mirror_id,
depth=1
depth=1,
)
else:
result = await service.clone_repository(
owner=owner,
repo=repo,
target_path=plugin_path,
branch=request.branch,
custom_url=repo_url,
depth=1
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
)
if not result.get("success"):
@ -964,17 +915,13 @@ async def update_plugin(
message="下载新版本失败",
operation="update",
plugin_id=request.plugin_id,
error=error_msg
error=error_msg,
)
raise HTTPException(status_code=500, detail=error_msg)
# 6. 验证新版本
await update_progress(
stage="loading",
progress=90,
message="验证新版本...",
operation="update",
plugin_id=request.plugin_id
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
)
new_manifest_path = plugin_path / "_manifest.json"
@ -983,6 +930,7 @@ async def update_plugin(
def remove_readonly(func, path, _):
"""清除只读属性并删除文件"""
import os
os.chmod(path, stat.S_IWRITE)
func(path)
@ -994,13 +942,13 @@ async def update_plugin(
message="新版本缺少 _manifest.json",
operation="update",
plugin_id=request.plugin_id,
error="无效的插件格式"
error="无效的插件格式",
)
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
# 7. 读取新版本信息
try:
with open(new_manifest_path, 'r', encoding='utf-8') as f:
with open(new_manifest_path, "r", encoding="utf-8") as f:
new_manifest = json_module.load(f)
new_version = new_manifest.get("version", "unknown")
@ -1014,7 +962,7 @@ async def update_plugin(
progress=100,
message=f"成功更新 {new_name}: {old_version}{new_version}",
operation="update",
plugin_id=request.plugin_id
plugin_id=request.plugin_id,
)
return {
@ -1023,7 +971,7 @@ async def update_plugin(
"plugin_id": request.plugin_id,
"plugin_name": new_name,
"old_version": old_version,
"new_version": new_version
"new_version": new_version,
}
except Exception as e:
@ -1036,7 +984,7 @@ async def update_plugin(
message="_manifest.json 无效",
operation="update",
plugin_id=request.plugin_id,
error=str(e)
error=str(e),
)
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
@ -1046,21 +994,14 @@ async def update_plugin(
logger.error(f"更新插件失败: {e}", exc_info=True)
await update_progress(
stage="error",
progress=0,
message="更新失败",
operation="update",
plugin_id=request.plugin_id,
error=str(e)
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
)
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
@router.get("/installed")
async def get_installed_plugins(
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
获取已安装的插件列表
@ -1081,10 +1022,7 @@ async def get_installed_plugins(
if not plugins_dir.exists():
logger.info("插件目录不存在,创建目录")
plugins_dir.mkdir(exist_ok=True)
return {
"success": True,
"plugins": []
}
return {"success": True, "plugins": []}
installed_plugins = []
@ -1098,7 +1036,7 @@ async def get_installed_plugins(
plugin_id = plugin_path.name
# 跳过隐藏目录和特殊目录
if plugin_id.startswith('.') or plugin_id.startswith('__'):
if plugin_id.startswith(".") or plugin_id.startswith("__"):
continue
# 读取 _manifest.json
@ -1110,20 +1048,23 @@ async def get_installed_plugins(
try:
import json as json_module
with open(manifest_path, 'r', encoding='utf-8') as f:
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
# 基本验证
if 'name' not in manifest or 'version' not in manifest:
if "name" not in manifest or "version" not in manifest:
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
continue
# 添加到已安装列表(返回完整的 manifest 信息)
installed_plugins.append({
"id": plugin_id,
"manifest": manifest, # 返回完整的 manifest 对象
"path": str(plugin_path.absolute())
})
installed_plugins.append(
{
"id": plugin_id,
"manifest": manifest, # 返回完整的 manifest 对象
"path": str(plugin_path.absolute()),
}
)
except json.JSONDecodeError as e:
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
@ -1134,11 +1075,7 @@ async def get_installed_plugins(
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
return {
"success": True,
"plugins": installed_plugins,
"total": len(installed_plugins)
}
return {"success": True, "plugins": installed_plugins, "total": len(installed_plugins)}
except Exception as e:
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)

View File

@ -3,6 +3,7 @@
提供系统重启状态查询等功能
"""
import os
import sys
import time
@ -19,12 +20,14 @@ _start_time = time.time()
class RestartResponse(BaseModel):
"""重启响应"""
success: bool
message: str
class StatusResponse(BaseModel):
"""状态响应"""
running: bool
uptime: float
version: str
@ -52,15 +55,9 @@ async def restart_maibot():
# 但我们仍然返回它以保持 API 一致性
os.execv(python, args)
return RestartResponse(
success=True,
message="麦麦正在重启中..."
)
return RestartResponse(success=True, message="麦麦正在重启中...")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"重启失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
@router.get("/status", response_model=StatusResponse)
@ -77,20 +74,15 @@ async def get_maibot_status():
version = MMC_VERSION # 可以从配置或常量中读取
return StatusResponse(
running=True,
uptime=uptime,
version=version,
start_time=datetime.fromtimestamp(_start_time).isoformat()
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取状态失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
# 可选:添加更多系统控制功能
@router.post("/reload-config")
async def reload_config():
"""
@ -102,7 +94,4 @@ async def reload_config():
# 这里需要调用主程序的配置重载函数
# 示例await app_instance.reload_config()
return {
"success": True,
"message": "配置重载功能待实现"
}
return {"success": True, "message": "配置重载功能待实现"}

View File

@ -1,4 +1,5 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, Field
from typing import Optional
@ -38,28 +39,33 @@ router.include_router(system_router)
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
"""Token 验证响应"""
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
class TokenUpdateRequest(BaseModel):
"""Token 更新请求"""
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
"""Token 更新响应"""
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
"""Token 重新生成响应"""
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
class FirstSetupStatusResponse(BaseModel):
"""首次配置状态响应"""
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
"""完成配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
"""重置配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
@ -105,25 +114,16 @@ async def verify_token(request: TokenVerifyRequest):
is_valid = token_manager.verify_token(request.token)
if is_valid:
return TokenVerifyResponse(
valid=True,
message="Token 验证成功"
)
return TokenVerifyResponse(valid=True, message="Token 验证成功")
else:
return TokenVerifyResponse(
valid=False,
message="Token 无效或已过期"
)
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@router.post("/auth/update", response_model=TokenUpdateResponse)
async def update_token(
request: TokenUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
"""
更新访问令牌需要当前有效的 token
@ -148,10 +148,7 @@ async def update_token(
# 更新 token
success, message = token_manager.update_token(request.new_token)
return TokenUpdateResponse(
success=success,
message=message
)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
raise
except Exception as e:
@ -184,11 +181,7 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
# 重新生成 token
new_token = token_manager.regenerate_token()
return TokenRegenerateResponse(
success=True,
token=new_token,
message="Token 已重新生成"
)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
raise
except Exception as e:
@ -221,10 +214,7 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
return FirstSetupStatusResponse(
is_first_setup=is_first,
message="首次配置" if is_first else "已完成配置"
)
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
except HTTPException:
raise
except Exception as e:
@ -257,10 +247,7 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
# 标记配置完成
success = token_manager.mark_setup_completed()
return CompleteSetupResponse(
success=success,
message="配置已完成" if success else "标记失败"
)
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
except HTTPException:
raise
except Exception as e:
@ -293,10 +280,7 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
# 重置配置状态
success = token_manager.reset_setup_status()
return ResetSetupResponse(
success=success,
message="配置状态已重置" if success else "重置失败"
)
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
except HTTPException:
raise
except Exception as e:

View File

@ -1,4 +1,5 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any, List
@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
@ -88,7 +93,7 @@ async def get_dashboard_data(hours: int = 24):
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
@ -100,11 +105,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
summary = StatisticsSummary()
# 查询 LLM 使用记录
llm_records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
total_time_cost = 0.0
time_cost_count = 0
@ -124,11 +125,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
# 查询在线时间
online_records = list(
OnlineTime.select()
.where(
(OnlineTime.start_timestamp >= start_time) |
(OnlineTime.end_timestamp >= start_time)
)
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
@ -139,9 +136,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
# 查询消息数量
messages = list(
Messages.select()
.where(Messages.time >= start_time.timestamp())
.where(Messages.time <= end_time.timestamp())
Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp())
)
summary.total_messages = len(messages)
@ -159,38 +154,32 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据"""
model_data = defaultdict(lambda: {
'request_count': 0,
'total_cost': 0.0,
'total_tokens': 0,
'time_costs': []
})
model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
)
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time))
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
model_data[model_name]['request_count'] += 1
model_data[model_name]['total_cost'] += record.cost or 0.0
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
model_data[model_name]["request_count"] += 1
model_data[model_name]["total_cost"] += record.cost or 0.0
model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
model_data[model_name]['time_costs'].append(record.time_cost)
model_data[model_name]["time_costs"].append(record.time_cost)
# 转换为列表并排序
result = []
for model_name, data in model_data.items():
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
result.append(ModelStatistics(
model_name=model_name,
request_count=data['request_count'],
total_cost=data['total_cost'],
total_tokens=data['total_tokens'],
avg_response_time=avg_time
))
avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0
result.append(
ModelStatistics(
model_name=model_name,
request_count=data["request_count"],
total_cost=data["total_cost"],
total_tokens=data["total_tokens"],
avg_response_time=avg_time,
)
)
# 按请求数排序
result.sort(key=lambda x: x.request_count, reverse=True)
@ -200,35 +189,28 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据"""
# 创建小时桶
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
for record in records:
# 获取小时键(去掉分钟和秒)
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
hour_str = hour_key.isoformat()
hourly_buckets[hour_str]['requests'] += 1
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
hourly_buckets[hour_str]["requests"] += 1
hourly_buckets[hour_str]["cost"] += record.cost or 0.0
hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.isoformat()
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=hour_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0})
result.append(
TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
)
current += timedelta(hours=1)
return result
@ -236,35 +218,28 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据"""
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
for record in records:
# 获取日期键
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
day_str = day_key.isoformat()
daily_buckets[day_str]['requests'] += 1
daily_buckets[day_str]['cost'] += record.cost or 0.0
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
daily_buckets[day_str]["requests"] += 1
daily_buckets[day_str]["cost"] += record.cost or 0.0
daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.isoformat()
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=day_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0})
result.append(
TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
)
current += timedelta(days=1)
return result
@ -272,23 +247,21 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(
LLMUsage.select()
.order_by(LLMUsage.timestamp.desc())
.limit(limit)
)
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append({
'timestamp': record.timestamp.isoformat(),
'model': record.model_assign_name or record.model_name,
'request_type': record.request_type,
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
'cost': record.cost or 0.0,
'time_cost': record.time_cost or 0.0,
'status': record.status
})
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities

View File

@ -80,7 +80,7 @@ class TokenManager:
"access_token": token,
"created_at": self._get_current_timestamp(),
"updated_at": self._get_current_timestamp(),
"first_setup_completed": False # 标记首次配置未完成
"first_setup_completed": False, # 标记首次配置未完成
}
self._save_config(config)
@ -91,6 +91,7 @@ class TokenManager:
def _get_current_timestamp(self) -> str:
"""获取当前时间戳字符串"""
from datetime import datetime
return datetime.now().isoformat()
def get_token(self) -> str: