mirror of https://github.com/Mai-with-u/MaiBot.git
Ruff fix
parent
2f58605644
commit
44f427dc64
4
bot.py
4
bot.py
|
|
@ -1,7 +1,6 @@
|
|||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import platform
|
||||
import traceback
|
||||
|
|
@ -30,7 +29,7 @@ else:
|
|||
raise
|
||||
|
||||
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
|
||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
|
||||
|
||||
initialize_logging()
|
||||
|
||||
|
|
@ -215,6 +214,7 @@ if __name__ == "__main__":
|
|||
|
||||
# 初始化 WebSocket 日志推送
|
||||
from src.common.logger import initialize_ws_handler
|
||||
|
||||
initialize_ws_handler(loop)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
|
|||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
|
||||
# 设置中文字体
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
|
|
|||
|
|
@ -57,8 +57,8 @@ from src.common.database.database import db
|
|||
from src.common.database.database_model import Emoji
|
||||
|
||||
# 常量定义
|
||||
MAGIC = b'MMIP'
|
||||
FOOTER_MAGIC = b'MMFF'
|
||||
MAGIC = b"MMIP"
|
||||
FOOTER_MAGIC = b"MMFF"
|
||||
VERSION = 1
|
||||
FOOTER_VERSION = 1
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ MAX_MANIFEST_SIZE = 200 * 1024 * 1024 # 200 MB
|
|||
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 * 1024 # 10 GB
|
||||
|
||||
# 支持的图片格式
|
||||
SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.avif', '.bmp'}
|
||||
SUPPORTED_FORMATS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".avif", ".bmp"}
|
||||
|
||||
# 创建控制台对象
|
||||
console = Console()
|
||||
|
|
@ -75,6 +75,7 @@ console = Console()
|
|||
|
||||
class MMIPKGError(Exception):
|
||||
"""MMIPKG 相关错误"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -97,55 +98,55 @@ def get_image_info(file_path: str) -> Tuple[int, int, str]:
|
|||
try:
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
format_lower = img.format.lower() if img.format else 'unknown'
|
||||
format_lower = img.format.lower() if img.format else "unknown"
|
||||
mime_map = {
|
||||
'jpeg': 'image/jpeg',
|
||||
'jpg': 'image/jpeg',
|
||||
'png': 'image/png',
|
||||
'gif': 'image/gif',
|
||||
'webp': 'image/webp',
|
||||
'avif': 'image/avif',
|
||||
'bmp': 'image/bmp'
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"avif": "image/avif",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
mime_type = mime_map.get(format_lower, f'image/{format_lower}')
|
||||
mime_type = mime_map.get(format_lower, f"image/{format_lower}")
|
||||
return width, height, mime_type
|
||||
except Exception as e:
|
||||
print(f"警告: 无法读取图片信息 {file_path}: {e}")
|
||||
return 0, 0, 'image/unknown'
|
||||
return 0, 0, "image/unknown"
|
||||
|
||||
|
||||
def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 80) -> bytes:
|
||||
def reencode_image(file_path: str, output_format: str = "webp", quality: int = 80) -> bytes:
|
||||
"""重新编码图片"""
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
# 转换为 RGB(如果需要)
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
if output_format.lower() == 'jpeg':
|
||||
if img.mode in ("RGBA", "LA", "P"):
|
||||
if output_format.lower() == "jpeg":
|
||||
# JPEG 不支持透明度,转为白色背景
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
||||
background = Image.new("RGB", img.size, (255, 255, 255))
|
||||
if img.mode == "P":
|
||||
img = img.convert("RGBA")
|
||||
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
|
||||
img = background
|
||||
elif output_format.lower() == 'webp':
|
||||
elif output_format.lower() == "webp":
|
||||
# WebP 支持透明度
|
||||
if img.mode == 'P':
|
||||
img = img.convert('RGBA')
|
||||
elif img.mode not in ('RGB', 'RGBA'):
|
||||
img = img.convert('RGB')
|
||||
if img.mode == "P":
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode not in ("RGB", "RGBA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 编码图片
|
||||
output = io.BytesIO()
|
||||
save_kwargs = {'format': output_format.upper()}
|
||||
save_kwargs = {"format": output_format.upper()}
|
||||
|
||||
if output_format.lower() in {'jpeg', 'jpg'}:
|
||||
save_kwargs['quality'] = quality
|
||||
save_kwargs['optimize'] = True
|
||||
elif output_format.lower() == 'webp':
|
||||
save_kwargs['quality'] = quality
|
||||
save_kwargs['method'] = 6 # 更好的压缩
|
||||
elif output_format.lower() == 'png':
|
||||
save_kwargs['optimize'] = True
|
||||
if output_format.lower() in {"jpeg", "jpg"}:
|
||||
save_kwargs["quality"] = quality
|
||||
save_kwargs["optimize"] = True
|
||||
elif output_format.lower() == "webp":
|
||||
save_kwargs["quality"] = quality
|
||||
save_kwargs["method"] = 6 # 更好的压缩
|
||||
elif output_format.lower() == "png":
|
||||
save_kwargs["optimize"] = True
|
||||
|
||||
img.save(output, **save_kwargs)
|
||||
return output.getvalue()
|
||||
|
|
@ -156,11 +157,13 @@ def reencode_image(file_path: str, output_format: str = 'webp', quality: int = 8
|
|||
class MMIPKGPacker:
|
||||
"""MMIPKG 打包器"""
|
||||
|
||||
def __init__(self,
|
||||
use_compression: bool = True,
|
||||
zstd_level: int = 3,
|
||||
reencode: Optional[str] = None,
|
||||
reencode_quality: int = 80):
|
||||
def __init__(
|
||||
self,
|
||||
use_compression: bool = True,
|
||||
zstd_level: int = 3,
|
||||
reencode: Optional[str] = None,
|
||||
reencode_quality: int = 80,
|
||||
):
|
||||
self.use_compression = use_compression and zstd is not None
|
||||
self.zstd_level = zstd_level
|
||||
self.reencode = reencode
|
||||
|
|
@ -170,8 +173,9 @@ class MMIPKGPacker:
|
|||
print("警告: zstandard 未安装,将不使用压缩")
|
||||
self.use_compression = False
|
||||
|
||||
def pack_from_db(self, output_path: str, pack_name: Optional[str] = None,
|
||||
custom_manifest: Optional[Dict] = None) -> bool:
|
||||
def pack_from_db(
|
||||
self, output_path: str, pack_name: Optional[str] = None, custom_manifest: Optional[Dict] = None
|
||||
) -> bool:
|
||||
"""从数据库导出已注册的表情包
|
||||
|
||||
Args:
|
||||
|
|
@ -205,12 +209,14 @@ class MMIPKGPacker:
|
|||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]扫描表情包...", total=emoji_count)
|
||||
|
||||
for idx, emoji in enumerate(emojis, 1):
|
||||
progress.update(task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}")
|
||||
progress.update(
|
||||
task, description=f"[cyan]处理 {idx}/{emoji_count}: {os.path.basename(emoji.full_path)}"
|
||||
)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.full_path):
|
||||
|
|
@ -224,10 +230,10 @@ class MMIPKGPacker:
|
|||
img_bytes = reencode_image(emoji.full_path, self.reencode, self.reencode_quality)
|
||||
except Exception as e:
|
||||
console.print(f" [yellow]警告: 重新编码失败,使用原始文件: {e}[/yellow]")
|
||||
with open(emoji.full_path, 'rb') as f:
|
||||
with open(emoji.full_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
else:
|
||||
with open(emoji.full_path, 'rb') as f:
|
||||
with open(emoji.full_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
|
||||
# 计算 SHA256
|
||||
|
|
@ -259,7 +265,7 @@ class MMIPKGPacker:
|
|||
"emoji_hash": emoji.emoji_hash or "",
|
||||
"is_registered": True,
|
||||
"is_banned": emoji.is_banned or False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
items.append(item)
|
||||
|
|
@ -281,7 +287,7 @@ class MMIPKGPacker:
|
|||
"p": pack_id, # pack_id
|
||||
"n": pack_name, # pack_name
|
||||
"t": datetime.now().isoformat(), # created_at
|
||||
"a": items # items array
|
||||
"a": items, # items array
|
||||
}
|
||||
|
||||
# 添加自定义字段
|
||||
|
|
@ -308,26 +314,28 @@ class MMIPKGPacker:
|
|||
except Exception as e:
|
||||
print(f"打包失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
def _write_package(self, output_path: str, manifest_bytes: bytes,
|
||||
image_data_list: List[bytes], payload_size: int) -> bool:
|
||||
def _write_package(
|
||||
self, output_path: str, manifest_bytes: bytes, image_data_list: List[bytes], payload_size: int
|
||||
) -> bool:
|
||||
"""写入打包文件"""
|
||||
try:
|
||||
with open(output_path, 'wb') as f:
|
||||
with open(output_path, "wb") as f:
|
||||
# 写入 Header (32 bytes)
|
||||
flags = 0x01 if self.use_compression else 0x00
|
||||
header = MAGIC # 4 bytes
|
||||
header += struct.pack('B', VERSION) # 1 byte
|
||||
header += struct.pack('B', flags) # 1 byte
|
||||
header += b'\x00\x00' # 2 bytes reserved
|
||||
header += struct.pack('>Q', payload_size) # 8 bytes
|
||||
header += struct.pack('>Q', len(manifest_bytes)) # 8 bytes
|
||||
header += b'\x00' * 8 # 8 bytes reserved
|
||||
header += struct.pack("B", VERSION) # 1 byte
|
||||
header += struct.pack("B", flags) # 1 byte
|
||||
header += b"\x00\x00" # 2 bytes reserved
|
||||
header += struct.pack(">Q", payload_size) # 8 bytes
|
||||
header += struct.pack(">Q", len(manifest_bytes)) # 8 bytes
|
||||
header += b"\x00" * 8 # 8 bytes reserved
|
||||
|
||||
assert len(header) == 32, f"Header size mismatch: {len(header)}"
|
||||
f.write(header)
|
||||
|
|
@ -342,7 +350,7 @@ class MMIPKGPacker:
|
|||
|
||||
with compressor.stream_writer(f, closefd=False) as writer:
|
||||
# 写入 manifest
|
||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
||||
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||
writer.write(manifest_len_bytes)
|
||||
writer.write(manifest_bytes)
|
||||
payload_sha.update(manifest_len_bytes)
|
||||
|
|
@ -355,13 +363,13 @@ class MMIPKGPacker:
|
|||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeRemainingColumn(),
|
||||
console=console
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[green]压缩写入图片...", total=len(image_data_list))
|
||||
|
||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||
progress.update(task, description=f"[green]压缩写入 {idx}/{len(image_data_list)}")
|
||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
||||
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||
writer.write(img_len_bytes)
|
||||
writer.write(img_bytes)
|
||||
payload_sha.update(img_len_bytes)
|
||||
|
|
@ -370,7 +378,7 @@ class MMIPKGPacker:
|
|||
else:
|
||||
# 不压缩,直接写入
|
||||
# 写入 manifest
|
||||
manifest_len_bytes = struct.pack('>I', len(manifest_bytes))
|
||||
manifest_len_bytes = struct.pack(">I", len(manifest_bytes))
|
||||
f.write(manifest_len_bytes)
|
||||
f.write(manifest_bytes)
|
||||
payload_sha.update(manifest_len_bytes)
|
||||
|
|
@ -383,13 +391,13 @@ class MMIPKGPacker:
|
|||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeRemainingColumn(),
|
||||
console=console
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[green]写入图片...", total=len(image_data_list))
|
||||
|
||||
for idx, img_bytes in enumerate(image_data_list, 1):
|
||||
progress.update(task, description=f"[green]写入 {idx}/{len(image_data_list)}")
|
||||
img_len_bytes = struct.pack('>I', len(img_bytes))
|
||||
img_len_bytes = struct.pack(">I", len(img_bytes))
|
||||
f.write(img_len_bytes)
|
||||
f.write(img_bytes)
|
||||
payload_sha.update(img_len_bytes)
|
||||
|
|
@ -400,8 +408,8 @@ class MMIPKGPacker:
|
|||
file_sha256 = payload_sha.digest()
|
||||
footer = FOOTER_MAGIC # 4 bytes
|
||||
footer += file_sha256 # 32 bytes
|
||||
footer += struct.pack('B', FOOTER_VERSION) # 1 byte
|
||||
footer += b'\x00' * 3 # 3 bytes reserved
|
||||
footer += struct.pack("B", FOOTER_VERSION) # 1 byte
|
||||
footer += b"\x00" * 3 # 3 bytes reserved
|
||||
|
||||
assert len(footer) == 40, f"Footer size mismatch: {len(footer)}"
|
||||
f.write(footer)
|
||||
|
|
@ -419,6 +427,7 @@ class MMIPKGPacker:
|
|||
except Exception as e:
|
||||
print(f"写入文件失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
|
@ -429,10 +438,9 @@ class MMIPKGUnpacker:
|
|||
def __init__(self, verify_sha: bool = True):
|
||||
self.verify_sha = verify_sha
|
||||
|
||||
def import_to_db(self, package_path: str,
|
||||
output_dir: Optional[str] = None,
|
||||
replace_existing: bool = False,
|
||||
batch_size: int = 500) -> bool:
|
||||
def import_to_db(
|
||||
self, package_path: str, output_dir: Optional[str] = None, replace_existing: bool = False, batch_size: int = 500
|
||||
) -> bool:
|
||||
"""导入到数据库"""
|
||||
try:
|
||||
if not os.path.exists(package_path):
|
||||
|
|
@ -451,7 +459,7 @@ class MMIPKGUnpacker:
|
|||
|
||||
print(f"正在读取包: {package_path}")
|
||||
|
||||
with open(package_path, 'rb') as f:
|
||||
with open(package_path, "rb") as f:
|
||||
# 读取 Header
|
||||
header = f.read(32)
|
||||
if len(header) != 32:
|
||||
|
|
@ -461,15 +469,15 @@ class MMIPKGUnpacker:
|
|||
if magic != MAGIC:
|
||||
raise MMIPKGError(f"无效的 MAGIC: {magic}")
|
||||
|
||||
version = struct.unpack('B', header[4:5])[0]
|
||||
version = struct.unpack("B", header[4:5])[0]
|
||||
if version != VERSION:
|
||||
print(f"警告: 包版本 {version} 与当前版本 {VERSION} 不匹配")
|
||||
|
||||
flags = struct.unpack('B', header[5:6])[0]
|
||||
flags = struct.unpack("B", header[5:6])[0]
|
||||
is_compressed = bool(flags & 0x01)
|
||||
|
||||
payload_uncompressed_len = struct.unpack('>Q', header[8:16])[0]
|
||||
manifest_uncompressed_len = struct.unpack('>Q', header[16:24])[0]
|
||||
payload_uncompressed_len = struct.unpack(">Q", header[8:16])[0]
|
||||
manifest_uncompressed_len = struct.unpack(">Q", header[16:24])[0]
|
||||
|
||||
# 安全检查
|
||||
if manifest_uncompressed_len > MAX_MANIFEST_SIZE:
|
||||
|
|
@ -519,7 +527,9 @@ class MMIPKGUnpacker:
|
|||
# 方法2:如果流式失败,尝试直接解压(兼容旧格式)
|
||||
print(f" 流式解压失败,尝试直接解压: {e}")
|
||||
try:
|
||||
payload_data = decompressor.decompress(compressed_data, max_output_size=payload_uncompressed_len)
|
||||
payload_data = decompressor.decompress(
|
||||
compressed_data, max_output_size=payload_uncompressed_len
|
||||
)
|
||||
except Exception as e2:
|
||||
raise MMIPKGError(f"解压失败: {e2}") from e2
|
||||
else:
|
||||
|
|
@ -537,7 +547,7 @@ class MMIPKGUnpacker:
|
|||
|
||||
# 读取 manifest
|
||||
manifest_len_bytes = payload_stream.read(4)
|
||||
manifest_len = struct.unpack('>I', manifest_len_bytes)[0]
|
||||
manifest_len = struct.unpack(">I", manifest_len_bytes)[0]
|
||||
manifest_bytes = payload_stream.read(manifest_len)
|
||||
manifest = msgpack.unpackb(manifest_bytes, raw=False)
|
||||
|
||||
|
|
@ -553,20 +563,21 @@ class MMIPKGUnpacker:
|
|||
print(f" 表情包数量: {len(items)}")
|
||||
|
||||
# 导入表情包
|
||||
return self._import_items(payload_stream, items, output_dir,
|
||||
replace_existing, batch_size)
|
||||
return self._import_items(payload_stream, items, output_dir, replace_existing, batch_size)
|
||||
|
||||
except Exception as e:
|
||||
print(f"导入失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
finally:
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
def _import_items(self, payload_stream: BinaryIO, items: List[Dict],
|
||||
output_dir: str, replace_existing: bool, batch_size: int) -> bool:
|
||||
def _import_items(
|
||||
self, payload_stream: BinaryIO, items: List[Dict], output_dir: str, replace_existing: bool, batch_size: int
|
||||
) -> bool:
|
||||
"""导入 items 到数据库"""
|
||||
try:
|
||||
imported_count = 0
|
||||
|
|
@ -581,7 +592,7 @@ class MMIPKGUnpacker:
|
|||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeRemainingColumn(),
|
||||
console=console
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]导入表情包...", total=len(items))
|
||||
|
||||
|
|
@ -597,7 +608,7 @@ class MMIPKGUnpacker:
|
|||
progress.advance(task)
|
||||
continue
|
||||
|
||||
img_len = struct.unpack('>I', img_len_bytes)[0]
|
||||
img_len = struct.unpack(">I", img_len_bytes)[0]
|
||||
img_bytes = payload_stream.read(img_len)
|
||||
|
||||
if len(img_bytes) != img_len:
|
||||
|
|
@ -641,7 +652,7 @@ class MMIPKGUnpacker:
|
|||
file_path = os.path.join(output_dir, filename)
|
||||
counter += 1
|
||||
|
||||
with open(file_path, 'wb') as img_file:
|
||||
with open(file_path, "wb") as img_file:
|
||||
img_file.write(img_bytes)
|
||||
|
||||
# 准备数据库记录
|
||||
|
|
@ -700,6 +711,7 @@ class MMIPKGUnpacker:
|
|||
except Exception as e:
|
||||
console.print(f"[red]导入 items 失败: {e}[/red]")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
|
@ -719,8 +731,9 @@ def print_menu():
|
|||
console.print(" [2] [bold]导入表情包[/bold] (从 .mmipkg 文件导入到数据库)")
|
||||
console.print(" [0] [bold]退出[/bold]")
|
||||
console.print()
|
||||
def get_input(prompt: str, default: Optional[str] = None,
|
||||
choices: Optional[List[str]] = None) -> str:
|
||||
|
||||
|
||||
def get_input(prompt: str, default: Optional[str] = None, choices: Optional[List[str]] = None) -> str:
|
||||
"""获取用户输入"""
|
||||
if default:
|
||||
prompt = f"{prompt} (默认: {default})"
|
||||
|
|
@ -760,9 +773,9 @@ def get_yes_no(prompt: str, default: bool = False) -> bool:
|
|||
if not value:
|
||||
return default
|
||||
|
||||
if value in ('y', 'yes', '是'):
|
||||
if value in ("y", "yes", "是"):
|
||||
return True
|
||||
elif value in ('n', 'no', '否'):
|
||||
elif value in ("n", "no", "否"):
|
||||
return False
|
||||
else:
|
||||
console.print(" [yellow]⚠ 请输入 y/yes/是 或 n/no/否[/yellow]")
|
||||
|
|
@ -843,8 +856,8 @@ def interactive_export():
|
|||
output_path = get_input(" 输出文件路径", default_filename)
|
||||
|
||||
# 确保有 .mmipkg 扩展名
|
||||
if not output_path.endswith('.mmipkg'):
|
||||
output_path += '.mmipkg'
|
||||
if not output_path.endswith(".mmipkg"):
|
||||
output_path += ".mmipkg"
|
||||
|
||||
# 获取包名称
|
||||
default_pack_name = f"MaiBot表情包_{datetime.now().strftime('%Y%m%d')}"
|
||||
|
|
@ -853,9 +866,7 @@ def interactive_export():
|
|||
# 自定义 manifest
|
||||
console.print("\n[yellow]2. 包信息设置(可选)[/yellow]")
|
||||
if get_yes_no(" 是否添加包的作者和介绍信息", False):
|
||||
custom_manifest = {
|
||||
"author": author
|
||||
} if (author := input(" 作者名称(可选): ").strip()) else {}
|
||||
custom_manifest = {"author": author} if (author := input(" 作者名称(可选): ").strip()) else {}
|
||||
|
||||
# 介绍信息
|
||||
console.print(" 包介绍(限制 100 字以内):")
|
||||
|
|
@ -888,9 +899,9 @@ def interactive_export():
|
|||
console.print(" webp: 推荐,体积小且支持透明度")
|
||||
console.print(" jpeg: 最小体积,但不支持透明度")
|
||||
console.print(" png: 无损,文件较大")
|
||||
reencode = get_input(" 选择格式", "webp", ['webp', 'jpeg', 'png'])
|
||||
reencode = get_input(" 选择格式", "webp", ["webp", "jpeg", "png"])
|
||||
|
||||
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ('webp', 'jpeg') else 80
|
||||
quality = get_int(" 编码质量", 80, 1, 100) if reencode in ("webp", "jpeg") else 80
|
||||
else:
|
||||
reencode = None
|
||||
quality = 80
|
||||
|
|
@ -920,10 +931,7 @@ def interactive_export():
|
|||
# 开始导出
|
||||
console.print("\n[cyan]开始导出...[/cyan]")
|
||||
packer = MMIPKGPacker(
|
||||
use_compression=use_compression,
|
||||
zstd_level=zstd_level,
|
||||
reencode=reencode,
|
||||
reencode_quality=quality
|
||||
use_compression=use_compression, zstd_level=zstd_level, reencode=reencode, reencode_quality=quality
|
||||
)
|
||||
|
||||
success = packer.pack_from_db(output_path, pack_name, custom_manifest)
|
||||
|
|
@ -944,11 +952,11 @@ def interactive_import():
|
|||
|
||||
# 选择导入模式
|
||||
print_import_mode_selection()
|
||||
import_mode = get_input("请选择", "1", ['1', '2'])
|
||||
import_mode = get_input("请选择", "1", ["1", "2"])
|
||||
|
||||
input_files = []
|
||||
|
||||
if import_mode == '1':
|
||||
if import_mode == "1":
|
||||
# 自动扫描模式
|
||||
import_dir = os.path.join(PROJECT_ROOT, "data", "import_emoji")
|
||||
os.makedirs(import_dir, exist_ok=True)
|
||||
|
|
@ -957,7 +965,7 @@ def interactive_import():
|
|||
|
||||
# 查找所有 .mmipkg 文件
|
||||
for file in os.listdir(import_dir):
|
||||
if file.endswith('.mmipkg'):
|
||||
if file.endswith(".mmipkg"):
|
||||
file_path = os.path.join(import_dir, file)
|
||||
if os.path.isfile(file_path):
|
||||
input_files.append(file_path)
|
||||
|
|
@ -1032,7 +1040,7 @@ def interactive_import():
|
|||
BarColumn(),
|
||||
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
|
||||
TimeElapsedColumn(),
|
||||
console=console
|
||||
console=console,
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]导入文件...", total=len(input_files))
|
||||
|
||||
|
|
@ -1044,10 +1052,7 @@ def interactive_import():
|
|||
console.print(f"[bold]{'=' * 70}[/bold]")
|
||||
|
||||
success = unpacker.import_to_db(
|
||||
input_path,
|
||||
output_dir=output_dir,
|
||||
replace_existing=replace_existing,
|
||||
batch_size=batch_size
|
||||
input_path, output_dir=output_dir, replace_existing=replace_existing, batch_size=batch_size
|
||||
)
|
||||
|
||||
if success:
|
||||
|
|
@ -1076,16 +1081,16 @@ def main():
|
|||
while True:
|
||||
print_menu()
|
||||
try:
|
||||
choice = get_input("请选择", "1", ['0', '1', '2'])
|
||||
choice = get_input("请选择", "1", ["0", "1", "2"])
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[green]再见![/green]")
|
||||
return 0
|
||||
|
||||
if choice == '0':
|
||||
if choice == "0":
|
||||
console.print("\n[green]再见![/green]")
|
||||
return 0
|
||||
|
||||
elif choice == '1':
|
||||
elif choice == "1":
|
||||
try:
|
||||
interactive_export()
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -1093,6 +1098,7 @@ def main():
|
|||
except Exception as e:
|
||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
|
|
@ -1100,7 +1106,7 @@ def main():
|
|||
except (KeyboardInterrupt, EOFError):
|
||||
pass
|
||||
|
||||
elif choice == '2':
|
||||
elif choice == "2":
|
||||
try:
|
||||
interactive_import()
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -1108,6 +1114,7 @@ def main():
|
|||
except Exception as e:
|
||||
console.print(f"\n[red]✗ 发生错误: {e}[/red]")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
try:
|
||||
|
|
@ -1121,5 +1128,5 @@ def main():
|
|||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
|
|
|||
|
|
@ -334,7 +334,6 @@ class HeartFChatting:
|
|||
self.consecutive_no_reply_count = 0
|
||||
reason = ""
|
||||
|
||||
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=False,
|
||||
|
|
|
|||
|
|
@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
|
|||
qa_manager = None
|
||||
inspire_manager = None
|
||||
|
||||
|
||||
def get_qa_manager():
|
||||
return qa_manager
|
||||
|
||||
|
||||
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||
# 检查LPMM知识库是否启用
|
||||
if global_config.lpmm_knowledge.enable:
|
||||
|
|
|
|||
|
|
@ -128,11 +128,10 @@ class QAManager:
|
|||
selected_knowledge = knowledge[:limit]
|
||||
|
||||
formatted_knowledge = [
|
||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
|
||||
for i, k in enumerate(selected_knowledge)
|
||||
f"第{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
|
||||
]
|
||||
# if max_score is not None:
|
||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
|
||||
|
||||
found_knowledge = "\n".join(formatted_knowledge)
|
||||
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
|
||||
|
|
|
|||
|
|
@ -226,7 +226,9 @@ class DefaultReplyer:
|
|||
traceback.print_exc()
|
||||
return False, llm_response
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
|
|
|
|||
|
|
@ -241,7 +241,9 @@ class PrivateReplyer:
|
|||
|
||||
return f"{sender_relation}"
|
||||
|
||||
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
|
||||
async def build_expression_habits(
|
||||
self, chat_history: str, target: str, reply_reason: str = ""
|
||||
) -> Tuple[str, List[int]]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""构建表达习惯块
|
||||
|
||||
|
|
|
|||
|
|
@ -204,8 +204,9 @@ class WebSocketLogHandler(logging.Handler):
|
|||
message = formatted_msg
|
||||
try:
|
||||
import json
|
||||
|
||||
log_dict = json.loads(formatted_msg)
|
||||
message = log_dict.get('event', formatted_msg)
|
||||
message = log_dict.get("event", formatted_msg)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# 不是 JSON,直接使用消息
|
||||
message = formatted_msg
|
||||
|
|
@ -228,10 +229,7 @@ class WebSocketLogHandler(logging.Handler):
|
|||
import asyncio
|
||||
from src.webui.logs_ws import broadcast_log
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
broadcast_log(log_data),
|
||||
self.loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
|
||||
except Exception:
|
||||
# WebSocket 推送失败不影响日志记录
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -467,11 +467,7 @@ class ExpressionLearner:
|
|||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = (
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
|
||||
.first()
|
||||
)
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
||||
if expr_obj:
|
||||
await self._update_existing_expression(
|
||||
|
|
|
|||
|
|
@ -42,8 +42,6 @@ def init_prompt():
|
|||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
|
|
@ -262,7 +260,6 @@ class ExpressionSelector:
|
|||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
|
||||
# print(prompt)
|
||||
|
||||
if not content:
|
||||
|
|
|
|||
|
|
@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
|
|||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [
|
||||
str(alias or "").strip().lower()
|
||||
for alias in getattr(bot_config, "alias_names", []) or []
|
||||
]
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
|
|
@ -188,9 +185,7 @@ async def _enrich_raw_content_if_needed(
|
|||
# 获取该消息的前三条消息
|
||||
try:
|
||||
previous_messages = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_id,
|
||||
timestamp=target_message.time,
|
||||
limit=3
|
||||
chat_id=chat_id, timestamp=target_message.time, limit=3
|
||||
)
|
||||
|
||||
if previous_messages:
|
||||
|
|
@ -245,7 +240,7 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
|||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [3,6, 10, 20, 40, 60, 100]
|
||||
thresholds = [3, 6, 10, 20, 40, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
|
|
@ -311,7 +306,9 @@ class JargonMiner:
|
|||
raw_content_list = []
|
||||
if raw_content_str:
|
||||
try:
|
||||
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
raw_content_list = (
|
||||
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
)
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
|
|
@ -360,7 +357,6 @@ class JargonMiner:
|
|||
jargon_obj.save()
|
||||
return
|
||||
|
||||
|
||||
# 步骤2: 仅基于content推断
|
||||
prompt2 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_content_only_prompt",
|
||||
|
|
@ -388,7 +384,6 @@ class JargonMiner:
|
|||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
|
|
@ -457,7 +452,9 @@ class JargonMiner:
|
|||
jargon_obj.is_complete = True
|
||||
|
||||
jargon_obj.save()
|
||||
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
|
||||
logger.debug(
|
||||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||
)
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
|
|
@ -475,6 +472,7 @@ class JargonMiner:
|
|||
except Exception as e:
|
||||
logger.error(f"jargon推断失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
|
|
@ -571,10 +569,7 @@ class JargonMiner:
|
|||
if _contains_bot_self_name(content):
|
||||
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
entries.append({
|
||||
"content": content,
|
||||
"raw_content": raw_content_list
|
||||
})
|
||||
entries.append({"content": content, "raw_content": raw_content_list})
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
|
@ -612,19 +607,10 @@ class JargonMiner:
|
|||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.jargon.all_global:
|
||||
# 开启all_global:无视chat_id,查询所有content匹配的记录(所有记录都是全局的)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(Jargon.content == content)
|
||||
)
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
else:
|
||||
# 关闭all_global:只查询chat_id匹配的记录(不考虑is_global)
|
||||
query = (
|
||||
Jargon.select()
|
||||
.where(
|
||||
(Jargon.chat_id == self.chat_id) &
|
||||
(Jargon.content == content)
|
||||
)
|
||||
)
|
||||
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
|
||||
|
||||
if query.exists():
|
||||
obj = query.get()
|
||||
|
|
@ -637,7 +623,9 @@ class JargonMiner:
|
|||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
|
|
@ -676,7 +664,7 @@ class JargonMiner:
|
|||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=self.chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
|
|
@ -720,11 +708,7 @@ async def extract_and_store_jargon(chat_id: str) -> None:
|
|||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
case_sensitive: bool = False,
|
||||
fuzzy: bool = True
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索jargon,支持大小写不敏感和模糊搜索
|
||||
|
|
@ -747,10 +731,7 @@ def search_jargon(
|
|||
keyword = keyword.strip()
|
||||
|
||||
# 构建查询
|
||||
query = Jargon.select(
|
||||
Jargon.content,
|
||||
Jargon.meaning
|
||||
)
|
||||
query = Jargon.select(Jargon.content, Jargon.meaning)
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive:
|
||||
|
|
@ -760,7 +741,7 @@ def search_jargon(
|
|||
search_condition = Jargon.content.contains(keyword)
|
||||
else:
|
||||
# 精确匹配
|
||||
search_condition = (Jargon.content == keyword)
|
||||
search_condition = Jargon.content == keyword
|
||||
else:
|
||||
# 大小写不敏感
|
||||
if fuzzy:
|
||||
|
|
@ -768,7 +749,7 @@ def search_jargon(
|
|||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||
else:
|
||||
# 精确匹配(使用LOWER函数)
|
||||
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
|
||||
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||
|
||||
query = query.where(search_condition)
|
||||
|
||||
|
|
@ -779,14 +760,10 @@ def search_jargon(
|
|||
else:
|
||||
# 关闭all_global:如果提供了chat_id,优先搜索该聊天或global的jargon
|
||||
if chat_id:
|
||||
query = query.where(
|
||||
(Jargon.chat_id == chat_id) | Jargon.is_global
|
||||
)
|
||||
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
|
||||
|
||||
# 只返回有meaning的记录
|
||||
query = query.where(
|
||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
||||
)
|
||||
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 按count降序排序,优先返回出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
|
@ -797,10 +774,7 @@ def search_jargon(
|
|||
# 执行查询并返回结果
|
||||
results = []
|
||||
for jargon in query:
|
||||
results.append({
|
||||
"content": jargon.content or "",
|
||||
"meaning": jargon.meaning or ""
|
||||
})
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
|
||||
return results
|
||||
|
||||
|
|
@ -840,10 +814,7 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||
if global_config.jargon.all_global:
|
||||
query = Jargon.select().where(Jargon.content == jargon_keyword)
|
||||
else:
|
||||
query = Jargon.select().where(
|
||||
(Jargon.chat_id == chat_id) &
|
||||
(Jargon.content == jargon_keyword)
|
||||
)
|
||||
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
|
||||
|
||||
if query.exists():
|
||||
# 更新现有记录
|
||||
|
|
@ -854,7 +825,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
|
|
@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
|
|||
raw_content=json.dumps([raw_content], ensure_ascii=False),
|
||||
chat_id=chat_id,
|
||||
is_global=is_global_new,
|
||||
count=1
|
||||
count=1,
|
||||
)
|
||||
logger.info(f"创建新jargon记录: {jargon_keyword}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"存储jargon失败: {e}")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -116,9 +116,7 @@ class MessageBuilder:
|
|||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0 and not (
|
||||
self.__role == RoleType.Assistant and self.__tool_calls
|
||||
):
|
||||
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class MainSystem:
|
|||
"""注册 WebUI API 路由"""
|
||||
try:
|
||||
from src.webui.routes import router as webui_router
|
||||
|
||||
self.server.register_router(webui_router)
|
||||
logger.info("WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
|
|
@ -55,6 +56,7 @@ class MainSystem:
|
|||
def _setup_webui(self):
|
||||
"""设置 WebUI(根据环境变量决定模式)"""
|
||||
import os
|
||||
|
||||
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
|
||||
if not webui_enabled:
|
||||
logger.info("WebUI 已禁用")
|
||||
|
|
@ -64,6 +66,7 @@ class MainSystem:
|
|||
|
||||
try:
|
||||
from src.webui.manager import setup_webui
|
||||
|
||||
setup_webui(mode=webui_mode)
|
||||
except Exception as e:
|
||||
logger.error(f"设置 WebUI 失败: {e}")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from src.llm_models.payload_content.message import MessageBuilder, RoleType, Mes
|
|||
logger = get_logger("memory_retrieval")
|
||||
|
||||
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 3600 # 未找到答案记录保留时长
|
||||
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
|
||||
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
|
||||
_last_not_found_cleanup_ts: float = 0.0
|
||||
|
||||
|
||||
|
|
@ -33,10 +33,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||
try:
|
||||
deleted_rows = (
|
||||
ThinkingBack.delete()
|
||||
.where(
|
||||
(ThinkingBack.found_answer == 0) &
|
||||
(ThinkingBack.update_time < threshold_time)
|
||||
)
|
||||
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
|
||||
.execute()
|
||||
)
|
||||
if deleted_rows:
|
||||
|
|
@ -45,6 +42,7 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
|||
except Exception as e:
|
||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
||||
|
||||
|
||||
def init_memory_retrieval_prompt():
|
||||
"""初始化记忆检索相关的 prompt 模板和工具"""
|
||||
# 首先注册所有工具
|
||||
|
|
@ -221,10 +219,7 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
|||
return None
|
||||
|
||||
|
||||
async def _retrieve_concepts_with_jargon(
|
||||
concepts: List[str],
|
||||
chat_id: str
|
||||
) -> str:
|
||||
async def _retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
|
|
@ -246,25 +241,13 @@ async def _retrieve_concepts_with_jargon(
|
|||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(
|
||||
keyword=concept,
|
||||
chat_id=chat_id,
|
||||
limit=10,
|
||||
case_sensitive=False,
|
||||
fuzzy=False
|
||||
)
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(
|
||||
keyword=concept,
|
||||
chat_id=chat_id,
|
||||
limit=10,
|
||||
case_sensitive=False,
|
||||
fuzzy=True
|
||||
)
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
|
|
@ -298,11 +281,7 @@ async def _retrieve_concepts_with_jargon(
|
|||
|
||||
|
||||
async def _react_agent_solve_question(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
max_iterations: int = 5,
|
||||
timeout: float = 30.0,
|
||||
initial_info: str = ""
|
||||
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0, initial_info: str = ""
|
||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||
"""使用ReAct架构的Agent来解决问题
|
||||
|
||||
|
|
@ -343,11 +322,12 @@ async def _react_agent_solve_question(
|
|||
remaining_iterations = max_iterations - current_iteration
|
||||
is_final_iteration = current_iteration >= max_iterations
|
||||
|
||||
|
||||
if is_final_iteration:
|
||||
# 最后一次迭代,使用最终prompt
|
||||
tool_definitions = []
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0(最后一次迭代,不提供工具调用)"
|
||||
)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_final_prompt",
|
||||
|
|
@ -370,7 +350,9 @@ async def _react_agent_solve_question(
|
|||
else:
|
||||
# 非最终迭代,使用head_prompt
|
||||
tool_definitions = tool_registry.get_tool_definitions()
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
|
||||
)
|
||||
|
||||
head_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_prompt_head",
|
||||
|
|
@ -401,7 +383,7 @@ async def _react_agent_solve_question(
|
|||
# 优化日志展示 - 合并所有消息到一条日志
|
||||
log_lines = []
|
||||
for idx, msg in enumerate(messages, 1):
|
||||
role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
|
||||
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
# 处理内容 - 显示完整内容,不截断
|
||||
if isinstance(msg.content, str):
|
||||
|
|
@ -437,14 +419,22 @@ async def _react_agent_solve_question(
|
|||
|
||||
return messages
|
||||
|
||||
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
(
|
||||
success,
|
||||
response,
|
||||
reasoning_content,
|
||||
model_name,
|
||||
tool_calls,
|
||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||
message_factory,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
tool_options=tool_definitions,
|
||||
request_type="memory.react",
|
||||
)
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||
|
|
@ -465,12 +455,7 @@ async def _react_agent_solve_question(
|
|||
assistant_message = assistant_builder.build()
|
||||
|
||||
# 记录思考步骤
|
||||
step = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": response,
|
||||
"actions": [],
|
||||
"observations": []
|
||||
}
|
||||
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
|
||||
|
||||
# 优先从思考内容中提取found_answer或not_enough_info
|
||||
def extract_quoted_content(text, func_name, param_name):
|
||||
|
|
@ -495,14 +480,14 @@ async def _react_agent_solve_question(
|
|||
return None
|
||||
|
||||
# 查找参数名和等号
|
||||
param_pattern = f'{param_name}='
|
||||
param_pattern = f"{param_name}="
|
||||
param_pos = text_lower.find(param_pattern, func_pos)
|
||||
if param_pos == -1:
|
||||
return None
|
||||
|
||||
# 跳过参数名、等号和空白
|
||||
start_pos = param_pos + len(param_pattern)
|
||||
while start_pos < len(text) and text[start_pos] in ' \t\n':
|
||||
while start_pos < len(text) and text[start_pos] in " \t\n":
|
||||
start_pos += 1
|
||||
|
||||
if start_pos >= len(text):
|
||||
|
|
@ -518,13 +503,13 @@ async def _react_agent_solve_question(
|
|||
while end_pos < len(text):
|
||||
if text[end_pos] == quote_char:
|
||||
# 检查是否是转义的引号
|
||||
if end_pos > start_pos + 1 and text[end_pos - 1] == '\\':
|
||||
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
|
||||
end_pos += 1
|
||||
continue
|
||||
# 找到匹配的引号
|
||||
content = text[start_pos + 1:end_pos]
|
||||
content = text[start_pos + 1 : end_pos]
|
||||
# 处理转义字符
|
||||
content = content.replace('\\"', '"').replace("\\'", "'").replace('\\\\', '\\')
|
||||
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
|
||||
return content
|
||||
end_pos += 1
|
||||
|
||||
|
|
@ -536,27 +521,35 @@ async def _react_agent_solve_question(
|
|||
|
||||
# 只检查response(LLM的直接输出内容),不检查reasoning_content
|
||||
if response:
|
||||
found_answer_content = extract_quoted_content(response, 'found_answer', 'answer')
|
||||
found_answer_content = extract_quoted_content(response, "found_answer", "answer")
|
||||
if not found_answer_content:
|
||||
not_enough_info_reason = extract_quoted_content(response, 'not_enough_info', 'reason')
|
||||
not_enough_info_reason = extract_quoted_content(response, "not_enough_info", "reason")
|
||||
|
||||
# 如果从输出内容中找到了答案,直接返回
|
||||
if found_answer_content:
|
||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
||||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}...")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到found_answer: {found_answer_content[:100]}..."
|
||||
)
|
||||
return True, found_answer_content, thinking_steps, False
|
||||
|
||||
if not_enough_info_reason:
|
||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}})
|
||||
step["actions"].append(
|
||||
{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}
|
||||
)
|
||||
step["observations"] = ["从LLM输出内容中检测到not_enough_info"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}...")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 从LLM输出内容中检测到not_enough_info: {not_enough_info_reason[:100]}..."
|
||||
)
|
||||
return False, not_enough_info_reason, thinking_steps, False
|
||||
|
||||
if is_final_iteration:
|
||||
step["actions"].append({"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}})
|
||||
step["actions"].append(
|
||||
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}
|
||||
)
|
||||
step["observations"] = ["已到达最后一次迭代,无法找到答案"]
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
|
||||
|
|
@ -596,7 +589,9 @@ async def _react_agent_solve_question(
|
|||
tool_name = tool_call.func_name
|
||||
tool_args = tool_call.args or {}
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i+1}/{len(tool_calls)}: {tool_name}({tool_args})")
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
|
||||
)
|
||||
|
||||
# 普通工具调用
|
||||
tool = tool_registry.get_tool(tool_name)
|
||||
|
|
@ -606,6 +601,7 @@ async def _react_agent_solve_question(
|
|||
|
||||
# 如果工具函数签名需要chat_id,添加它
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(tool.execute_func)
|
||||
if "chat_id" in sig.parameters:
|
||||
tool_params["chat_id"] = chat_id
|
||||
|
|
@ -625,7 +621,7 @@ async def _react_agent_solve_question(
|
|||
step["actions"].append({"action_type": tool_name, "action_params": tool_args})
|
||||
else:
|
||||
error_msg = f"未知的工具类型: {tool_name}"
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1}/{len(tool_calls)} {error_msg}")
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}")
|
||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}")))
|
||||
|
||||
# 并行执行所有工具
|
||||
|
|
@ -636,7 +632,7 @@ async def _react_agent_solve_question(
|
|||
for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)):
|
||||
if isinstance(observation, Exception):
|
||||
observation = f"工具执行异常: {str(observation)}"
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i+1} 执行异常: {observation}")
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}")
|
||||
|
||||
observation_text = observation if isinstance(observation, str) else str(observation)
|
||||
step["observations"].append(observation_text)
|
||||
|
|
@ -655,7 +651,9 @@ async def _react_agent_solve_question(
|
|||
# 迭代超时应该直接视为not_enough_info,而不是使用已有信息
|
||||
# 只有Agent明确返回found_answer时,才认为找到了答案
|
||||
if collected_info:
|
||||
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}...")
|
||||
logger.warning(
|
||||
f"ReAct Agent达到最大迭代次数或超时,但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
|
||||
)
|
||||
if is_timeout:
|
||||
logger.warning("ReAct Agent超时,直接视为not_enough_info")
|
||||
else:
|
||||
|
|
@ -680,10 +678,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
|||
# 查询最近时间窗口内的记录,按更新时间倒序
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.update_time >= start_time)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
)
|
||||
|
|
@ -735,9 +730,9 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
|
|||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.update_time >= start_time) &
|
||||
(ThinkingBack.found_answer == 1)
|
||||
(ThinkingBack.chat_id == chat_id)
|
||||
& (ThinkingBack.update_time >= start_time)
|
||||
& (ThinkingBack.found_answer == 1)
|
||||
)
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
|
|
@ -775,10 +770,7 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
|
|||
# 按更新时间倒序,获取最新的记录
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.question == question)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
|
@ -857,6 +849,7 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||
jargon_keyword = analysis_result.get("jargon_keyword", "").strip()
|
||||
if jargon_keyword:
|
||||
from src.jargon.jargon_miner import store_jargon_from_answer
|
||||
|
||||
await store_jargon_from_answer(jargon_keyword, answer, chat_id)
|
||||
else:
|
||||
logger.warning(f"分析为黑话但未提取到关键词,问题: {question[:50]}...")
|
||||
|
|
@ -882,14 +875,8 @@ async def _analyze_question_answer(question: str, answer: str, chat_id: str) ->
|
|||
logger.error(f"分析问题和答案时发生异常: {e}")
|
||||
|
||||
|
||||
|
||||
def _store_thinking_back(
|
||||
chat_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
found_answer: bool,
|
||||
answer: str,
|
||||
thinking_steps: List[Dict[str, Any]]
|
||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||
|
||||
|
|
@ -907,10 +894,7 @@ def _store_thinking_back(
|
|||
# 先查询是否已存在相同chat_id和问题的记录
|
||||
existing = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.question == question)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
|
@ -935,19 +919,14 @@ def _store_thinking_back(
|
|||
answer=answer,
|
||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||
create_time=now,
|
||||
update_time=now
|
||||
update_time=now,
|
||||
)
|
||||
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"存储思考过程失败: {e}")
|
||||
|
||||
|
||||
async def _process_single_question(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
context: str,
|
||||
initial_info: str = ""
|
||||
) -> Optional[str]:
|
||||
async def _process_single_question(question: str, chat_id: str, context: str, initial_info: str = "") -> Optional[str]:
|
||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||
|
||||
Args:
|
||||
|
|
@ -1015,7 +994,7 @@ async def _process_single_question(
|
|||
chat_id=chat_id,
|
||||
max_iterations=global_config.memory.max_agent_iterations,
|
||||
timeout=120.0,
|
||||
initial_info=question_initial_info
|
||||
initial_info=question_initial_info,
|
||||
)
|
||||
|
||||
# 存储到数据库(超时时不存储)
|
||||
|
|
@ -1026,7 +1005,7 @@ async def _process_single_question(
|
|||
context=context,
|
||||
found_answer=found_answer,
|
||||
answer=answer,
|
||||
thinking_steps=thinking_steps
|
||||
thinking_steps=thinking_steps,
|
||||
)
|
||||
else:
|
||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||
|
|
@ -1112,7 +1091,6 @@ async def build_memory_retrieval_prompt(
|
|||
else:
|
||||
logger.info("概念检索未找到任何结果")
|
||||
|
||||
|
||||
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
||||
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
||||
|
||||
|
|
@ -1141,12 +1119,7 @@ async def build_memory_retrieval_prompt(
|
|||
|
||||
# 并行处理所有问题,将概念检索结果作为初始信息传递
|
||||
question_tasks = [
|
||||
_process_single_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
context=message,
|
||||
initial_info=initial_info
|
||||
)
|
||||
_process_single_question(question=question, chat_id=chat_id, context=message, initial_info=initial_info)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
|
|
@ -1179,7 +1152,9 @@ async def build_memory_retrieval_prompt(
|
|||
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
||||
logger.info(
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
|
@ -14,6 +15,7 @@ from src.common.logger import get_logger
|
|||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
|
|
@ -52,6 +54,7 @@ def parse_md_json(json_text: str) -> list[str]:
|
|||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
|
@ -97,10 +100,10 @@ def preprocess_text(text: str) -> str:
|
|||
text = text.lower()
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
|
@ -109,7 +112,6 @@ def preprocess_text(text: str) -> str:
|
|||
return text
|
||||
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
|
|
@ -164,4 +166,3 @@ def parse_time_range(time_range: str) -> Tuple[float, float]:
|
|||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
|||
from .query_person_info import register_tool as register_query_person_info
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
|
|
|
|||
|
|
@ -15,10 +15,7 @@ logger = get_logger("memory_retrieval_tools")
|
|||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_range: Optional[str] = None,
|
||||
fuzzy: bool = True
|
||||
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
|
|
@ -50,17 +47,11 @@ async def query_chat_history(
|
|||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||
query = query.where(time_filter)
|
||||
|
||||
# 执行查询
|
||||
|
|
@ -91,7 +82,9 @@ async def query_chat_history(
|
|||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
|
|
@ -102,20 +95,24 @@ async def query_chat_history(
|
|||
if fuzzy:
|
||||
# 模糊匹配:只要包含任意一个关键词即匹配(OR关系)
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
if (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
else:
|
||||
# 全匹配:必须包含所有关键词才匹配(AND关系)
|
||||
matched = True
|
||||
for kw in keywords_lower:
|
||||
kw_matched = (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list))
|
||||
kw_matched = (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
)
|
||||
if not kw_matched:
|
||||
matched = False
|
||||
break
|
||||
|
|
@ -160,6 +157,7 @@ async def query_chat_history(
|
|||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
|
@ -199,20 +197,20 @@ def register_tool():
|
|||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊匹配模式(默认True)。True表示模糊匹配(只要包含任意一个关键词即匹配,OR关系),False表示全匹配(必须包含所有关键词才匹配,AND关系)",
|
||||
"required": False
|
||||
}
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
execute_func=query_chat_history,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,5 +73,3 @@ def register_tool():
|
|||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,9 @@ def _format_group_nick_names(group_nick_name_field) -> str:
|
|||
|
||||
try:
|
||||
# 解析JSON格式的群昵称列表
|
||||
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
group_nick_names_data = (
|
||||
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
|
||||
)
|
||||
|
||||
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
|
||||
return ""
|
||||
|
|
@ -71,9 +73,7 @@ async def query_person_info(person_name: str) -> str:
|
|||
return "用户名称为空"
|
||||
|
||||
# 构建查询条件(使用模糊查询)
|
||||
query = PersonInfo.select().where(
|
||||
PersonInfo.person_name.contains(person_name)
|
||||
)
|
||||
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
|
||||
|
||||
# 执行查询
|
||||
records = list(query.limit(20)) # 最多返回20条记录
|
||||
|
|
@ -137,7 +137,11 @@ async def query_person_info(person_name: str) -> str:
|
|||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
|
|
@ -206,7 +210,11 @@ async def query_person_info(person_name: str) -> str:
|
|||
# 记忆点(memory_points)
|
||||
if record.memory_points:
|
||||
try:
|
||||
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
|
||||
memory_points_data = (
|
||||
json.loads(record.memory_points)
|
||||
if isinstance(record.memory_points, str)
|
||||
else record.memory_points
|
||||
)
|
||||
if isinstance(memory_points_data, list) and memory_points_data:
|
||||
# 解析记忆点格式:category:content:weight
|
||||
memory_list = []
|
||||
|
|
@ -275,13 +283,7 @@ def register_tool():
|
|||
name="query_person_info",
|
||||
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
|
||||
parameters=[
|
||||
{
|
||||
"name": "person_name",
|
||||
"type": "string",
|
||||
"description": "用户名称,用于查询用户信息",
|
||||
"required": True
|
||||
}
|
||||
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
|
||||
],
|
||||
execute_func=query_person_info
|
||||
execute_func=query_person_info,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -82,11 +82,7 @@ class MemoryRetrievalTool:
|
|||
param_tuples.append(param_tuple)
|
||||
|
||||
# 构建工具定义,格式与BaseTool.get_tool_definition()一致
|
||||
tool_def = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": param_tuples
|
||||
}
|
||||
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
|
||||
|
||||
return tool_def
|
||||
|
||||
|
|
|
|||
|
|
@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
|||
class Person:
|
||||
@classmethod
|
||||
def register_person(
|
||||
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
|
||||
cls,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
nickname: str,
|
||||
group_id: Optional[str] = None,
|
||||
group_nick_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
注册新用户的类方法
|
||||
|
|
@ -781,7 +786,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
|||
if len(parts) >= 2:
|
||||
existing_content = parts[1].strip()
|
||||
# 简单相似度检查(如果内容相同或非常相似,则跳过)
|
||||
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
|
||||
if (
|
||||
existing_content == memory_content
|
||||
or memory_content in existing_content
|
||||
or existing_content in memory_content
|
||||
):
|
||||
is_duplicate = True
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,6 @@ class ToolExecutor:
|
|||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
|||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
|
|
@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
|
|||
|
||||
class EmojiListResponse(BaseModel):
|
||||
"""表情包列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel):
|
|||
|
||||
class EmojiDetailResponse(BaseModel):
|
||||
"""表情包详情响应"""
|
||||
|
||||
success: bool
|
||||
data: EmojiResponse
|
||||
|
||||
|
||||
class EmojiUpdateRequest(BaseModel):
|
||||
"""表情包更新请求"""
|
||||
|
||||
description: Optional[str] = None
|
||||
is_registered: Optional[bool] = None
|
||||
is_banned: Optional[bool] = None
|
||||
|
|
@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel):
|
|||
|
||||
class EmojiUpdateResponse(BaseModel):
|
||||
"""表情包更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
|
@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel):
|
|||
|
||||
class EmojiDeleteResponse(BaseModel):
|
||||
"""表情包删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
|
@ -120,7 +127,7 @@ async def get_emoji_list(
|
|||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包列表
|
||||
|
|
@ -145,10 +152,7 @@ async def get_emoji_list(
|
|||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Emoji.description.contains(search)) |
|
||||
(Emoji.emoji_hash.contains(search))
|
||||
)
|
||||
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||
|
||||
# 注册状态过滤
|
||||
if is_registered is not None:
|
||||
|
|
@ -164,10 +168,9 @@ async def get_emoji_list(
|
|||
|
||||
# 排序:使用次数倒序,然后按记录时间倒序
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Emoji.usage_count.desc(),
|
||||
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
|
||||
Emoji.record_time.desc()
|
||||
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
|
|
@ -180,13 +183,7 @@ async def get_emoji_list(
|
|||
# 转换为响应对象
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
|
||||
return EmojiListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -196,10 +193,7 @@ async def get_emoji_list(
|
|||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
|
|
@ -218,10 +212,7 @@ async def get_emoji_detail(
|
|||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
return EmojiDetailResponse(
|
||||
success=True,
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -231,11 +222,7 @@ async def get_emoji_detail(
|
|||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
|
|
@ -262,15 +249,15 @@ async def update_emoji(
|
|||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 处理情感标签(转换为 JSON)
|
||||
if 'emotion' in update_data:
|
||||
if update_data['emotion'] is None:
|
||||
update_data['emotion'] = None
|
||||
if "emotion" in update_data:
|
||||
if update_data["emotion"] is None:
|
||||
update_data["emotion"] = None
|
||||
else:
|
||||
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
|
||||
update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False)
|
||||
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
|
||||
update_data['register_time'] = time.time()
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
|
|
@ -281,9 +268,7 @@ async def update_emoji(
|
|||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=emoji_to_response(emoji)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -294,10 +279,7 @@ async def update_emoji(
|
|||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
|
|
@ -324,10 +306,7 @@ async def delete_emoji(
|
|||
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表情包: {emoji_hash}"
|
||||
)
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -337,9 +316,7 @@ async def delete_emoji(
|
|||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
|
|
@ -369,7 +346,7 @@ async def get_emoji_stats(
|
|||
"id": emoji.id,
|
||||
"emoji_hash": emoji.emoji_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.usage_count
|
||||
"usage_count": emoji.usage_count,
|
||||
}
|
||||
for emoji in top_used
|
||||
]
|
||||
|
|
@ -382,8 +359,8 @@ async def get_emoji_stats(
|
|||
"banned": banned,
|
||||
"unregistered": total - registered,
|
||||
"formats": formats,
|
||||
"top_used": top_used_list
|
||||
}
|
||||
"top_used": top_used_list,
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -394,10 +371,7 @@ async def get_emoji_stats(
|
|||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
|
|
@ -429,11 +403,7 @@ async def register_emoji(
|
|||
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包注册成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -443,10 +413,7 @@ async def register_emoji(
|
|||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
|
|
@ -472,11 +439,7 @@ async def ban_emoji(
|
|||
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包禁用成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -489,7 +452,7 @@ async def ban_emoji(
|
|||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
|
|
@ -523,25 +486,20 @@ async def get_emoji_thumbnail(
|
|||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'gif': 'image/gif',
|
||||
'webp': 'image/webp',
|
||||
'bmp': 'image/bmp'
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
|
||||
return FileResponse(
|
||||
path=emoji.full_path,
|
||||
media_type=media_type,
|
||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
|
@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
|
|||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
|
|
@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
|||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
|
|||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
|
|
@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
|
|||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
|
|
@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
|
|||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
|
|
@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
|
|||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
|
|
@ -112,7 +121,7 @@ async def get_expression_list(
|
|||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
|
@ -136,9 +145,9 @@ async def get_expression_list(
|
|||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search)) |
|
||||
(Expression.style.contains(search)) |
|
||||
(Expression.context.contains(search))
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
| (Expression.context.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
|
|
@ -147,9 +156,9 @@ async def get_expression_list(
|
|||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
|
||||
Expression.last_active_time.desc()
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
|
|
@ -162,13 +171,7 @@ async def get_expression_list(
|
|||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -178,10 +181,7 @@ async def get_expression_list(
|
|||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
|
|
@ -200,10 +200,7 @@ async def get_expression_detail(
|
|||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(
|
||||
success=True,
|
||||
data=expression_to_response(expression)
|
||||
)
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -213,10 +210,7 @@ async def get_expression_detail(
|
|||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
|
|
@ -246,9 +240,7 @@ async def create_expression(
|
|||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True,
|
||||
message="表达方式创建成功",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -260,9 +252,7 @@ async def create_expression(
|
|||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
|
@ -290,7 +280,7 @@ async def update_expression(
|
|||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data['last_active_time'] = time.time()
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
|
|
@ -301,9 +291,7 @@ async def update_expression(
|
|||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -314,10 +302,7 @@ async def update_expression(
|
|||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
|
|
@ -344,10 +329,7 @@ async def delete_expression(
|
|||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表达方式: {situation}"
|
||||
)
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -357,9 +339,7 @@ async def delete_expression(
|
|||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
|
|
@ -382,10 +362,11 @@ async def get_expression_stats(
|
|||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = Expression.select().where(
|
||||
(Expression.create_date.is_null(False)) &
|
||||
(Expression.create_date >= seven_days_ago)
|
||||
).count()
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -393,8 +374,8 @@ async def get_expression_stats(
|
|||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
}
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
|
|
@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
|
|||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
|
|
@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
|
|||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
|
|
@ -47,7 +50,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
|
|
@ -56,7 +59,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
|
|
@ -65,7 +68,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
|
|
@ -74,7 +77,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
|
|
@ -83,7 +86,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
|
|
@ -92,8 +95,8 @@ class GitMirrorConfig:
|
|||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None
|
||||
}
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -106,7 +109,7 @@ class GitMirrorConfig:
|
|||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
|
|
@ -145,14 +148,14 @@ class GitMirrorConfig:
|
|||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
|
|
@ -182,7 +185,7 @@ class GitMirrorConfig:
|
|||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
|
@ -209,7 +212,7 @@ class GitMirrorConfig:
|
|||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat()
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
|
|
@ -225,7 +228,7 @@ class GitMirrorConfig:
|
|||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
|
@ -279,12 +282,7 @@ class GitMirrorConfig:
|
|||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
timeout: int = 30,
|
||||
config: Optional[GitMirrorConfig] = None
|
||||
):
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
|
|
@ -323,46 +321,25 @@ class GitMirrorService:
|
|||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "系统中未找到 Git,请先安装 Git"
|
||||
}
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(
|
||||
["git", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"path": git_path
|
||||
}
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"Git 命令执行失败: {result.stderr}"
|
||||
}
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "Git 版本检测超时"
|
||||
}
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"检测 Git 时发生错误: {str(e)}"
|
||||
}
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
|
|
@ -371,7 +348,7 @@ class GitMirrorService:
|
|||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
|
@ -403,12 +380,7 @@ class GitMirrorService:
|
|||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
|
|
@ -427,14 +399,12 @@ class GitMirrorService:
|
|||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(
|
||||
owner, repo, branch, file_path, mirror
|
||||
)
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
|
|
@ -445,7 +415,7 @@ class GitMirrorService:
|
|||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
|
@ -461,26 +431,16 @@ class GitMirrorService:
|
|||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror: Dict[str, Any]
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
|
|
@ -508,7 +468,7 @@ class GitMirrorService:
|
|||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
|
|
@ -520,13 +480,7 @@ class GitMirrorService:
|
|||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
|
|
@ -536,7 +490,7 @@ class GitMirrorService:
|
|||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
|
@ -569,12 +523,7 @@ class GitMirrorService:
|
|||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
|
|
@ -582,20 +531,13 @@ class GitMirrorService:
|
|||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(
|
||||
owner, repo, target_path, branch, depth, mirror
|
||||
)
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源克隆均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
|
|
@ -604,7 +546,7 @@ class GitMirrorService:
|
|||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any]
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
|
|
@ -614,12 +556,7 @@ class GitMirrorService:
|
|||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
async def _clone_with_url(
|
||||
self,
|
||||
url: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror_type: str
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
|
|
@ -657,7 +594,7 @@ class GitMirrorService:
|
|||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install"
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
|
@ -670,7 +607,7 @@ class GitMirrorService:
|
|||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300 # 5分钟超时
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
|
@ -683,7 +620,7 @@ class GitMirrorService:
|
|||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default"
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
|
|
@ -710,13 +647,7 @@ class GitMirrorService:
|
|||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set
|
||||
import json
|
||||
|
|
@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
|||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -55,10 +56,10 @@ def setup_production_mode() -> bool:
|
|||
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('application/javascript', '.mjs')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type('application/json', '.json')
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
server = get_global_server()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
|
|
@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
|
|||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
"""人物信息响应"""
|
||||
|
||||
id: int
|
||||
is_known: bool
|
||||
person_id: str
|
||||
|
|
@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
|
|||
|
||||
class PersonListResponse(BaseModel):
|
||||
"""人物列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
|
|
@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
|
|||
|
||||
class PersonDetailResponse(BaseModel):
|
||||
"""人物详情响应"""
|
||||
|
||||
success: bool
|
||||
data: PersonInfoResponse
|
||||
|
||||
|
||||
class PersonUpdateRequest(BaseModel):
|
||||
"""人物信息更新请求"""
|
||||
|
||||
person_name: Optional[str] = None
|
||||
name_reason: Optional[str] = None
|
||||
nickname: Optional[str] = None
|
||||
|
|
@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
|
|||
|
||||
class PersonUpdateResponse(BaseModel):
|
||||
"""人物信息更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[PersonInfoResponse] = None
|
||||
|
|
@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel):
|
|||
|
||||
class PersonDeleteResponse(BaseModel):
|
||||
"""人物删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
|
@ -118,7 +125,7 @@ async def get_person_list(
|
|||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
|
|
@ -143,9 +150,9 @@ async def get_person_list(
|
|||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search)) |
|
||||
(PersonInfo.nickname.contains(search)) |
|
||||
(PersonInfo.user_id.contains(search))
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 已认识状态过滤
|
||||
|
|
@ -159,10 +166,8 @@ async def get_person_list(
|
|||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
query = query.order_by(
|
||||
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
|
||||
PersonInfo.last_know.desc()
|
||||
)
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
|
@ -174,13 +179,7 @@ async def get_person_list(
|
|||
# 转换为响应对象
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -190,10 +189,7 @@ async def get_person_list(
|
|||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
|
|
@ -212,10 +208,7 @@ async def get_person_detail(
|
|||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(
|
||||
success=True,
|
||||
data=person_to_response(person)
|
||||
)
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -225,11 +218,7 @@ async def get_person_detail(
|
|||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
|
|
@ -256,7 +245,7 @@ async def update_person(
|
|||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 更新最后修改时间
|
||||
update_data['last_know'] = time.time()
|
||||
update_data["last_know"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
|
|
@ -267,9 +256,7 @@ async def update_person(
|
|||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=person_to_response(person)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -280,10 +267,7 @@ async def update_person(
|
|||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
|
|
@ -310,10 +294,7 @@ async def delete_person(
|
|||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
|
||||
return PersonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除人物信息: {person_name}"
|
||||
)
|
||||
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -323,9 +304,7 @@ async def delete_person(
|
|||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
|
|
@ -348,15 +327,7 @@ async def get_person_stats(
|
|||
platform = person.platform
|
||||
platforms[platform] = platforms.get(platform, 0) + 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"known": known,
|
||||
"unknown": unknown,
|
||||
"platforms": platforms
|
||||
}
|
||||
}
|
||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set, Dict, Any
|
||||
import json
|
||||
|
|
@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
|
|||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -57,7 +58,7 @@ async def update_progress(
|
|||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
|
|
@ -80,7 +81,7 @@ async def update_progress(
|
|||
"plugin_id": plugin_id,
|
||||
"total_plugins": total_plugins,
|
||||
"loaded_plugins": loaded_plugins,
|
||||
"timestamp": asyncio.get_event_loop().time()
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
await broadcast_progress(progress_data)
|
||||
|
|
|
|||
|
|
@ -30,12 +30,12 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||
(major, minor, patch) 三元组
|
||||
"""
|
||||
# 移除 snapshot 等后缀
|
||||
base_version = version_str.split('.snapshot')[0].split('.dev')[0].split('.alpha')[0].split('.beta')[0]
|
||||
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
||||
|
||||
parts = base_version.split('.')
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
# 补齐到 3 位
|
||||
parts.extend(['0'] * (3 - len(parts)))
|
||||
parts.extend(["0"] * (3 - len(parts)))
|
||||
|
||||
try:
|
||||
major = int(parts[0])
|
||||
|
|
@ -49,8 +49,10 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
|
||||
class FetchRawFileRequest(BaseModel):
|
||||
"""获取 Raw 文件请求"""
|
||||
|
||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||
branch: str = Field(..., description="分支名称", example="main")
|
||||
|
|
@ -61,6 +63,7 @@ class FetchRawFileRequest(BaseModel):
|
|||
|
||||
class FetchRawFileResponse(BaseModel):
|
||||
"""获取 Raw 文件响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
data: Optional[str] = Field(None, description="文件内容")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
|
@ -71,6 +74,7 @@ class FetchRawFileResponse(BaseModel):
|
|||
|
||||
class CloneRepositoryRequest(BaseModel):
|
||||
"""克隆仓库请求"""
|
||||
|
||||
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
|
||||
repo: str = Field(..., description="仓库名称", example="plugin-repo")
|
||||
target_path: str = Field(..., description="目标路径(相对于插件目录)")
|
||||
|
|
@ -82,6 +86,7 @@ class CloneRepositoryRequest(BaseModel):
|
|||
|
||||
class CloneRepositoryResponse(BaseModel):
|
||||
"""克隆仓库响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
path: Optional[str] = Field(None, description="克隆路径")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
|
@ -93,6 +98,7 @@ class CloneRepositoryResponse(BaseModel):
|
|||
|
||||
class MirrorConfigResponse(BaseModel):
|
||||
"""镜像源配置响应"""
|
||||
|
||||
id: str = Field(..., description="镜像源 ID")
|
||||
name: str = Field(..., description="镜像源名称")
|
||||
raw_prefix: str = Field(..., description="Raw 文件前缀")
|
||||
|
|
@ -103,12 +109,14 @@ class MirrorConfigResponse(BaseModel):
|
|||
|
||||
class AvailableMirrorsResponse(BaseModel):
|
||||
"""可用镜像源列表响应"""
|
||||
|
||||
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
|
||||
default_priority: List[str] = Field(..., description="默认优先级顺序(ID 列表)")
|
||||
|
||||
|
||||
class AddMirrorRequest(BaseModel):
|
||||
"""添加镜像源请求"""
|
||||
|
||||
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
|
||||
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
|
||||
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
|
||||
|
|
@ -119,6 +127,7 @@ class AddMirrorRequest(BaseModel):
|
|||
|
||||
class UpdateMirrorRequest(BaseModel):
|
||||
"""更新镜像源请求"""
|
||||
|
||||
name: Optional[str] = Field(None, description="镜像源名称")
|
||||
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
|
||||
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
|
||||
|
|
@ -128,6 +137,7 @@ class UpdateMirrorRequest(BaseModel):
|
|||
|
||||
class GitStatusResponse(BaseModel):
|
||||
"""Git 安装状态响应"""
|
||||
|
||||
installed: bool = Field(..., description="是否已安装 Git")
|
||||
version: Optional[str] = Field(None, description="Git 版本号")
|
||||
path: Optional[str] = Field(None, description="Git 可执行文件路径")
|
||||
|
|
@ -136,6 +146,7 @@ class GitStatusResponse(BaseModel):
|
|||
|
||||
class InstallPluginRequest(BaseModel):
|
||||
"""安装插件请求"""
|
||||
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
repository_url: str = Field(..., description="插件仓库 URL")
|
||||
branch: Optional[str] = Field("main", description="分支名称")
|
||||
|
|
@ -144,6 +155,7 @@ class InstallPluginRequest(BaseModel):
|
|||
|
||||
class VersionResponse(BaseModel):
|
||||
"""麦麦版本响应"""
|
||||
|
||||
version: str = Field(..., description="麦麦版本号")
|
||||
version_major: int = Field(..., description="主版本号")
|
||||
version_minor: int = Field(..., description="次版本号")
|
||||
|
|
@ -152,11 +164,13 @@ class VersionResponse(BaseModel):
|
|||
|
||||
class UninstallPluginRequest(BaseModel):
|
||||
"""卸载插件请求"""
|
||||
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
|
||||
|
||||
class UpdatePluginRequest(BaseModel):
|
||||
"""更新插件请求"""
|
||||
|
||||
plugin_id: str = Field(..., description="插件 ID")
|
||||
repository_url: str = Field(..., description="插件仓库 URL")
|
||||
branch: Optional[str] = Field("main", description="分支名称")
|
||||
|
|
@ -165,6 +179,7 @@ class UpdatePluginRequest(BaseModel):
|
|||
|
||||
# ============ API 路由 ============
|
||||
|
||||
|
||||
@router.get("/version", response_model=VersionResponse)
|
||||
async def get_maimai_version() -> VersionResponse:
|
||||
"""
|
||||
|
|
@ -174,12 +189,7 @@ async def get_maimai_version() -> VersionResponse:
|
|||
"""
|
||||
major, minor, patch = parse_version(MMC_VERSION)
|
||||
|
||||
return VersionResponse(
|
||||
version=MMC_VERSION,
|
||||
version_major=major,
|
||||
version_minor=minor,
|
||||
version_patch=patch
|
||||
)
|
||||
return VersionResponse(version=MMC_VERSION, version_major=major, version_minor=minor, version_patch=patch)
|
||||
|
||||
|
||||
@router.get("/git-status", response_model=GitStatusResponse)
|
||||
|
|
@ -196,9 +206,7 @@ async def check_git_status() -> GitStatusResponse:
|
|||
|
||||
|
||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||
async def get_available_mirrors(
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> AvailableMirrorsResponse:
|
||||
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||
"""
|
||||
获取所有可用的镜像源配置
|
||||
"""
|
||||
|
|
@ -219,22 +227,16 @@ async def get_available_mirrors(
|
|||
raw_prefix=m["raw_prefix"],
|
||||
clone_prefix=m["clone_prefix"],
|
||||
enabled=m["enabled"],
|
||||
priority=m["priority"]
|
||||
priority=m["priority"],
|
||||
)
|
||||
for m in all_mirrors
|
||||
]
|
||||
|
||||
return AvailableMirrorsResponse(
|
||||
mirrors=mirrors,
|
||||
default_priority=config.get_default_priority_list()
|
||||
)
|
||||
return AvailableMirrorsResponse(mirrors=mirrors, default_priority=config.get_default_priority_list())
|
||||
|
||||
|
||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||
async def add_mirror(
|
||||
request: AddMirrorRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> MirrorConfigResponse:
|
||||
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||
"""
|
||||
添加新的镜像源
|
||||
"""
|
||||
|
|
@ -254,7 +256,7 @@ async def add_mirror(
|
|||
raw_prefix=request.raw_prefix,
|
||||
clone_prefix=request.clone_prefix,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
return MirrorConfigResponse(
|
||||
|
|
@ -263,7 +265,7 @@ async def add_mirror(
|
|||
raw_prefix=mirror["raw_prefix"],
|
||||
clone_prefix=mirror["clone_prefix"],
|
||||
enabled=mirror["enabled"],
|
||||
priority=mirror["priority"]
|
||||
priority=mirror["priority"],
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
|
@ -274,9 +276,7 @@ async def add_mirror(
|
|||
|
||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||
async def update_mirror(
|
||||
mirror_id: str,
|
||||
request: UpdateMirrorRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
|
@ -297,7 +297,7 @@ async def update_mirror(
|
|||
raw_prefix=request.raw_prefix,
|
||||
clone_prefix=request.clone_prefix,
|
||||
enabled=request.enabled,
|
||||
priority=request.priority
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
if not mirror:
|
||||
|
|
@ -309,7 +309,7 @@ async def update_mirror(
|
|||
raw_prefix=mirror["raw_prefix"],
|
||||
clone_prefix=mirror["clone_prefix"],
|
||||
enabled=mirror["enabled"],
|
||||
priority=mirror["priority"]
|
||||
priority=mirror["priority"],
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -319,10 +319,7 @@ async def update_mirror(
|
|||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(
|
||||
mirror_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
删除镜像源
|
||||
"""
|
||||
|
|
@ -340,16 +337,12 @@ async def delete_mirror(
|
|||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"未找到镜像源: {mirror_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已删除镜像源: {mirror_id}"
|
||||
}
|
||||
return {"success": True, "message": f"已删除镜像源: {mirror_id}"}
|
||||
|
||||
|
||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||
async def fetch_raw_file(
|
||||
request: FetchRawFileRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
|
||||
) -> FetchRawFileResponse:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
|
@ -376,7 +369,7 @@ async def fetch_raw_file(
|
|||
progress=10,
|
||||
message=f"正在获取插件列表: {request.file_path}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -389,22 +382,19 @@ async def fetch_raw_file(
|
|||
branch=request.branch,
|
||||
file_path=request.file_path,
|
||||
mirror_id=request.mirror_id,
|
||||
custom_url=request.custom_url
|
||||
custom_url=request.custom_url,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
# 更新进度:成功获取
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message="正在解析插件数据...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
stage="loading", progress=70, message="正在解析插件数据...", total_plugins=0, loaded_plugins=0
|
||||
)
|
||||
|
||||
# 尝试解析插件数量
|
||||
try:
|
||||
import json
|
||||
|
||||
data = json.loads(result.get("data", "[]"))
|
||||
total = len(data) if isinstance(data, list) else 0
|
||||
|
||||
|
|
@ -414,16 +404,12 @@ async def fetch_raw_file(
|
|||
progress=100,
|
||||
message=f"成功加载 {total} 个插件",
|
||||
total_plugins=total,
|
||||
loaded_plugins=total
|
||||
loaded_plugins=total,
|
||||
)
|
||||
except Exception:
|
||||
# 如果解析失败,仍然发送成功状态
|
||||
await update_progress(
|
||||
stage="success",
|
||||
progress=100,
|
||||
message="加载完成",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
stage="success", progress=100, message="加载完成", total_plugins=0, loaded_plugins=0
|
||||
)
|
||||
|
||||
return FetchRawFileResponse(**result)
|
||||
|
|
@ -433,12 +419,7 @@ async def fetch_raw_file(
|
|||
|
||||
# 发送错误进度
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="加载失败",
|
||||
error=str(e),
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
stage="error", progress=0, message="加载失败", error=str(e), total_plugins=0, loaded_plugins=0
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
|
@ -446,8 +427,7 @@ async def fetch_raw_file(
|
|||
|
||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||
async def clone_repository(
|
||||
request: CloneRepositoryRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
|
||||
) -> CloneRepositoryResponse:
|
||||
"""
|
||||
克隆 GitHub 仓库到本地
|
||||
|
|
@ -460,9 +440,7 @@ async def clone_repository(
|
|||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(
|
||||
f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}"
|
||||
)
|
||||
logger.info(f"收到克隆仓库请求: {request.owner}/{request.repo} -> {request.target_path}")
|
||||
|
||||
try:
|
||||
# TODO: 验证 target_path 的安全性,防止路径遍历攻击
|
||||
|
|
@ -478,7 +456,7 @@ async def clone_repository(
|
|||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
custom_url=request.custom_url,
|
||||
depth=request.depth
|
||||
depth=request.depth,
|
||||
)
|
||||
|
||||
return CloneRepositoryResponse(**result)
|
||||
|
|
@ -489,10 +467,7 @@ async def clone_repository(
|
|||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(
|
||||
request: InstallPluginRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
安装插件
|
||||
|
||||
|
|
@ -513,16 +488,16 @@ async def install_plugin(
|
|||
progress=5,
|
||||
message=f"开始安装插件: {request.plugin_id}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 1. 解析仓库 URL
|
||||
# repository_url 格式: https://github.com/owner/repo
|
||||
repo_url = request.repository_url.rstrip('/')
|
||||
if repo_url.endswith('.git'):
|
||||
repo_url = request.repository_url.rstrip("/")
|
||||
if repo_url.endswith(".git"):
|
||||
repo_url = repo_url[:-4]
|
||||
|
||||
parts = repo_url.split('/')
|
||||
parts = repo_url.split("/")
|
||||
if len(parts) < 2:
|
||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||
|
||||
|
|
@ -534,7 +509,7 @@ async def install_plugin(
|
|||
progress=10,
|
||||
message=f"解析仓库信息: {owner}/{repo}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 2. 确定插件安装路径
|
||||
|
|
@ -548,10 +523,10 @@ async def install_plugin(
|
|||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message=f"插件已存在",
|
||||
message="插件已存在",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
error="插件已安装,请先卸载"
|
||||
error="插件已安装,请先卸载",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="插件已安装")
|
||||
|
||||
|
|
@ -560,31 +535,26 @@ async def install_plugin(
|
|||
progress=15,
|
||||
message=f"准备克隆到: {target_path}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 3. 克隆仓库(这里会自动推送 20%-80% 的进度)
|
||||
service = get_git_mirror_service()
|
||||
|
||||
# 如果是 GitHub 仓库,使用镜像源
|
||||
if 'github.com' in repo_url:
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=target_path,
|
||||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
depth=1 # 浅克隆,节省时间和空间
|
||||
depth=1, # 浅克隆,节省时间和空间
|
||||
)
|
||||
else:
|
||||
# 自定义仓库,直接使用 URL
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=target_path,
|
||||
branch=request.branch,
|
||||
custom_url=repo_url,
|
||||
depth=1
|
||||
owner=owner, repo=repo, target_path=target_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
|
|
@ -595,23 +565,20 @@ async def install_plugin(
|
|||
message="克隆仓库失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
error=error_msg
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 4. 验证插件完整性
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=85,
|
||||
message="验证插件文件...",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
stage="loading", progress=85, message="验证插件文件...", operation="install", plugin_id=request.plugin_id
|
||||
)
|
||||
|
||||
manifest_path = target_path / "_manifest.json"
|
||||
if not manifest_path.exists():
|
||||
# 清理失败的安装
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
await update_progress(
|
||||
|
|
@ -620,26 +587,23 @@ async def install_plugin(
|
|||
message="插件缺少 _manifest.json",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
error="无效的插件格式"
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
# 5. 读取并验证 manifest
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=90,
|
||||
message="读取插件配置...",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="读取插件配置...", operation="install", plugin_id=request.plugin_id
|
||||
)
|
||||
|
||||
try:
|
||||
import json as json_module
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
|
||||
# 基本验证
|
||||
required_fields = ['manifest_version', 'name', 'version', 'author']
|
||||
required_fields = ["manifest_version", "name", "version", "author"]
|
||||
for field in required_fields:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
|
@ -647,6 +611,7 @@ async def install_plugin(
|
|||
except Exception as e:
|
||||
# 清理失败的安装
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
await update_progress(
|
||||
|
|
@ -655,7 +620,7 @@ async def install_plugin(
|
|||
message="_manifest.json 无效",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
||||
|
|
@ -665,16 +630,16 @@ async def install_plugin(
|
|||
progress=100,
|
||||
message=f"成功安装插件: {manifest['name']} v{manifest['version']}",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件安装成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_name": manifest['name'],
|
||||
"version": manifest['version'],
|
||||
"path": str(target_path)
|
||||
"plugin_name": manifest["name"],
|
||||
"version": manifest["version"],
|
||||
"path": str(target_path),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
@ -688,7 +653,7 @@ async def install_plugin(
|
|||
message="安装失败",
|
||||
operation="install",
|
||||
plugin_id=request.plugin_id,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
|
@ -696,8 +661,7 @@ async def install_plugin(
|
|||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(
|
||||
request: UninstallPluginRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
卸载插件
|
||||
|
|
@ -719,7 +683,7 @@ async def uninstall_plugin(
|
|||
progress=10,
|
||||
message=f"开始卸载插件: {request.plugin_id}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否存在
|
||||
|
|
@ -733,7 +697,7 @@ async def uninstall_plugin(
|
|||
message="插件不存在",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
error="插件未安装或已被删除"
|
||||
error="插件未安装或已被删除",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
|
|
@ -742,7 +706,7 @@ async def uninstall_plugin(
|
|||
progress=30,
|
||||
message=f"正在删除插件文件: {plugin_path}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 2. 读取插件信息(用于日志)
|
||||
|
|
@ -752,7 +716,8 @@ async def uninstall_plugin(
|
|||
if manifest_path.exists():
|
||||
try:
|
||||
import json as json_module
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
plugin_name = manifest.get("name", request.plugin_id)
|
||||
except Exception:
|
||||
|
|
@ -763,7 +728,7 @@ async def uninstall_plugin(
|
|||
progress=50,
|
||||
message=f"正在删除 {plugin_name}...",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除插件目录
|
||||
|
|
@ -773,6 +738,7 @@ async def uninstall_plugin(
|
|||
def remove_readonly(func, path, _):
|
||||
"""清除只读属性并删除文件"""
|
||||
import os
|
||||
|
||||
os.chmod(path, stat.S_IWRITE)
|
||||
func(path)
|
||||
|
||||
|
|
@ -786,15 +752,10 @@ async def uninstall_plugin(
|
|||
progress=100,
|
||||
message=f"成功卸载插件: {plugin_name}",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "插件卸载成功",
|
||||
"plugin_id": request.plugin_id,
|
||||
"plugin_name": plugin_name
|
||||
}
|
||||
return {"success": True, "message": "插件卸载成功", "plugin_id": request.plugin_id, "plugin_name": plugin_name}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -807,7 +768,7 @@ async def uninstall_plugin(
|
|||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
error="权限不足,无法删除插件文件"
|
||||
error="权限不足,无法删除插件文件",
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail="权限不足,无法删除插件文件") from e
|
||||
|
|
@ -820,17 +781,14 @@ async def uninstall_plugin(
|
|||
message="卸载失败",
|
||||
operation="uninstall",
|
||||
plugin_id=request.plugin_id,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(
|
||||
request: UpdatePluginRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件
|
||||
|
||||
|
|
@ -851,7 +809,7 @@ async def update_plugin(
|
|||
progress=5,
|
||||
message=f"开始更新插件: {request.plugin_id}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 1. 检查插件是否已安装
|
||||
|
|
@ -865,7 +823,7 @@ async def update_plugin(
|
|||
message="插件不存在",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
error="插件未安装,请先安装"
|
||||
error="插件未安装,请先安装",
|
||||
)
|
||||
raise HTTPException(status_code=404, detail="插件未安装")
|
||||
|
||||
|
|
@ -877,10 +835,11 @@ async def update_plugin(
|
|||
if manifest_path.exists():
|
||||
try:
|
||||
import json as json_module
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
old_version = manifest.get("version", "unknown")
|
||||
plugin_name = manifest.get("name", request.plugin_id)
|
||||
_plugin_name = manifest.get("name", request.plugin_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
@ -889,16 +848,12 @@ async def update_plugin(
|
|||
progress=10,
|
||||
message=f"当前版本: {old_version},准备更新...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
# 3. 删除旧版本
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=20,
|
||||
message="正在删除旧版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
stage="loading", progress=20, message="正在删除旧版本...", operation="update", plugin_id=request.plugin_id
|
||||
)
|
||||
|
||||
import shutil
|
||||
|
|
@ -907,6 +862,7 @@ async def update_plugin(
|
|||
def remove_readonly(func, path, _):
|
||||
"""清除只读属性并删除文件"""
|
||||
import os
|
||||
|
||||
os.chmod(path, stat.S_IWRITE)
|
||||
func(path)
|
||||
|
||||
|
|
@ -920,14 +876,14 @@ async def update_plugin(
|
|||
progress=30,
|
||||
message="正在准备下载新版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
repo_url = request.repository_url.rstrip('/')
|
||||
if repo_url.endswith('.git'):
|
||||
repo_url = request.repository_url.rstrip("/")
|
||||
if repo_url.endswith(".git"):
|
||||
repo_url = repo_url[:-4]
|
||||
|
||||
parts = repo_url.split('/')
|
||||
parts = repo_url.split("/")
|
||||
if len(parts) < 2:
|
||||
raise HTTPException(status_code=400, detail="无效的仓库 URL")
|
||||
|
||||
|
|
@ -937,23 +893,18 @@ async def update_plugin(
|
|||
# 5. 克隆新版本(这里会推送 35%-85% 的进度)
|
||||
service = get_git_mirror_service()
|
||||
|
||||
if 'github.com' in repo_url:
|
||||
if "github.com" in repo_url:
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=plugin_path,
|
||||
branch=request.branch,
|
||||
mirror_id=request.mirror_id,
|
||||
depth=1
|
||||
depth=1,
|
||||
)
|
||||
else:
|
||||
result = await service.clone_repository(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
target_path=plugin_path,
|
||||
branch=request.branch,
|
||||
custom_url=repo_url,
|
||||
depth=1
|
||||
owner=owner, repo=repo, target_path=plugin_path, branch=request.branch, custom_url=repo_url, depth=1
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
|
|
@ -964,17 +915,13 @@ async def update_plugin(
|
|||
message="下载新版本失败",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
error=error_msg
|
||||
error=error_msg,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=error_msg)
|
||||
|
||||
# 6. 验证新版本
|
||||
await update_progress(
|
||||
stage="loading",
|
||||
progress=90,
|
||||
message="验证新版本...",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
stage="loading", progress=90, message="验证新版本...", operation="update", plugin_id=request.plugin_id
|
||||
)
|
||||
|
||||
new_manifest_path = plugin_path / "_manifest.json"
|
||||
|
|
@ -983,6 +930,7 @@ async def update_plugin(
|
|||
def remove_readonly(func, path, _):
|
||||
"""清除只读属性并删除文件"""
|
||||
import os
|
||||
|
||||
os.chmod(path, stat.S_IWRITE)
|
||||
func(path)
|
||||
|
||||
|
|
@ -994,13 +942,13 @@ async def update_plugin(
|
|||
message="新版本缺少 _manifest.json",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
error="无效的插件格式"
|
||||
error="无效的插件格式",
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="无效的插件:缺少 _manifest.json")
|
||||
|
||||
# 7. 读取新版本信息
|
||||
try:
|
||||
with open(new_manifest_path, 'r', encoding='utf-8') as f:
|
||||
with open(new_manifest_path, "r", encoding="utf-8") as f:
|
||||
new_manifest = json_module.load(f)
|
||||
|
||||
new_version = new_manifest.get("version", "unknown")
|
||||
|
|
@ -1014,7 +962,7 @@ async def update_plugin(
|
|||
progress=100,
|
||||
message=f"成功更新 {new_name}: {old_version} → {new_version}",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id
|
||||
plugin_id=request.plugin_id,
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
@ -1023,7 +971,7 @@ async def update_plugin(
|
|||
"plugin_id": request.plugin_id,
|
||||
"plugin_name": new_name,
|
||||
"old_version": old_version,
|
||||
"new_version": new_version
|
||||
"new_version": new_version,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -1036,7 +984,7 @@ async def update_plugin(
|
|||
message="_manifest.json 无效",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"无效的 _manifest.json: {e}") from e
|
||||
|
||||
|
|
@ -1046,21 +994,14 @@ async def update_plugin(
|
|||
logger.error(f"更新插件失败: {e}", exc_info=True)
|
||||
|
||||
await update_progress(
|
||||
stage="error",
|
||||
progress=0,
|
||||
message="更新失败",
|
||||
operation="update",
|
||||
plugin_id=request.plugin_id,
|
||||
error=str(e)
|
||||
stage="error", progress=0, message="更新失败", operation="update", plugin_id=request.plugin_id, error=str(e)
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(
|
||||
authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取已安装的插件列表
|
||||
|
||||
|
|
@ -1081,10 +1022,7 @@ async def get_installed_plugins(
|
|||
if not plugins_dir.exists():
|
||||
logger.info("插件目录不存在,创建目录")
|
||||
plugins_dir.mkdir(exist_ok=True)
|
||||
return {
|
||||
"success": True,
|
||||
"plugins": []
|
||||
}
|
||||
return {"success": True, "plugins": []}
|
||||
|
||||
installed_plugins = []
|
||||
|
||||
|
|
@ -1098,7 +1036,7 @@ async def get_installed_plugins(
|
|||
plugin_id = plugin_path.name
|
||||
|
||||
# 跳过隐藏目录和特殊目录
|
||||
if plugin_id.startswith('.') or plugin_id.startswith('__'):
|
||||
if plugin_id.startswith(".") or plugin_id.startswith("__"):
|
||||
continue
|
||||
|
||||
# 读取 _manifest.json
|
||||
|
|
@ -1110,20 +1048,23 @@ async def get_installed_plugins(
|
|||
|
||||
try:
|
||||
import json as json_module
|
||||
with open(manifest_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json_module.load(f)
|
||||
|
||||
# 基本验证
|
||||
if 'name' not in manifest or 'version' not in manifest:
|
||||
if "name" not in manifest or "version" not in manifest:
|
||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 格式无效,跳过")
|
||||
continue
|
||||
|
||||
# 添加到已安装列表(返回完整的 manifest 信息)
|
||||
installed_plugins.append({
|
||||
"id": plugin_id,
|
||||
"manifest": manifest, # 返回完整的 manifest 对象
|
||||
"path": str(plugin_path.absolute())
|
||||
})
|
||||
installed_plugins.append(
|
||||
{
|
||||
"id": plugin_id,
|
||||
"manifest": manifest, # 返回完整的 manifest 对象
|
||||
"path": str(plugin_path.absolute()),
|
||||
}
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"插件 {plugin_id} 的 _manifest.json 解析失败: {e}")
|
||||
|
|
@ -1134,11 +1075,7 @@ async def get_installed_plugins(
|
|||
|
||||
logger.info(f"找到 {len(installed_plugins)} 个已安装插件")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"plugins": installed_plugins,
|
||||
"total": len(installed_plugins)
|
||||
}
|
||||
return {"success": True, "plugins": installed_plugins, "total": len(installed_plugins)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -19,12 +20,14 @@ _start_time = time.time()
|
|||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""状态响应"""
|
||||
|
||||
running: bool
|
||||
uptime: float
|
||||
version: str
|
||||
|
|
@ -52,15 +55,9 @@ async def restart_maibot():
|
|||
# 但我们仍然返回它以保持 API 一致性
|
||||
os.execv(python, args)
|
||||
|
||||
return RestartResponse(
|
||||
success=True,
|
||||
message="麦麦正在重启中..."
|
||||
)
|
||||
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"重启失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
|
|
@ -77,20 +74,15 @@ async def get_maibot_status():
|
|||
version = MMC_VERSION # 可以从配置或常量中读取
|
||||
|
||||
return StatusResponse(
|
||||
running=True,
|
||||
uptime=uptime,
|
||||
version=version,
|
||||
start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取状态失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
"""
|
||||
|
|
@ -102,7 +94,4 @@ async def reload_config():
|
|||
# 这里需要调用主程序的配置重载函数
|
||||
# 示例:await app_instance.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置重载功能待实现"
|
||||
}
|
||||
return {"success": True, "message": "配置重载功能待实现"}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
|
|
@ -38,28 +39,33 @@ router.include_router(system_router)
|
|||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
|
|
@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
|
|||
|
||||
class FirstSetupStatusResponse(BaseModel):
|
||||
"""首次配置状态响应"""
|
||||
|
||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||
message: str = Field(..., description="状态消息")
|
||||
|
||||
|
||||
class CompleteSetupResponse(BaseModel):
|
||||
"""完成配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
class ResetSetupResponse(BaseModel):
|
||||
"""重置配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
|
@ -105,25 +114,16 @@ async def verify_token(request: TokenVerifyRequest):
|
|||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(
|
||||
valid=True,
|
||||
message="Token 验证成功"
|
||||
)
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||
else:
|
||||
return TokenVerifyResponse(
|
||||
valid=False,
|
||||
message="Token 无效或已过期"
|
||||
)
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
|
|
@ -148,10 +148,7 @@ async def update_token(
|
|||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
return TokenUpdateResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -184,11 +181,7 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
|||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
return TokenRegenerateResponse(
|
||||
success=True,
|
||||
token=new_token,
|
||||
message="Token 已重新生成"
|
||||
)
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -221,10 +214,7 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
|||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
return FirstSetupStatusResponse(
|
||||
is_first_setup=is_first,
|
||||
message="首次配置" if is_first else "已完成配置"
|
||||
)
|
||||
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -257,10 +247,7 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
|||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
return CompleteSetupResponse(
|
||||
success=success,
|
||||
message="配置已完成" if success else "标记失败"
|
||||
)
|
||||
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -293,10 +280,7 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
|
|||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
return ResetSetupResponse(
|
||||
success=success,
|
||||
message="配置状态已重置" if success else "重置失败"
|
||||
)
|
||||
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
|
|
@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
|
|||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
|
|
@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
|
|||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
|
|
@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
|
|||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
|
|
@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
|
|||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
|
|
@ -88,7 +93,7 @@ async def get_dashboard_data(hours: int = 24):
|
|||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
|
|
@ -100,11 +105,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||
summary = StatisticsSummary()
|
||||
|
||||
# 查询 LLM 使用记录
|
||||
llm_records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
total_time_cost = 0.0
|
||||
time_cost_count = 0
|
||||
|
|
@ -124,11 +125,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||
|
||||
# 查询在线时间
|
||||
online_records = list(
|
||||
OnlineTime.select()
|
||||
.where(
|
||||
(OnlineTime.start_timestamp >= start_time) |
|
||||
(OnlineTime.end_timestamp >= start_time)
|
||||
)
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||
)
|
||||
|
||||
for record in online_records:
|
||||
|
|
@ -139,9 +136,7 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||
|
||||
# 查询消息数量
|
||||
messages = list(
|
||||
Messages.select()
|
||||
.where(Messages.time >= start_time.timestamp())
|
||||
.where(Messages.time <= end_time.timestamp())
|
||||
Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp())
|
||||
)
|
||||
|
||||
summary.total_messages = len(messages)
|
||||
|
|
@ -159,38 +154,32 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
|||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据"""
|
||||
model_data = defaultdict(lambda: {
|
||||
'request_count': 0,
|
||||
'total_cost': 0.0,
|
||||
'total_tokens': 0,
|
||||
'time_costs': []
|
||||
})
|
||||
model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
)
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time))
|
||||
|
||||
for record in records:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
model_data[model_name]['request_count'] += 1
|
||||
model_data[model_name]['total_cost'] += record.cost or 0.0
|
||||
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
model_data[model_name]["request_count"] += 1
|
||||
model_data[model_name]["total_cost"] += record.cost or 0.0
|
||||
model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
if record.time_cost and record.time_cost > 0:
|
||||
model_data[model_name]['time_costs'].append(record.time_cost)
|
||||
model_data[model_name]["time_costs"].append(record.time_cost)
|
||||
|
||||
# 转换为列表并排序
|
||||
result = []
|
||||
for model_name, data in model_data.items():
|
||||
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
|
||||
result.append(ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=data['request_count'],
|
||||
total_cost=data['total_cost'],
|
||||
total_tokens=data['total_tokens'],
|
||||
avg_response_time=avg_time
|
||||
))
|
||||
avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=data["request_count"],
|
||||
total_cost=data["total_cost"],
|
||||
total_tokens=data["total_tokens"],
|
||||
avg_response_time=avg_time,
|
||||
)
|
||||
)
|
||||
|
||||
# 按请求数排序
|
||||
result.sort(key=lambda x: x.request_count, reverse=True)
|
||||
|
|
@ -200,35 +189,28 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
|||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据"""
|
||||
# 创建小时桶
|
||||
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
for record in records:
|
||||
# 获取小时键(去掉分钟和秒)
|
||||
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
||||
hour_str = hour_key.isoformat()
|
||||
|
||||
hourly_buckets[hour_str]['requests'] += 1
|
||||
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
|
||||
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
hourly_buckets[hour_str]["requests"] += 1
|
||||
hourly_buckets[hour_str]["cost"] += record.cost or 0.0
|
||||
hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.isoformat()
|
||||
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||
)
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
|
@ -236,35 +218,28 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
|
|||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据"""
|
||||
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
for record in records:
|
||||
# 获取日期键
|
||||
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
day_str = day_key.isoformat()
|
||||
|
||||
daily_buckets[day_str]['requests'] += 1
|
||||
daily_buckets[day_str]['cost'] += record.cost or 0.0
|
||||
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
daily_buckets[day_str]["requests"] += 1
|
||||
daily_buckets[day_str]["cost"] += record.cost or 0.0
|
||||
daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.isoformat()
|
||||
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||
)
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
|
@ -272,23 +247,21 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
|
|||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.order_by(LLMUsage.timestamp.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append({
|
||||
'timestamp': record.timestamp.isoformat(),
|
||||
'model': record.model_assign_name or record.model_name,
|
||||
'request_type': record.request_type,
|
||||
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
'cost': record.cost or 0.0,
|
||||
'time_cost': record.time_cost or 0.0,
|
||||
'status': record.status
|
||||
})
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": record.status,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class TokenManager:
|
|||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False # 标记首次配置未完成
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
self._save_config(config)
|
||||
|
|
@ -91,6 +91,7 @@ class TokenManager:
|
|||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
|
|
|
|||
Loading…
Reference in New Issue