feat:增加聊天名称获取功能和聊天列表接口

pull/1385/head
墨梓柒 2025-11-26 18:30:51 +08:00
parent 11f2d2dec3
commit 1bcd37b206
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
1 changed files with 86 additions and 2 deletions

View File

@ -2,9 +2,9 @@
from fastapi import APIRouter, HTTPException, Header, Query from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression, ChatStreams
from .token_manager import get_token_manager from .token_manager import get_token_manager
import time import time
@ -115,6 +115,90 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
) )
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表用于下拉选择
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
)
)
# 按名称排序
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse) @router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list( async def get_expression_list(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),