mirror of https://github.com/Mai-with-u/MaiBot.git
优化异步处理,避免事件循环问题并增强错误日志记录
parent
7a68ab0319
commit
fab4656185
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
|
@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||
return True
|
||||
|
||||
|
||||
def main(): # sourcery skip: dict-comprehension
|
||||
async def main_async(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
|
|
@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension
|
|||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_closed():
|
||||
# 如果事件循环已关闭,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -117,30 +117,36 @@ class EmbeddingStore:
|
|||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,处理异步调用"""
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,直接运行
|
||||
result = asyncio.run(get_embedding(s))
|
||||
if result is None:
|
||||
# 创建新的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
return embedding
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
|
@ -181,8 +187,14 @@ class EmbeddingStore:
|
|||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 直接使用异步函数
|
||||
embedding = asyncio.run(llm.get_embedding(s))
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
|||
|
||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||
"""获取文本的embedding向量"""
|
||||
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding, _ = await llm.get_embedding(text)
|
||||
|
|
|
|||
|
|
@ -159,14 +159,23 @@ class ClientRegistry:
|
|||
|
||||
return decorator
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
return client_class(api_provider)
|
||||
else:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
|
||||
# 正常的缓存逻辑
|
||||
if api_provider.name not in self.client_instance_cache:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||
|
|
|
|||
|
|
@ -388,6 +388,7 @@ class OpenaiClient(BaseClient):
|
|||
base_url=api_provider.base_url,
|
||||
api_key=api_provider.api_key,
|
||||
max_retries=0,
|
||||
timeout=api_provider.timeout,
|
||||
)
|
||||
|
||||
async def get_response(
|
||||
|
|
@ -520,6 +521,11 @@ class OpenaiClient(BaseClient):
|
|||
extra_body=extra_params,
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, '__cause__') and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
# 重封装APIError为RespNotOkException
|
||||
|
|
|
|||
|
|
@ -248,7 +248,11 @@ class LLMRequest:
|
|||
)
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
client = client_registry.get_client_class_instance(api_provider)
|
||||
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||
|
|
|
|||
Loading…
Reference in New Issue