增加群prompt

pull/580/head
HexatomicRing 2025-03-26 11:01:51 +08:00
parent 79e383e90b
commit 8626aa9e76
5 changed files with 25 additions and 8 deletions

View File

@ -357,7 +357,8 @@ class ChatBot:
)
if isinstance(event, GroupRecallNoticeEvent):
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
group_info = GroupInfo.from_dict(await bot.get_group_info(group_id=event.group_id))
group_info.platform = "qq"
else:
group_info = None
@ -418,9 +419,9 @@ class ChatBot:
platform="qq",
)
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
group_info = GroupInfo.from_dict(await bot.get_group_info(group_id=event.group_id))
group_info.platform = "qq"
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
message_cq = MessageRecvCQ(

View File

@ -173,9 +173,13 @@ class ChatManager:
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
def get_stream(self, stream_id: str, new_group_info: GroupInfo = None) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
stream = self.streams.get(stream_id)
# 动态更新群信息
if stream and new_group_info:
stream.group_info = new_group_info
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None

View File

@ -28,6 +28,7 @@ class BotConfig:
talk_allowed_groups = set()
talk_frequency_down_groups = set()
ban_user_id = set()
group_prompts = dict()
#personality
PROMPT_PERSONALITY = [
@ -413,6 +414,7 @@ class BotConfig:
config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
config.ban_user_id = set(groups_config.get("ban_user_id", []))
config.group_prompts = groups_config.get("group_prompts", dict())
def experimental(parent: dict):
experimental_config = parent["experimental"]

View File

@ -68,7 +68,7 @@ class PromptBuilder:
chat_talking_prompt = get_recent_group_detailed_plain_text(
stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
)
chat_stream = chat_manager.get_stream(stream_id)
chat_stream = chat_manager.get_stream(stream_id, chat_stream.group_info)
if chat_stream.group_info:
chat_talking_prompt = chat_talking_prompt
else:
@ -100,7 +100,16 @@ class PromptBuilder:
# 类型
if chat_in_group:
chat_target = "你正在qq群里聊天下面是群里在聊的内容"
group_info = chat_stream.group_info
group_prompt = global_config.group_prompts.get(str(group_info.group_id))
if not group_prompt:
group_prompt = ''
else:
group_prompt = ",这是" + group_prompt
group_name = group_info.group_name
if not group_name:
group_name = "qq群"
chat_target = f"你正在{group_name}里聊天{group_prompt},下面是群里在聊的内容:"
chat_target_2 = "和群里聊天"
else:
chat_target = f"你正在和{sender_name}聊天,这是你们之前聊的内容:"

View File

@ -1,5 +1,5 @@
[inner]
version = "0.0.11"
version = "0.0.12"
[mai_version]
version = "0.6.0"
@ -29,6 +29,7 @@ talk_allowed = [
] #可以回复消息的群号码
talk_frequency_down = [] #降低回复频率的群号码
ban_user_id = [] #禁止回复和读取消息的QQ号
group_prompts = {123 = "这是一个qq群聊"} # 群聊描述(避免换行)
[personality]
prompt_personality = [