diff --git a/scripts/import_openie.py b/scripts/import_openie.py index fe9f5269..c4367892 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -6,6 +6,7 @@ import sys import os +import asyncio from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -172,7 +173,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k return True -def main(): # sourcery skip: dict-comprehension +async def main_async(): # sourcery skip: dict-comprehension # 新增确认提示 print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") @@ -239,6 +240,29 @@ def main(): # sourcery skip: dict-comprehension return None +def main(): + """主函数 - 设置新的事件循环并运行异步主函数""" + # 检查是否有现有的事件循环 + try: + loop = asyncio.get_running_loop() + if loop.is_closed(): + # 如果事件循环已关闭,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # 没有运行的事件循环,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 在新的事件循环中运行异步主函数 + loop.run_until_complete(main_async()) + finally: + # 确保事件循环被正确关闭 + if not loop.is_closed(): + loop.close() + + if __name__ == "__main__": # logger.info(f"111111111111111111111111{ROOT_PATH}") main() diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d0f6e774..dec5b595 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -117,30 +117,36 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - """获取字符串的嵌入向量,处理异步调用""" + """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" + # 创建新的事件循环并在完成后立即关闭 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: - # 尝试获取当前事件循环 - asyncio.get_running_loop() - # 如果在事件循环中,使用线程池执行 - import concurrent.futures - - def run_in_thread(): - return asyncio.run(get_embedding(s)) - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - result = future.result() - if result is None: - logger.error(f"获取嵌入失败: {s}") - return [] - return result - except RuntimeError: - # 没有运行的事件循环,直接运行 - result = asyncio.run(get_embedding(s)) - if result is None: + # 创建新的LLMRequest实例 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + # 使用新的事件循环运行异步方法 + embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + + if embedding and len(embedding) > 0: + return embedding + else: logger.error(f"获取嵌入失败: {s}") return [] - return result + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + return [] + finally: + # 确保事件循环被正确关闭 + try: + loop.close() + except Exception: + pass def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 @@ -181,8 +187,14 @@ class EmbeddingStore: for i, s in enumerate(chunk_strs): try: - # 直接使用异步函数 - embedding = asyncio.run(llm.get_embedding(s)) + # 在线程中创建独立的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + embedding = loop.run_until_complete(llm.get_embedding(s)) + finally: + loop.close() + if embedding and len(embedding) > 0: chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 else: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index d0976e9c..e2e3088c 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -113,6 +113,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" + # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: embedding, _ = await llm.get_embedding(text) diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 97c34546..807f6484 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -159,14 +159,23 @@ class ClientRegistry: return decorator - def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient: + def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: """ 获取注册的API客户端实例 Args: api_provider: APIProvider实例 + force_new: 是否强制创建新实例(用于解决事件循环问题) Returns: BaseClient: 注册的API客户端实例 """ + # 如果强制创建新实例,直接创建不使用缓存 + if force_new: + if client_class := self.client_registry.get(api_provider.client_type): + return client_class(api_provider) + else: + raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + + # 正常的缓存逻辑 if api_provider.name not in self.client_instance_cache: if client_class := self.client_registry.get(api_provider.client_type): self.client_instance_cache[api_provider.name] = client_class(api_provider) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index c580899a..bba00f94 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -388,6 +388,7 @@ class OpenaiClient(BaseClient): base_url=api_provider.base_url, api_key=api_provider.api_key, max_retries=0, + timeout=api_provider.timeout, ) async def get_response( @@ -520,6 +521,11 @@ class OpenaiClient(BaseClient): extra_body=extra_params, ) except APIConnectionError as e: + # 添加详细的错误信息以便调试 + logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") + logger.error(f"错误类型: {type(e)}") + if hasattr(e, '__cause__') and e.__cause__: + logger.error(f"底层错误: {str(e.__cause__)}") raise NetworkConnectionError() from e except APIStatusError as e: # 重封装APIError为RespNotOkException diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e8e4db5f..f0229c2c 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -248,7 +248,11 @@ class LLMRequest: ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - client = client_registry.get_client_class_instance(api_provider) + + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = (self.request_type == "embedding") + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) + logger.debug(f"选择请求模型: {model_info.name}") total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用