MaiBot/src/mais4u/mais4u_chat/s4u_stream_generator.py

156 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
from typing import AsyncGenerator
from src.llm_models.utils_model import LLMRequest
from src.mais4u.openai_client import AsyncOpenAIClient
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.common.logger import get_logger
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
import asyncio
import re
logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
replyer_1_config = global_config.model.replyer_1
provider = replyer_1_config.get("provider")
if not provider:
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
api_key = os.environ.get(f"{provider.upper()}_KEY")
base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
if not api_key:
logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
self.model_1_name = replyer_1_config.get("name")
if not self.model_1_name:
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
self.replyer_1_config = replyer_1_config
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
self.current_model_name = "unknown model"
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
self.sentence_split_pattern = re.compile(
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))' # 匹配直到句子结束符
, re.UNICODE | re.DOTALL
)
async def generate_response(
self, message: MessageRecv, previous_reply_context: str = ""
) -> AsyncGenerator[str, None]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
current_client = self.client_1
self.current_model_name = self.model_1_name
person_id = PersonInfoManager.get_person_id(
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
)
person_info_manager = get_person_info_manager()
person_name = await person_info_manager.get_value(person_id, "person_name")
if message.chat_stream.user_info.user_nickname:
sender_name = f"[{message.chat_stream.user_info.user_nickname}]你叫ta{person_name}"
else:
sender_name = f"用户({message.chat_stream.user_info.user_id})"
# 构建prompt
if previous_reply_context:
message_txt = f"""
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
[你已经对上一条消息说的话]: {previous_reply_context}
---
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
{message.processed_plain_text}
"""
else:
message_txt = message.processed_plain_text
prompt = await prompt_builder.build_prompt_normal(
message = message,
message_txt=message_txt,
sender_name=sender_name,
chat_stream=message.chat_stream,
)
logger.info(
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
) # noqa: E501
extra_kwargs = {}
if self.replyer_1_config.get("enable_thinking") is not None:
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
if self.replyer_1_config.get("thinking_budget") is not None:
extra_kwargs["thinking_budget"] = self.replyer_1_config.get("thinking_budget")
async for chunk in self._generate_response_with_model(
prompt, current_client, self.current_model_name, **extra_kwargs
):
yield chunk
async def _generate_response_with_model(
self,
prompt: str,
client: AsyncOpenAIClient,
model_name: str,
**kwargs,
) -> AsyncGenerator[str, None]:
print(prompt)
buffer = ""
delimiters = ",。!?,.!?\n\r" # For final trimming
punctuation_buffer = ""
async for content in client.get_stream_content(
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
):
buffer += content
# 使用正则表达式匹配句子
last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer):
sentence = match.group(0).strip()
if sentence:
# 如果句子看起来完整(即不只是等待更多内容),则发送
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
# 检查是否只是一个标点符号
if sentence in [",", "", ".", "", "!", "", "?", ""]:
punctuation_buffer += sentence
else:
# 发送之前累积的标点和当前句子
to_yield = punctuation_buffer + sentence
if to_yield.endswith((',', '')):
to_yield = to_yield.rstrip(',')
yield to_yield
punctuation_buffer = "" # 清空标点符号缓冲区
await asyncio.sleep(0) # 允许其他任务运行
last_match_end = match.end(0)
# 从缓冲区移除已发送的部分
if last_match_end > 0:
buffer = buffer[last_match_end:]
# 发送缓冲区中剩余的任何内容
to_yield = (punctuation_buffer + buffer).strip()
if to_yield:
if to_yield.endswith(('', ',')):
to_yield = to_yield.rstrip(',')
if to_yield:
yield to_yield