import asyncio
import json
from collections import deque
from datetime import datetime
from typing import Dict, List, Optional
from aiohttp import web, WSMsgType
import aiohttp_cors
from threading import Thread
import weakref
from src.chat.message_receive.message import MessageRecv
from src.common.logger import get_logger
logger = get_logger("context_web")
class ContextMessage:
"""上下文消息类"""
def __init__(self, message: MessageRecv):
self.user_name = message.message_info.user_info.user_nickname
self.user_id = message.message_info.user_info.user_id
self.content = message.processed_plain_text
self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
def to_dict(self):
return {
"user_name": self.user_name,
"user_id": self.user_id,
"content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name
}
class ContextWebManager:
"""上下文网页管理器"""
def __init__(self, max_messages: int = 10, port: int = 8765):
self.max_messages = max_messages
self.port = port
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
self.websockets: List[web.WebSocketResponse] = []
self.app = None
self.runner = None
self.site = None
self._server_starting = False # 添加启动标志防止并发
async def start_server(self):
"""启动web服务器"""
if self.site is not None:
logger.debug("Web服务器已经启动,跳过重复启动")
return
if self._server_starting:
logger.debug("Web服务器正在启动中,等待启动完成...")
# 等待启动完成
while self._server_starting and self.site is None:
await asyncio.sleep(0.1)
return
self._server_starting = True
try:
self.app = web.Application()
# 设置CORS
cors = aiohttp_cors.setup(self.app, defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*",
allow_methods="*"
)
})
# 添加路由
self.app.router.add_get('/', self.index_handler)
self.app.router.add_get('/ws', self.websocket_handler)
self.app.router.add_get('/api/contexts', self.get_contexts_handler)
self.app.router.add_get('/debug', self.debug_handler)
# 为所有路由添加CORS
for route in list(self.app.router.routes()):
cors.add(route)
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, 'localhost', self.port)
await self.site.start()
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
except Exception as e:
logger.error(f"❌ 启动Web服务器失败: {e}")
# 清理部分启动的资源
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
raise
finally:
self._server_starting = False
async def stop_server(self):
"""停止web服务器"""
if self.site:
await self.site.stop()
if self.runner:
await self.runner.cleanup()
self.app = None
self.runner = None
self.site = None
self._server_starting = False
async def index_handler(self, request):
"""主页处理器"""
html_content = '''
聊天上下文
'''
return web.Response(text=html_content, content_type='text/html')
async def websocket_handler(self, request):
"""WebSocket处理器"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self.websockets.append(ws)
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
# 发送初始数据
await self.send_contexts_to_websocket(ws)
async for msg in ws:
if msg.type == WSMsgType.ERROR:
logger.error(f'WebSocket错误: {ws.exception()}')
break
# 清理断开的连接
if ws in self.websockets:
self.websockets.remove(ws)
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
return ws
async def get_contexts_handler(self, request):
"""获取上下文API"""
all_context_msgs = []
for chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
return web.json_response({"contexts": contexts_data})
async def debug_handler(self, request):
"""调试信息处理器"""
debug_info = {
"server_status": "running",
"websocket_connections": len(self.websockets),
"total_chats": len(self.contexts),
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
}
# 构建聊天详情HTML
chats_html = ""
for chat_id, contexts in self.contexts.items():
messages_html = ""
for msg in contexts:
timestamp = msg.timestamp.strftime("%H:%M:%S")
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
messages_html += f'[{timestamp}] {msg.user_name}: {content}
'
chats_html += f'''
聊天 {chat_id} ({len(contexts)} 条消息)
{messages_html}
'''
html_content = f'''
调试信息
上下文网页管理器调试信息
服务器状态
状态: {debug_info["server_status"]}
WebSocket连接数: {debug_info["websocket_connections"]}
聊天总数: {debug_info["total_chats"]}
消息总数: {debug_info["total_messages"]}
聊天详情
{chats_html}
操作
'''
return web.Response(text=html_content, content_type='text/html')
async def add_message(self, chat_id: str, message: MessageRecv):
"""添加新消息到上下文"""
if chat_id not in self.contexts:
self.contexts[chat_id] = deque(maxlen=self.max_messages)
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
context_msg = ContextMessage(message)
self.contexts[chat_id].append(context_msg)
# 统计当前总消息数
total_messages = sum(len(contexts) for contexts in self.contexts.values())
logger.info(f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}")
# 调试:打印当前所有消息
logger.info(f"📝 当前上下文中的所有消息:")
for cid, contexts in self.contexts.items():
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
for i, msg in enumerate(contexts):
logger.info(f" {i+1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}...")
# 广播更新给所有WebSocket连接
await self.broadcast_contexts()
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
"""向单个WebSocket发送上下文数据"""
all_context_msgs = []
for chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
data = {"contexts": contexts_data}
await ws.send_str(json.dumps(data, ensure_ascii=False))
async def broadcast_contexts(self):
"""向所有WebSocket连接广播上下文更新"""
if not self.websockets:
logger.debug("没有WebSocket连接,跳过广播")
return
all_context_msgs = []
for chat_id, contexts in self.contexts.items():
all_context_msgs.extend(list(contexts))
# 按时间排序,最新的在最后
all_context_msgs.sort(key=lambda x: x.timestamp)
# 转换为字典格式
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages:]]
data = {"contexts": contexts_data}
message = json.dumps(data, ensure_ascii=False)
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
# 创建WebSocket列表的副本,避免在遍历时修改
websockets_copy = self.websockets.copy()
removed_count = 0
for ws in websockets_copy:
if ws.closed:
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
else:
try:
await ws.send_str(message)
logger.debug("消息发送成功")
except Exception as e:
logger.error(f"发送WebSocket消息失败: {e}")
if ws in self.websockets:
self.websockets.remove(ws)
removed_count += 1
if removed_count > 0:
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
# 全局实例
_context_web_manager: Optional[ContextWebManager] = None
def get_context_web_manager() -> ContextWebManager:
"""获取上下文网页管理器实例"""
global _context_web_manager
if _context_web_manager is None:
_context_web_manager = ContextWebManager()
return _context_web_manager
async def init_context_web_manager():
"""初始化上下文网页管理器"""
manager = get_context_web_manager()
await manager.start_server()
return manager