diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index cca15c3b..3f4a16d7 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -1,13 +1,16 @@ """黑话(俚语)管理路由""" -import json -from typing import Optional, List, Annotated +from typing import Annotated, Any, List, Optional from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel, Field from sqlalchemy import func as fn +from sqlmodel import Session, col, delete, select +import json + +from src.common.database.database import get_db_session +from src.common.database.database_model import ChatSession, Jargon from src.common.logger import get_logger -from src.common.database.database_model import Jargon, ChatStreams logger = get_logger("webui.jargon") @@ -43,7 +46,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]: return [chat_id_str] -def get_display_name_for_chat_id(chat_id_str: str) -> str: +def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str: """ 获取 chat_id 的显示名称 尝试解析 JSON 并查询 ChatStreams 表获取群聊名称 @@ -51,19 +54,18 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str: stream_ids = parse_chat_id_to_stream_ids(chat_id_str) if not stream_ids: - return chat_id_str + return chat_id_str[:20] - # 查询所有 stream_id 对应的名称 - names = [] - for stream_id in stream_ids: - chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id) - if chat_stream and chat_stream.group_name: - names.append(chat_stream.group_name) - else: - # 如果没找到,显示截断的 stream_id - names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id) + stream_id = stream_ids[0] + chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first() - return ", ".join(names) if names else chat_id_str + if not chat_session: + return stream_id[:20] + + if chat_session.group_id: + return str(chat_session.group_id) + + return chat_session.session_id[:20] # ==================== 请求/响应模型 ==================== @@ -79,7 +81,6 @@ class JargonResponse(BaseModel): chat_id: str stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配 chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示 - is_global: bool = False count: int = 0 is_jargon: Optional[bool] = None is_complete: bool = False @@ -94,7 +95,7 @@ class JargonListResponse(BaseModel): total: int page: int page_size: int - data: List[JargonResponse] + data: List[dict[str, Any]] class JargonDetailResponse(BaseModel): @@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel): raw_content: Optional[str] = Field(None, description="原始内容") meaning: Optional[str] = Field(None, description="含义") chat_id: str = Field(..., description="聊天ID") - is_global: bool = Field(False, description="是否全局") class JargonUpdateRequest(BaseModel): @@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel): raw_content: Optional[str] = None meaning: Optional[str] = None chat_id: Optional[str] = None - is_global: Optional[bool] = None is_jargon: Optional[bool] = None @@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel): """黑话统计响应""" success: bool = True - data: dict + data: dict[str, Any] class ChatInfoResponse(BaseModel): @@ -181,27 +180,24 @@ class ChatListResponse(BaseModel): # ==================== 工具函数 ==================== -def jargon_to_dict(jargon: Jargon) -> dict: +def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]: """将 Jargon ORM 对象转换为字典""" - # 解析 chat_id 获取显示名称和 stream_id - chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None - stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else [] - stream_id = stream_ids[0] if stream_ids else None + chat_id = jargon.session_id or "" + chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None return { "id": jargon.id, "content": jargon.content, "raw_content": jargon.raw_content, "meaning": jargon.meaning, - "chat_id": jargon.chat_id, - "stream_id": stream_id, + "chat_id": chat_id, + "stream_id": jargon.session_id, "chat_name": chat_name, - "is_global": jargon.is_global, "count": jargon.count, "is_jargon": jargon.is_jargon, "is_complete": jargon.is_complete, "inference_with_context": jargon.inference_with_context, - "inference_content_only": jargon.inference_content_only, + "inference_content_only": jargon.inference_with_content_only, } @@ -215,49 +211,41 @@ async def get_jargon_list( search: Optional[str] = Query(None, description="搜索关键词"), chat_id: Optional[str] = Query(None, description="按聊天ID筛选"), is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"), - is_global: Optional[bool] = Query(None, description="按是否全局筛选"), ): """获取黑话列表""" try: - # 构建查询 - query = Jargon.select() + statement = select(Jargon) + count_statement = select(fn.count()).select_from(Jargon) - # 搜索过滤 if search: - query = query.where( - (Jargon.content.contains(search)) - | (Jargon.meaning.contains(search)) - | (Jargon.raw_content.contains(search)) + search_filter = ( + (col(Jargon.content).contains(search)) + | (col(Jargon.meaning).contains(search)) + | (col(Jargon.raw_content).contains(search)) ) + statement = statement.where(search_filter) + count_statement = count_statement.where(search_filter) - # 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式) if chat_id: - # 从传入的 chat_id 中解析出 stream_id stream_ids = parse_chat_id_to_stream_ids(chat_id) if stream_ids: - # 使用第一个 stream_id 进行模糊匹配 - query = query.where(Jargon.chat_id.contains(stream_ids[0])) + chat_filter = col(Jargon.session_id).contains(stream_ids[0]) else: - # 如果无法解析,使用精确匹配 - query = query.where(Jargon.chat_id == chat_id) + chat_filter = col(Jargon.session_id) == chat_id + statement = statement.where(chat_filter) + count_statement = count_statement.where(chat_filter) - # 按是否是黑话筛选 if is_jargon is not None: - query = query.where(Jargon.is_jargon == is_jargon) + statement = statement.where(col(Jargon.is_jargon) == is_jargon) + count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon) - # 按是否全局筛选 - if is_global is not None: - query = query.where(Jargon.is_global == is_global) + statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc()) + statement = statement.offset((page - 1) * page_size).limit(page_size) - # 获取总数 - total = query.count() - - # 分页和排序(按使用次数降序) - query = query.order_by(Jargon.count.desc(), Jargon.id.desc()) - query = query.paginate(page, page_size) - - # 转换为响应格式 - data = [jargon_to_dict(j) for j in query] + with get_db_session() as session: + total = session.exec(count_statement).one() + jargons = session.exec(statement).all() + data = [jargon_to_dict(jargon, session) for jargon in jargons] return JargonListResponse( success=True, @@ -276,10 +264,9 @@ async def get_jargon_list( async def get_chat_list(): """获取所有有黑话记录的聊天列表""" try: - # 获取所有不同的 chat_id - chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)) - - chat_id_list = [j.chat_id for j in chat_ids if j.chat_id] + with get_db_session() as session: + statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None)) + chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id] # 用于按 stream_id 去重 seen_stream_ids: set[str] = set() @@ -290,27 +277,28 @@ async def get_chat_list(): seen_stream_ids.add(stream_ids[0]) result = [] - for stream_id in seen_stream_ids: - # 尝试从 ChatStreams 表获取聊天名称 - chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id) - if chat_stream: - result.append( - ChatInfoResponse( - chat_id=stream_id, # 使用 stream_id,方便筛选匹配 - chat_name=chat_stream.group_name or stream_id, - platform=chat_stream.platform, - is_group=True, + with get_db_session() as session: + for stream_id in seen_stream_ids: + chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first() + if chat_session: + chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20] + result.append( + ChatInfoResponse( + chat_id=stream_id, + chat_name=chat_name, + platform=chat_session.platform, + is_group=bool(chat_session.group_id), + ) ) - ) - else: - result.append( - ChatInfoResponse( - chat_id=stream_id, # 使用 stream_id - chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id, - platform=None, - is_group=False, + else: + result.append( + ChatInfoResponse( + chat_id=stream_id, + chat_name=stream_id[:20], + platform=None, + is_group=False, + ) ) - ) return ChatListResponse(success=True, data=result) @@ -323,35 +311,35 @@ async def get_chat_list(): async def get_jargon_stats(): """获取黑话统计数据""" try: - # 总数量 - total = Jargon.select().count() + with get_db_session() as session: + total = session.exec(select(fn.count()).select_from(Jargon)).one() - # 已确认是黑话的数量 - confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count() + confirmed_jargon = session.exec( + select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True) + ).one() + confirmed_not_jargon = session.exec( + select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False) + ).one() + pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one() - # 已确认不是黑话的数量 - confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count() + complete_count = session.exec( + select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True) + ).one() - # 未判定的数量 - pending = Jargon.select().where(Jargon.is_jargon.is_null()).count() + chat_count = session.exec( + select(fn.count()).select_from( + select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery() + ) + ).one() - # 全局黑话数量 - global_count = Jargon.select().where(Jargon.is_global).count() - - # 已完成推断的数量 - complete_count = Jargon.select().where(Jargon.is_complete).count() - - # 关联的聊天数量 - chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count() - - # 按聊天统计 TOP 5 - top_chats = ( - Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count")) - .group_by(Jargon.chat_id) - .order_by(fn.COUNT(Jargon.id).desc()) - .limit(5) - ) - top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id} + top_chats = session.exec( + select(col(Jargon.session_id), fn.count().label("count")) + .where(col(Jargon.session_id).is_not(None)) + .group_by(col(Jargon.session_id)) + .order_by(fn.count().desc()) + .limit(5) + ).all() + top_chats_dict = {session_id: count for session_id, count in top_chats if session_id} return JargonStatsResponse( success=True, @@ -360,7 +348,6 @@ async def get_jargon_stats(): "confirmed_jargon": confirmed_jargon, "confirmed_not_jargon": confirmed_not_jargon, "pending": pending, - "global_count": global_count, "complete_count": complete_count, "chat_count": chat_count, "top_chats": top_chats_dict, @@ -376,11 +363,13 @@ async def get_jargon_stats(): async def get_jargon_detail(jargon_id: int): """获取黑话详情""" try: - jargon = Jargon.get_or_none(Jargon.id == jargon_id) - if not jargon: - raise HTTPException(status_code=404, detail="黑话不存在") + with get_db_session() as session: + jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first() + if not jargon: + raise HTTPException(status_code=404, detail="黑话不存在") + data = JargonResponse(**jargon_to_dict(jargon, session)) - return JargonDetailResponse(success=True, data=jargon_to_dict(jargon)) + return JargonDetailResponse(success=True, data=data) except HTTPException: raise @@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int): async def create_jargon(request: JargonCreateRequest): """创建黑话""" try: - # 检查是否已存在相同内容的黑话 - existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)) - if existing: - raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") + with get_db_session() as session: + existing = session.exec( + select(Jargon).where( + (col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id) + ) + ).first() + if existing: + raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") - # 创建黑话 - jargon = Jargon.create( - content=request.content, - raw_content=request.raw_content, - meaning=request.meaning, - chat_id=request.chat_id, - is_global=request.is_global, - count=0, - is_jargon=None, - is_complete=False, - ) + jargon = Jargon( + content=request.content, + raw_content=request.raw_content, + meaning=request.meaning or "", + session_id=request.chat_id, + count=0, + is_jargon=None, + is_complete=False, + ) + session.add(jargon) + session.flush() - logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}") + logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}") + data = JargonResponse(**jargon_to_dict(jargon, session)) - return JargonCreateResponse( - success=True, - message="创建成功", - data=jargon_to_dict(jargon), - ) + return JargonCreateResponse(success=True, message="创建成功", data=data) except HTTPException: raise @@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest): async def update_jargon(jargon_id: int, request: JargonUpdateRequest): """更新黑话(增量更新)""" try: - jargon = Jargon.get_or_none(Jargon.id == jargon_id) - if not jargon: - raise HTTPException(status_code=404, detail="黑话不存在") + with get_db_session() as session: + jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first() + if not jargon: + raise HTTPException(status_code=404, detail="黑话不存在") - # 增量更新字段 - update_data = request.model_dump(exclude_unset=True) - if update_data: - for field, value in update_data.items(): - if value is not None or field in ["meaning", "raw_content", "is_jargon"]: - setattr(jargon, field, value) - jargon.save() + update_data = request.model_dump(exclude_unset=True) + if update_data: + for field, value in update_data.items(): + if field == "is_global": + continue + if field == "chat_id": + jargon.session_id = value + continue + if value is not None or field in ["meaning", "raw_content", "is_jargon"]: + setattr(jargon, field, value) + session.add(jargon) - logger.info(f"更新黑话成功: id={jargon_id}") + logger.info(f"更新黑话成功: id={jargon_id}") + data = JargonResponse(**jargon_to_dict(jargon, session)) - return JargonUpdateResponse( - success=True, - message="更新成功", - data=jargon_to_dict(jargon), - ) + return JargonUpdateResponse(success=True, message="更新成功", data=data) except HTTPException: raise @@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest): async def delete_jargon(jargon_id: int): """删除黑话""" try: - jargon = Jargon.get_or_none(Jargon.id == jargon_id) - if not jargon: - raise HTTPException(status_code=404, detail="黑话不存在") + with get_db_session() as session: + jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first() + if not jargon: + raise HTTPException(status_code=404, detail="黑话不存在") - content = jargon.content - jargon.delete_instance() + content = jargon.content + session.delete(jargon) - logger.info(f"删除黑话成功: id={jargon_id}, content={content}") + logger.info(f"删除黑话成功: id={jargon_id}, content={content}") - return JargonDeleteResponse( - success=True, - message="删除成功", - deleted_count=1, - ) + return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1) except HTTPException: raise @@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest): if not request.ids: raise HTTPException(status_code=400, detail="ID列表不能为空") - deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute() + with get_db_session() as session: + result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids))) + deleted_count = result.rowcount or 0 - logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录") + logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录") return JargonDeleteResponse( success=True, @@ -516,14 +507,16 @@ async def batch_set_jargon_status( if not ids: raise HTTPException(status_code=400, detail="ID列表不能为空") - updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute() + with get_db_session() as session: + jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all() + for jargon in jargons: + jargon.is_jargon = is_jargon + session.add(jargon) + updated_count = len(jargons) - logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}") + logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}") - return JargonUpdateResponse( - success=True, - message=f"成功更新 {updated_count} 条黑话状态", - ) + return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态") except HTTPException: raise