refactor(webui): migrate jargon routes from Peewee to SQLModel

- 完全迁移到 SQLModel
- chat_id → session_id 映射
- ChatStreams → ChatSession 替代
- 移除 is_global 字段
- 使用 group_id 替代 group_name
pull/1496/head
DrSmoothl 2026-02-17 19:58:21 +08:00
parent 7da0811b5c
commit 390d1daefd
No known key found for this signature in database
1 changed files with 159 additions and 166 deletions

View File

@ -1,13 +1,16 @@
"""黑话(俚语)管理路由""" """黑话(俚语)管理路由"""
import json from typing import Annotated, Any, List, Optional
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func as fn from sqlalchemy import func as fn
from sqlmodel import Session, col, delete, select
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Jargon
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon") logger = get_logger("webui.jargon")
@ -43,7 +46,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
return [chat_id_str] return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str: def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
""" """
获取 chat_id 的显示名称 获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称 尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
@ -51,19 +54,18 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
stream_ids = parse_chat_id_to_stream_ids(chat_id_str) stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids: if not stream_ids:
return chat_id_str return chat_id_str[:20]
# 查询所有 stream_id 对应的名称 stream_id = stream_ids[0]
names = [] chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str if not chat_session:
return stream_id[:20]
if chat_session.group_id:
return str(chat_session.group_id)
return chat_session.session_id[:20]
# ==================== 请求/响应模型 ==================== # ==================== 请求/响应模型 ====================
@ -79,7 +81,6 @@ class JargonResponse(BaseModel):
chat_id: str chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配 stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示 chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0 count: int = 0
is_jargon: Optional[bool] = None is_jargon: Optional[bool] = None
is_complete: bool = False is_complete: bool = False
@ -94,7 +95,7 @@ class JargonListResponse(BaseModel):
total: int total: int
page: int page: int
page_size: int page_size: int
data: List[JargonResponse] data: List[dict[str, Any]]
class JargonDetailResponse(BaseModel): class JargonDetailResponse(BaseModel):
@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel):
raw_content: Optional[str] = Field(None, description="原始内容") raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义") meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID") chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel): class JargonUpdateRequest(BaseModel):
@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel):
raw_content: Optional[str] = None raw_content: Optional[str] = None
meaning: Optional[str] = None meaning: Optional[str] = None
chat_id: Optional[str] = None chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None is_jargon: Optional[bool] = None
@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel):
"""黑话统计响应""" """黑话统计响应"""
success: bool = True success: bool = True
data: dict data: dict[str, Any]
class ChatInfoResponse(BaseModel): class ChatInfoResponse(BaseModel):
@ -181,27 +180,24 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ==================== # ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict: def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"""将 Jargon ORM 对象转换为字典""" """将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id chat_id = jargon.session_id or ""
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return { return {
"id": jargon.id, "id": jargon.id,
"content": jargon.content, "content": jargon.content,
"raw_content": jargon.raw_content, "raw_content": jargon.raw_content,
"meaning": jargon.meaning, "meaning": jargon.meaning,
"chat_id": jargon.chat_id, "chat_id": chat_id,
"stream_id": stream_id, "stream_id": jargon.session_id,
"chat_name": chat_name, "chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count, "count": jargon.count,
"is_jargon": jargon.is_jargon, "is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete, "is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context, "inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only, "inference_content_only": jargon.inference_with_content_only,
} }
@ -215,49 +211,41 @@ async def get_jargon_list(
search: Optional[str] = Query(None, description="搜索关键词"), search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"), chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"), is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
): ):
"""获取黑话列表""" """获取黑话列表"""
try: try:
# 构建查询 statement = select(Jargon)
query = Jargon.select() count_statement = select(fn.count()).select_from(Jargon)
# 搜索过滤
if search: if search:
query = query.where( search_filter = (
(Jargon.content.contains(search)) (col(Jargon.content).contains(search))
| (Jargon.meaning.contains(search)) | (col(Jargon.meaning).contains(search))
| (Jargon.raw_content.contains(search)) | (col(Jargon.raw_content).contains(search))
) )
statement = statement.where(search_filter)
count_statement = count_statement.where(search_filter)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id: if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id) stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids: if stream_ids:
# 使用第一个 stream_id 进行模糊匹配 chat_filter = col(Jargon.session_id).contains(stream_ids[0])
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
else: else:
# 如果无法解析,使用精确匹配 chat_filter = col(Jargon.session_id) == chat_id
query = query.where(Jargon.chat_id == chat_id) statement = statement.where(chat_filter)
count_statement = count_statement.where(chat_filter)
# 按是否是黑话筛选
if is_jargon is not None: if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon) statement = statement.where(col(Jargon.is_jargon) == is_jargon)
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
# 按是否全局筛选 statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
if is_global is not None: statement = statement.offset((page - 1) * page_size).limit(page_size)
query = query.where(Jargon.is_global == is_global)
# 获取总数 with get_db_session() as session:
total = query.count() total = session.exec(count_statement).one()
jargons = session.exec(statement).all()
# 分页和排序(按使用次数降序) data = [jargon_to_dict(jargon, session) for jargon in jargons]
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
return JargonListResponse( return JargonListResponse(
success=True, success=True,
@ -276,10 +264,9 @@ async def get_jargon_list(
async def get_chat_list(): async def get_chat_list():
"""获取所有有黑话记录的聊天列表""" """获取所有有黑话记录的聊天列表"""
try: try:
# 获取所有不同的 chat_id with get_db_session() as session:
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)) statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重 # 用于按 stream_id 去重
seen_stream_ids: set[str] = set() seen_stream_ids: set[str] = set()
@ -290,27 +277,28 @@ async def get_chat_list():
seen_stream_ids.add(stream_ids[0]) seen_stream_ids.add(stream_ids[0])
result = [] result = []
for stream_id in seen_stream_ids: with get_db_session() as session:
# 尝试从 ChatStreams 表获取聊天名称 for stream_id in seen_stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id) chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if chat_stream: if chat_session:
result.append( chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
ChatInfoResponse( result.append(
chat_id=stream_id, # 使用 stream_id方便筛选匹配 ChatInfoResponse(
chat_name=chat_stream.group_name or stream_id, chat_id=stream_id,
platform=chat_stream.platform, chat_name=chat_name,
is_group=True, platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
) )
) else:
else: result.append(
result.append( ChatInfoResponse(
ChatInfoResponse( chat_id=stream_id,
chat_id=stream_id, # 使用 stream_id chat_name=stream_id[:20],
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id, platform=None,
platform=None, is_group=False,
is_group=False, )
) )
)
return ChatListResponse(success=True, data=result) return ChatListResponse(success=True, data=result)
@ -323,35 +311,35 @@ async def get_chat_list():
async def get_jargon_stats(): async def get_jargon_stats():
"""获取黑话统计数据""" """获取黑话统计数据"""
try: try:
# 总数量 with get_db_session() as session:
total = Jargon.select().count() total = session.exec(select(fn.count()).select_from(Jargon)).one()
# 已确认是黑话的数量 confirmed_jargon = session.exec(
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count() select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
# 已确认不是黑话的数量 complete_count = session.exec(
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count() select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
).one()
# 未判定的数量 chat_count = session.exec(
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count() select(fn.count()).select_from(
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
)
).one()
# 全局黑话数量 top_chats = session.exec(
global_count = Jargon.select().where(Jargon.is_global).count() select(col(Jargon.session_id), fn.count().label("count"))
.where(col(Jargon.session_id).is_not(None))
# 已完成推断的数量 .group_by(col(Jargon.session_id))
complete_count = Jargon.select().where(Jargon.is_complete).count() .order_by(fn.count().desc())
.limit(5)
# 关联的聊天数量 ).all()
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count() top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
return JargonStatsResponse( return JargonStatsResponse(
success=True, success=True,
@ -360,7 +348,6 @@ async def get_jargon_stats():
"confirmed_jargon": confirmed_jargon, "confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon, "confirmed_not_jargon": confirmed_not_jargon,
"pending": pending, "pending": pending,
"global_count": global_count,
"complete_count": complete_count, "complete_count": complete_count,
"chat_count": chat_count, "chat_count": chat_count,
"top_chats": top_chats_dict, "top_chats": top_chats_dict,
@ -376,11 +363,13 @@ async def get_jargon_stats():
async def get_jargon_detail(jargon_id: int): async def get_jargon_detail(jargon_id: int):
"""获取黑话详情""" """获取黑话详情"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon)) return JargonDetailResponse(success=True, data=data)
except HTTPException: except HTTPException:
raise raise
@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int):
async def create_jargon(request: JargonCreateRequest): async def create_jargon(request: JargonCreateRequest):
"""创建黑话""" """创建黑话"""
try: try:
# 检查是否已存在相同内容的黑话 with get_db_session() as session:
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)) existing = session.exec(
if existing: select(Jargon).where(
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") (col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
)
).first()
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话 jargon = Jargon(
jargon = Jargon.create( content=request.content,
content=request.content, raw_content=request.raw_content,
raw_content=request.raw_content, meaning=request.meaning or "",
meaning=request.meaning, session_id=request.chat_id,
chat_id=request.chat_id, count=0,
is_global=request.is_global, is_jargon=None,
count=0, is_complete=False,
is_jargon=None, )
is_complete=False, session.add(jargon)
) session.flush()
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}") logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonCreateResponse( return JargonCreateResponse(success=True, message="创建成功", data=data)
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
except HTTPException: except HTTPException:
raise raise
@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest):
async def update_jargon(jargon_id: int, request: JargonUpdateRequest): async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)""" """更新黑话(增量更新)"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段 update_data = request.model_dump(exclude_unset=True)
update_data = request.model_dump(exclude_unset=True) if update_data:
if update_data: for field, value in update_data.items():
for field, value in update_data.items(): if field == "is_global":
if value is not None or field in ["meaning", "raw_content", "is_jargon"]: continue
setattr(jargon, field, value) if field == "chat_id":
jargon.save() jargon.session_id = value
continue
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
session.add(jargon)
logger.info(f"更新黑话成功: id={jargon_id}") logger.info(f"更新黑话成功: id={jargon_id}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonUpdateResponse( return JargonUpdateResponse(success=True, message="更新成功", data=data)
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
except HTTPException: except HTTPException:
raise raise
@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
async def delete_jargon(jargon_id: int): async def delete_jargon(jargon_id: int):
"""删除黑话""" """删除黑话"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content content = jargon.content
jargon.delete_instance() session.delete(jargon)
logger.info(f"删除黑话成功: id={jargon_id}, content={content}") logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse( return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1)
success=True,
message="删除成功",
deleted_count=1,
)
except HTTPException: except HTTPException:
raise raise
@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
if not request.ids: if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute() with get_db_session() as session:
result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录") logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse( return JargonDeleteResponse(
success=True, success=True,
@ -516,14 +507,16 @@ async def batch_set_jargon_status(
if not ids: if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute() with get_db_session() as session:
jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all()
for jargon in jargons:
jargon.is_jargon = is_jargon
session.add(jargon)
updated_count = len(jargons)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}") logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse( return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态")
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
except HTTPException: except HTTPException:
raise raise