mirror of https://github.com/Mai-with-u/MaiBot.git
169 lines
7.1 KiB
Python
169 lines
7.1 KiB
Python
import os
|
||
from typing import AsyncGenerator
|
||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||
from src.config.config import global_config, model_config
|
||
from src.chat.message_receive.message import MessageRecvS4U
|
||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||
from src.common.logger import get_logger
|
||
import asyncio
|
||
import re
|
||
|
||
|
||
logger = get_logger("s4u_stream_generator")
|
||
|
||
|
||
class S4UStreamGenerator:
|
||
def __init__(self):
|
||
replyer_1_config = model_config.model_task_config.replyer_1
|
||
model_to_use = replyer_1_config.model_list[0]
|
||
model_info = model_config.get_model_info(model_to_use)
|
||
if not model_info:
|
||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
||
provider_name = model_info.api_provider
|
||
provider_info = model_config.get_provider(provider_name)
|
||
if not provider_info:
|
||
logger.error("`replyer_1` 找不到对应的Provider")
|
||
raise ValueError("`replyer_1` 找不到对应的Provider")
|
||
|
||
api_key = provider_info.api_key
|
||
base_url = provider_info.base_url
|
||
|
||
if not api_key:
|
||
logger.error(f"{provider_name}没有配置API KEY")
|
||
raise ValueError(f"{provider_name}没有配置API KEY")
|
||
|
||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||
self.model_1_name = model_to_use
|
||
self.replyer_1_config = replyer_1_config
|
||
|
||
self.current_model_name = "unknown model"
|
||
self.partial_response = ""
|
||
|
||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||
self.sentence_split_pattern = re.compile(
|
||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||
re.UNICODE | re.DOTALL,
|
||
)
|
||
|
||
self.chat_stream = None
|
||
|
||
async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
|
||
# person_id = PersonInfoManager.get_person_id(
|
||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||
# )
|
||
# person_info_manager = get_person_info_manager()
|
||
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||
|
||
# if message.chat_stream.user_info.user_nickname:
|
||
# if person_name:
|
||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||
# else:
|
||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||
# else:
|
||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||
|
||
# 构建prompt
|
||
if previous_reply_context:
|
||
message_txt = f"""
|
||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||
---
|
||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||
{message.processed_plain_text}
|
||
"""
|
||
return True, message_txt
|
||
else:
|
||
message_txt = message.processed_plain_text
|
||
return False, message_txt
|
||
|
||
async def generate_response(
|
||
self, message: MessageRecvS4U, previous_reply_context: str = ""
|
||
) -> AsyncGenerator[str, None]:
|
||
"""根据当前模型类型选择对应的生成函数"""
|
||
# 从global_config中获取模型概率值并选择模型
|
||
self.partial_response = ""
|
||
message_txt = message.processed_plain_text
|
||
if not message.is_internal:
|
||
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
|
||
if interupted:
|
||
message_txt = message_txt_added
|
||
|
||
message.chat_stream = self.chat_stream
|
||
prompt = await prompt_builder.build_prompt_normal(
|
||
message=message,
|
||
message_txt=message_txt,
|
||
)
|
||
|
||
logger.info(
|
||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||
) # noqa: E501
|
||
|
||
current_client = self.client_1
|
||
self.current_model_name = self.model_1_name
|
||
|
||
extra_kwargs = {}
|
||
if self.replyer_1_config.get("enable_thinking") is not None:
|
||
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
|
||
if self.replyer_1_config.get("thinking_budget") is not None:
|
||
extra_kwargs["thinking_budget"] = self.replyer_1_config.get("thinking_budget")
|
||
|
||
async for chunk in self._generate_response_with_model(
|
||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||
):
|
||
yield chunk
|
||
|
||
async def _generate_response_with_model(
|
||
self,
|
||
prompt: str,
|
||
client: AsyncOpenAIClient,
|
||
model_name: str,
|
||
**kwargs,
|
||
) -> AsyncGenerator[str, None]:
|
||
buffer = ""
|
||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||
punctuation_buffer = ""
|
||
|
||
async for content in client.get_stream_content(
|
||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||
):
|
||
buffer += content
|
||
|
||
# 使用正则表达式匹配句子
|
||
last_match_end = 0
|
||
for match in self.sentence_split_pattern.finditer(buffer):
|
||
sentence = match.group(0).strip()
|
||
if sentence:
|
||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||
# 检查是否只是一个标点符号
|
||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||
punctuation_buffer += sentence
|
||
else:
|
||
# 发送之前累积的标点和当前句子
|
||
to_yield = punctuation_buffer + sentence
|
||
if to_yield.endswith((",", ",")):
|
||
to_yield = to_yield.rstrip(",,")
|
||
|
||
self.partial_response += to_yield
|
||
yield to_yield
|
||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||
await asyncio.sleep(0) # 允许其他任务运行
|
||
|
||
last_match_end = match.end(0)
|
||
|
||
# 从缓冲区移除已发送的部分
|
||
if last_match_end > 0:
|
||
buffer = buffer[last_match_end:]
|
||
|
||
# 发送缓冲区中剩余的任何内容
|
||
to_yield = (punctuation_buffer + buffer).strip()
|
||
if to_yield:
|
||
if to_yield.endswith((",", ",")):
|
||
to_yield = to_yield.rstrip(",,")
|
||
if to_yield:
|
||
self.partial_response += to_yield
|
||
yield to_yield
|