mirror of https://github.com/Mai-with-u/MaiBot.git
156 lines
6.6 KiB
Python
156 lines
6.6 KiB
Python
import os
|
||
from typing import AsyncGenerator
|
||
from src.llm_models.utils_model import LLMRequest
|
||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||
from src.config.config import global_config
|
||
from src.chat.message_receive.message import MessageRecv
|
||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||
from src.common.logger import get_logger
|
||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||
import asyncio
|
||
import re
|
||
|
||
|
||
logger = get_logger("s4u_stream_generator")
|
||
|
||
|
||
class S4UStreamGenerator:
|
||
def __init__(self):
|
||
replyer_1_config = global_config.model.replyer_1
|
||
provider = replyer_1_config.get("provider")
|
||
if not provider:
|
||
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||
|
||
api_key = os.environ.get(f"{provider.upper()}_KEY")
|
||
base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
|
||
|
||
if not api_key:
|
||
logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
|
||
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
|
||
|
||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||
self.model_1_name = replyer_1_config.get("name")
|
||
if not self.model_1_name:
|
||
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||
self.replyer_1_config = replyer_1_config
|
||
|
||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||
self.current_model_name = "unknown model"
|
||
|
||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||
self.sentence_split_pattern = re.compile(
|
||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符
|
||
, re.UNICODE | re.DOTALL
|
||
)
|
||
|
||
async def generate_response(
|
||
self, message: MessageRecv, previous_reply_context: str = ""
|
||
) -> AsyncGenerator[str, None]:
|
||
"""根据当前模型类型选择对应的生成函数"""
|
||
# 从global_config中获取模型概率值并选择模型
|
||
current_client = self.client_1
|
||
self.current_model_name = self.model_1_name
|
||
|
||
person_id = PersonInfoManager.get_person_id(
|
||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||
)
|
||
person_info_manager = get_person_info_manager()
|
||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||
|
||
if message.chat_stream.user_info.user_nickname:
|
||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||
else:
|
||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||
|
||
# 构建prompt
|
||
if previous_reply_context:
|
||
message_txt = f"""
|
||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||
---
|
||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||
{message.processed_plain_text}
|
||
"""
|
||
else:
|
||
message_txt = message.processed_plain_text
|
||
|
||
|
||
prompt = await prompt_builder.build_prompt_normal(
|
||
message = message,
|
||
message_txt=message_txt,
|
||
sender_name=sender_name,
|
||
chat_stream=message.chat_stream,
|
||
)
|
||
|
||
logger.info(
|
||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||
) # noqa: E501
|
||
|
||
extra_kwargs = {}
|
||
if self.replyer_1_config.get("enable_thinking") is not None:
|
||
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
|
||
if self.replyer_1_config.get("thinking_budget") is not None:
|
||
extra_kwargs["thinking_budget"] = self.replyer_1_config.get("thinking_budget")
|
||
|
||
async for chunk in self._generate_response_with_model(
|
||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||
):
|
||
yield chunk
|
||
|
||
async def _generate_response_with_model(
|
||
self,
|
||
prompt: str,
|
||
client: AsyncOpenAIClient,
|
||
model_name: str,
|
||
**kwargs,
|
||
) -> AsyncGenerator[str, None]:
|
||
print(prompt)
|
||
|
||
buffer = ""
|
||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||
punctuation_buffer = ""
|
||
|
||
async for content in client.get_stream_content(
|
||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||
):
|
||
buffer += content
|
||
|
||
# 使用正则表达式匹配句子
|
||
last_match_end = 0
|
||
for match in self.sentence_split_pattern.finditer(buffer):
|
||
sentence = match.group(0).strip()
|
||
if sentence:
|
||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||
# 检查是否只是一个标点符号
|
||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||
punctuation_buffer += sentence
|
||
else:
|
||
# 发送之前累积的标点和当前句子
|
||
to_yield = punctuation_buffer + sentence
|
||
if to_yield.endswith((',', ',')):
|
||
to_yield = to_yield.rstrip(',,')
|
||
|
||
yield to_yield
|
||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||
await asyncio.sleep(0) # 允许其他任务运行
|
||
|
||
last_match_end = match.end(0)
|
||
|
||
# 从缓冲区移除已发送的部分
|
||
if last_match_end > 0:
|
||
buffer = buffer[last_match_end:]
|
||
|
||
# 发送缓冲区中剩余的任何内容
|
||
to_yield = (punctuation_buffer + buffer).strip()
|
||
if to_yield:
|
||
if to_yield.endswith((',', ',')):
|
||
to_yield = to_yield.rstrip(',,')
|
||
if to_yield:
|
||
yield to_yield
|
||
|