mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev
commit
11462193cd
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def main(): # sourcery skip: dict-comprehension
|
async def main_async(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
|
|
@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||||
|
# 检查是否有现有的事件循环
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
if loop.is_closed():
|
||||||
|
# 如果事件循环已关闭,创建新的
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有运行的事件循环,创建新的
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 在新的事件循环中运行异步主函数
|
||||||
|
loop.run_until_complete(main_async())
|
||||||
|
finally:
|
||||||
|
# 确保事件循环被正确关闭
|
||||||
|
if not loop.is_closed():
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||||
main()
|
main()
|
||||||
|
|
|
||||||
|
|
@ -117,30 +117,36 @@ class EmbeddingStore:
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
"""获取字符串的嵌入向量,处理异步调用"""
|
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||||
|
# 创建新的事件循环并在完成后立即关闭
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试获取当前事件循环
|
# 创建新的LLMRequest实例
|
||||||
asyncio.get_running_loop()
|
from src.llm_models.utils_model import LLMRequest
|
||||||
# 如果在事件循环中,使用线程池执行
|
from src.config.config import model_config
|
||||||
import concurrent.futures
|
|
||||||
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||||
def run_in_thread():
|
|
||||||
return asyncio.run(get_embedding(s))
|
# 使用新的事件循环运行异步方法
|
||||||
|
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
||||||
future = executor.submit(run_in_thread)
|
if embedding and len(embedding) > 0:
|
||||||
result = future.result()
|
return embedding
|
||||||
if result is None:
|
else:
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
|
||||||
return []
|
|
||||||
return result
|
|
||||||
except RuntimeError:
|
|
||||||
# 没有运行的事件循环,直接运行
|
|
||||||
result = asyncio.run(get_embedding(s))
|
|
||||||
if result is None:
|
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
return []
|
return []
|
||||||
return result
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
# 确保事件循环被正确关闭
|
||||||
|
try:
|
||||||
|
loop.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||||
"""使用多线程批量获取嵌入向量
|
"""使用多线程批量获取嵌入向量
|
||||||
|
|
@ -181,8 +187,14 @@ class EmbeddingStore:
|
||||||
|
|
||||||
for i, s in enumerate(chunk_strs):
|
for i, s in enumerate(chunk_strs):
|
||||||
try:
|
try:
|
||||||
# 直接使用异步函数
|
# 在线程中创建独立的事件循环
|
||||||
embedding = asyncio.run(llm.get_embedding(s))
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
embedding = loop.run_until_complete(llm.get_embedding(s))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
if embedding and len(embedding) > 0:
|
||||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -113,6 +113,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
|
||||||
|
|
||||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
|
# 每次都创建新的LLMRequest实例以避免事件循环冲突
|
||||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
|
||||||
try:
|
try:
|
||||||
embedding, _ = await llm.get_embedding(text)
|
embedding, _ = await llm.get_embedding(text)
|
||||||
|
|
|
||||||
|
|
@ -159,14 +159,23 @@ class ClientRegistry:
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
|
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||||
"""
|
"""
|
||||||
获取注册的API客户端实例
|
获取注册的API客户端实例
|
||||||
Args:
|
Args:
|
||||||
api_provider: APIProvider实例
|
api_provider: APIProvider实例
|
||||||
|
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||||
Returns:
|
Returns:
|
||||||
BaseClient: 注册的API客户端实例
|
BaseClient: 注册的API客户端实例
|
||||||
"""
|
"""
|
||||||
|
# 如果强制创建新实例,直接创建不使用缓存
|
||||||
|
if force_new:
|
||||||
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
|
return client_class(api_provider)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
|
|
||||||
|
# 正常的缓存逻辑
|
||||||
if api_provider.name not in self.client_instance_cache:
|
if api_provider.name not in self.client_instance_cache:
|
||||||
if client_class := self.client_registry.get(api_provider.client_type):
|
if client_class := self.client_registry.get(api_provider.client_type):
|
||||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
||||||
|
|
|
||||||
|
|
@ -388,6 +388,7 @@ class OpenaiClient(BaseClient):
|
||||||
base_url=api_provider.base_url,
|
base_url=api_provider.base_url,
|
||||||
api_key=api_provider.api_key,
|
api_key=api_provider.api_key,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
|
timeout=api_provider.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
|
|
@ -520,6 +521,11 @@ class OpenaiClient(BaseClient):
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
)
|
)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
|
# 添加详细的错误信息以便调试
|
||||||
|
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||||
|
logger.error(f"错误类型: {type(e)}")
|
||||||
|
if hasattr(e, '__cause__') and e.__cause__:
|
||||||
|
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||||
raise NetworkConnectionError() from e
|
raise NetworkConnectionError() from e
|
||||||
except APIStatusError as e:
|
except APIStatusError as e:
|
||||||
# 重封装APIError为RespNotOkException
|
# 重封装APIError为RespNotOkException
|
||||||
|
|
|
||||||
|
|
@ -248,7 +248,11 @@ class LLMRequest:
|
||||||
)
|
)
|
||||||
model_info = model_config.get_model_info(least_used_model_name)
|
model_info = model_config.get_model_info(least_used_model_name)
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
api_provider = model_config.get_provider(model_info.api_provider)
|
||||||
client = client_registry.get_client_class_instance(api_provider)
|
|
||||||
|
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||||
|
force_new_client = (self.request_type == "embedding")
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||||
|
|
||||||
logger.debug(f"选择请求模型: {model_info.name}")
|
logger.debug(f"选择请求模型: {model_info.name}")
|
||||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue