mirror of https://github.com/Mai-with-u/MaiBot.git
feat:增加聊天名称获取功能和聊天列表接口
parent
11f2d2dec3
commit
1bcd37b206
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .token_manager import get_token_manager
|
||||
import time
|
||||
|
||||
|
|
@ -115,6 +115,90 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
|||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
# 优先使用群聊名称,否则使用用户昵称
|
||||
if chat_stream.group_name:
|
||||
return chat_stream.group_name
|
||||
elif chat_stream.user_nickname:
|
||||
return chat_stream.user_nickname
|
||||
return chat_id # 找不到时返回原始ID
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称"""
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
||||
for cs in chat_streams:
|
||||
if cs.group_name:
|
||||
result[cs.stream_id] = cs.group_name
|
||||
elif cs.user_nickname:
|
||||
result[cs.stream_id] = cs.user_nickname
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
|
||||
|
||||
class ChatInfo(BaseModel):
|
||||
"""聊天信息"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool
|
||||
data: List[ChatInfo]
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=cs.stream_id,
|
||||
chat_name=chat_name,
|
||||
platform=cs.platform,
|
||||
is_group=bool(cs.group_id),
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/list", response_model=ExpressionListResponse)
|
||||
async def get_expression_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
|
|
|
|||
Loading…
Reference in New Issue