mirror of https://github.com/Mai-with-u/MaiBot.git
feat:增加聊天名称获取功能和聊天列表接口
parent
11f2d2dec3
commit
1bcd37b206
|
|
@ -2,9 +2,9 @@
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Dict
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression, ChatStreams
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
@ -115,6 +115,90 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_name(chat_id: str) -> str:
|
||||||
|
"""根据 chat_id 获取聊天名称"""
|
||||||
|
try:
|
||||||
|
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||||
|
if chat_stream:
|
||||||
|
# 优先使用群聊名称,否则使用用户昵称
|
||||||
|
if chat_stream.group_name:
|
||||||
|
return chat_stream.group_name
|
||||||
|
elif chat_stream.user_nickname:
|
||||||
|
return chat_stream.user_nickname
|
||||||
|
return chat_id # 找不到时返回原始ID
|
||||||
|
except Exception:
|
||||||
|
return chat_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||||
|
"""批量获取聊天名称"""
|
||||||
|
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||||
|
try:
|
||||||
|
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
||||||
|
for cs in chat_streams:
|
||||||
|
if cs.group_name:
|
||||||
|
result[cs.stream_id] = cs.group_name
|
||||||
|
elif cs.user_nickname:
|
||||||
|
result[cs.stream_id] = cs.user_nickname
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ChatInfo(BaseModel):
|
||||||
|
"""聊天信息"""
|
||||||
|
|
||||||
|
chat_id: str
|
||||||
|
chat_name: str
|
||||||
|
platform: Optional[str] = None
|
||||||
|
is_group: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ChatListResponse(BaseModel):
|
||||||
|
"""聊天列表响应"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
data: List[ChatInfo]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/chats", response_model=ChatListResponse)
|
||||||
|
async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||||
|
"""
|
||||||
|
获取所有聊天列表(用于下拉选择)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
聊天列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
chat_list = []
|
||||||
|
for cs in ChatStreams.select():
|
||||||
|
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
||||||
|
chat_list.append(
|
||||||
|
ChatInfo(
|
||||||
|
chat_id=cs.stream_id,
|
||||||
|
chat_name=chat_name,
|
||||||
|
platform=cs.platform,
|
||||||
|
is_group=bool(cs.group_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按名称排序
|
||||||
|
chat_list.sort(key=lambda x: x.chat_name)
|
||||||
|
|
||||||
|
return ChatListResponse(success=True, data=chat_list)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"获取聊天列表失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=ExpressionListResponse)
|
@router.get("/list", response_model=ExpressionListResponse)
|
||||||
async def get_expression_list(
|
async def get_expression_list(
|
||||||
page: int = Query(1, ge=1, description="页码"),
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue