diff --git a/src/webui/expression_routes.py b/src/webui/expression_routes.py index 983918cf..40f87489 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/expression_routes.py @@ -2,9 +2,9 @@ from fastapi import APIRouter, HTTPException, Header, Query from pydantic import BaseModel -from typing import Optional, List +from typing import Optional, List, Dict from src.common.logger import get_logger -from src.common.database.database_model import Expression +from src.common.database.database_model import Expression, ChatStreams from .token_manager import get_token_manager import time @@ -115,6 +115,90 @@ def expression_to_response(expression: Expression) -> ExpressionResponse: ) +def get_chat_name(chat_id: str) -> str: + """根据 chat_id 获取聊天名称""" + try: + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream: + # 优先使用群聊名称,否则使用用户昵称 + if chat_stream.group_name: + return chat_stream.group_name + elif chat_stream.user_nickname: + return chat_stream.user_nickname + return chat_id # 找不到时返回原始ID + except Exception: + return chat_id + + +def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]: + """批量获取聊天名称""" + result = {cid: cid for cid in chat_ids} # 默认值为原始ID + try: + chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids)) + for cs in chat_streams: + if cs.group_name: + result[cs.stream_id] = cs.group_name + elif cs.user_nickname: + result[cs.stream_id] = cs.user_nickname + except Exception as e: + logger.warning(f"批量获取聊天名称失败: {e}") + return result + + +class ChatInfo(BaseModel): + """聊天信息""" + + chat_id: str + chat_name: str + platform: Optional[str] = None + is_group: bool = False + + +class ChatListResponse(BaseModel): + """聊天列表响应""" + + success: bool + data: List[ChatInfo] + + +@router.get("/chats", response_model=ChatListResponse) +async def get_chat_list(authorization: Optional[str] = Header(None)): + """ + 获取所有聊天列表(用于下拉选择) + + Args: + authorization: Authorization header + + Returns: + 聊天列表 + """ + try: + verify_auth_token(authorization) + + chat_list = [] + for cs in ChatStreams.select(): + chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id) + chat_list.append( + ChatInfo( + chat_id=cs.stream_id, + chat_name=chat_name, + platform=cs.platform, + is_group=bool(cs.group_id), + ) + ) + + # 按名称排序 + chat_list.sort(key=lambda x: x.chat_name) + + return ChatListResponse(success=True, data=chat_list) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"获取聊天列表失败: {e}") + raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e + + @router.get("/list", response_model=ExpressionListResponse) async def get_expression_list( page: int = Query(1, ge=1, description="页码"),