diff --git a/agent_core/__init__.py b/agent_core/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/agent_core/__init__.py @@ -0,0 +1 @@ + diff --git a/agent_core/llm_provider.py b/agent_core/llm_provider.py new file mode 100644 index 0000000..1bb568d --- /dev/null +++ b/agent_core/llm_provider.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +import json +from urllib.error import URLError +from urllib.request import Request, urlopen + + +class LLMConfigurationError(ValueError): + pass + + +class EmbeddingConfigurationError(ValueError): + pass + + +@dataclass +class LLMResponse: + content: str = "" + model_name: str = "" + success: bool = True + error: Exception | None = None + + +class MockLLMProvider: + def __init__(self, model_name: str = "mock-model"): + self.model_name = model_name or "mock-model" + + def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse: + user_content = "" + for message in reversed(messages): + if message.get("role") == "user": + user_content = message.get("content", "") + break + return LLMResponse( + content=f"模拟模型回答:{user_content}", + model_name=self.model_name, + success=True, + ) + + +class OpenAICompatibleProvider: + def __init__(self, api_key: str, base_url: str, model_name: str): + self.api_key = api_key + self.base_url = base_url + self.model_name = model_name + + def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse: + if not self.api_key: + return LLMResponse( + model_name=self.model_name, + success=False, + error=LLMConfigurationError("LLM_API_KEY 未配置,无法调用 OpenAI 兼容模型接口"), + ) + payload = { + "model": self.model_name, + "messages": messages, + } + if response_format: + payload["response_format"] = response_format + try: + data = _post_json( + base_url=self.base_url, + endpoint="chat/completions", + api_key=self.api_key, + payload=payload, + ) + choice = data.get("choices", [{}])[0] + content = choice.get("message", {}).get("content", "") + return LLMResponse( + content=content, + model_name=data.get("model", self.model_name), + success=True, + ) + except Exception as exc: + return LLMResponse(model_name=self.model_name, success=False, error=exc) + + +class OpenAICompatibleEmbeddingProvider: + def __init__(self, api_key: str, base_url: str, model_name: str): + self.api_key = api_key + self.base_url = base_url + self.model_name = model_name + + def embed_texts(self, texts: list[str]) -> list[list[float]]: + if not self.api_key: + raise EmbeddingConfigurationError("EMBEDDING_API_KEY 未配置,无法调用 OpenAI 兼容 Embedding 接口") + data = _post_json( + base_url=self.base_url, + endpoint="embeddings", + api_key=self.api_key, + payload={"model": self.model_name, "input": texts}, + ) + return [item.get("embedding", []) for item in data.get("data", [])] + + +def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict: + url = f"{base_url.rstrip('/')}/{endpoint}" + request = Request( + url, + data=json.dumps(payload, ensure_ascii=False).encode("utf-8"), + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + method="POST", + ) + try: + with urlopen(request, timeout=60) as response: + return json.loads(response.read().decode("utf-8")) + except URLError as exc: + raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc + + +def create_llm_provider(config: dict | None = None): + config = config or {} + provider_name = config.get("LLM_PROVIDER", "mock") + model_name = config.get("LLM_MODEL", "mock-model") + if provider_name == "mock": + return MockLLMProvider(model_name=model_name) + return OpenAICompatibleProvider( + api_key=config.get("LLM_API_KEY", ""), + base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"), + model_name=model_name, + ) + + +def create_embedding_provider(config: dict | None = None): + config = config or {} + return OpenAICompatibleEmbeddingProvider( + api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")), + base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")), + model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"), + ) diff --git a/agent_core/orchestrator.py b/agent_core/orchestrator.py new file mode 100644 index 0000000..a8e6c23 --- /dev/null +++ b/agent_core/orchestrator.py @@ -0,0 +1,40 @@ +import time + +from .results import AgentResult +from .structured_output import build_mock_structured_output +from .tool_registry import run_declared_tools +from .rag.retriever import retrieve + + +def run_agent(scenario_config: dict, user_input: str, options: dict | None = None) -> AgentResult: + started_at = time.perf_counter() + options = options or {} + output_type = scenario_config.get("output", {}).get("type", "general_answer") + + references = [] + rag_config = scenario_config.get("rag", {}) + if rag_config.get("enabled"): + references = retrieve( + scenario_id=scenario_config.get("id", ""), + query=user_input, + collection=rag_config.get("collection", scenario_config.get("id", "")), + top_k=rag_config.get("top_k", 5), + document_ids=options.get("document_ids"), + store_path=options.get("rag_store_path"), + ) + + tool_calls = run_declared_tools(scenario_config.get("tools", []), user_input) + structured_output = build_mock_structured_output(output_type, user_input, references) + answer = f"已根据「{scenario_config.get('name', '当前场景')}」生成模拟回答:{user_input}" + latency_ms = int((time.perf_counter() - started_at) * 1000) + + return AgentResult( + answer=answer, + structured_output=structured_output, + references=references, + tool_calls=tool_calls, + raw_output=answer, + model_name=options.get("model_name", "mock-model"), + latency_ms=latency_ms, + status="success", + ) diff --git a/agent_core/rag/__init__.py b/agent_core/rag/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/agent_core/rag/__init__.py @@ -0,0 +1 @@ + diff --git a/agent_core/rag/chroma_store.py b/agent_core/rag/chroma_store.py new file mode 100644 index 0000000..01cfb6f --- /dev/null +++ b/agent_core/rag/chroma_store.py @@ -0,0 +1,96 @@ +from pathlib import Path + +from django.conf import settings + +from agent_core.llm_provider import create_embedding_provider + + +def _client(path: str | Path | None = None): + import chromadb + + resolved_path = str(path or settings.CHROMA_PATH) + return chromadb.PersistentClient(path=resolved_path) + + +def _embedding_provider(): + return create_embedding_provider( + { + "EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY, + "EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL, + "EMBEDDING_MODEL": settings.EMBEDDING_MODEL, + } + ) + + +def upsert_chunks( + collection: str, + chunks: list[dict], + store_path: str | Path | None = None, +) -> None: + client = _client(store_path) + chroma_collection = client.get_or_create_collection(collection) + document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None} + for document_id in document_ids: + chroma_collection.delete(where={"document_id": document_id}) + texts = [chunk["content"] for chunk in chunks] + embeddings = _embedding_provider().embed_texts(texts) + chroma_collection.upsert( + ids=[chunk["chunk_id"] for chunk in chunks], + documents=texts, + embeddings=embeddings, + metadatas=[ + { + "scenario_id": chunk["scenario_id"], + "document_id": chunk["document_id"], + "source": chunk["source"], + "chunk_id": chunk["chunk_id"], + "created_at": chunk["created_at"], + } + for chunk in chunks + ], + ) + + +def query_chunks( + scenario_id: str, + query: str, + collection: str, + top_k: int = 5, + document_ids: list[int] | None = None, + store_path: str | Path | None = None, +) -> list[dict]: + client = _client(store_path) + chroma_collection = client.get_or_create_collection(collection) + where: dict = {"scenario_id": scenario_id} + if document_ids: + where = { + "$and": [ + {"scenario_id": scenario_id}, + {"document_id": {"$in": document_ids}}, + ] + } + embedding = _embedding_provider().embed_texts([query])[0] + result = chroma_collection.query( + query_embeddings=[embedding], + n_results=top_k, + where=where, + include=["documents", "metadatas", "distances"], + ) + chunks = [] + documents = result.get("documents", [[]])[0] + metadatas = result.get("metadatas", [[]])[0] + distances = result.get("distances", [[]])[0] + for content, metadata, distance in zip(documents, metadatas, distances): + chunks.append( + { + "scenario_id": metadata.get("scenario_id"), + "document_id": metadata.get("document_id"), + "collection": collection, + "source": metadata.get("source"), + "chunk_id": metadata.get("chunk_id"), + "content": content, + "created_at": metadata.get("created_at"), + "score": round(1 / (1 + float(distance)), 4), + } + ) + return chunks diff --git a/agent_core/rag/ingest.py b/agent_core/rag/ingest.py new file mode 100644 index 0000000..a3187a6 --- /dev/null +++ b/agent_core/rag/ingest.py @@ -0,0 +1,116 @@ +import json +import re +import importlib.util +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + +from django.conf import settings + +from .chroma_store import upsert_chunks + + +@dataclass +class IngestResult: + success: bool + chunks_count: int = 0 + error: str = "" + + +def _default_store_path() -> Path: + return Path(settings.CHROMA_PATH) / "rag_store.json" + + +def _load_store(store_path: Path) -> list[dict]: + if not store_path.exists(): + return [] + with store_path.open("r", encoding="utf-8") as file: + return json.load(file) + + +def _save_store(store_path: Path, chunks: list[dict]) -> None: + store_path.parent.mkdir(parents=True, exist_ok=True) + with store_path.open("w", encoding="utf-8") as file: + json.dump(chunks, file, ensure_ascii=False, indent=2) + + +def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]: + normalized = re.sub(r"\s+", " ", text).strip() + if not normalized: + return [] + chunks = [] + start = 0 + while start < len(normalized): + end = start + chunk_size + chunks.append(normalized[start:end]) + if end >= len(normalized): + break + start = max(end - overlap, start + 1) + return chunks + + +def ingest_document( + scenario_id: str, + source_file: str, + text: str, + collection: str, + document_id: int | None = None, + store_path: str | Path | None = None, +) -> IngestResult: + if not text.strip(): + return IngestResult(success=False, error="文档内容为空") + if store_path is None and importlib.util.find_spec("chromadb") is not None: + return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection) + resolved_store_path = Path(store_path) if store_path else _default_store_path() + existing_chunks = [ + chunk + for chunk in _load_store(resolved_store_path) + if not ( + chunk.get("document_id") == document_id + and chunk.get("scenario_id") == scenario_id + and chunk.get("collection") == collection + ) + ] + created_at = datetime.now(timezone.utc).isoformat() + new_chunks = [] + for index, chunk_text in enumerate(_split_text(text), start=1): + new_chunks.append( + { + "scenario_id": scenario_id, + "document_id": document_id, + "collection": collection, + "source": source_file, + "chunk_id": f"{scenario_id}:{source_file}:{index}", + "content": chunk_text, + "created_at": created_at, + } + ) + _save_store(resolved_store_path, [*existing_chunks, *new_chunks]) + return IngestResult(success=True, chunks_count=len(new_chunks)) + + +def _ingest_chroma_document( + document_id: int | None, + scenario_id: str, + source_file: str, + text: str, + collection: str, +) -> IngestResult: + created_at = datetime.now(timezone.utc).isoformat() + chunks = [ + { + "scenario_id": scenario_id, + "document_id": document_id, + "collection": collection, + "source": source_file, + "chunk_id": f"{scenario_id}:{document_id or source_file}:{index}", + "content": chunk_text, + "created_at": created_at, + } + for index, chunk_text in enumerate(_split_text(text), start=1) + ] + try: + upsert_chunks(collection=collection, chunks=chunks) + except Exception as exc: + return IngestResult(success=False, error=str(exc)) + return IngestResult(success=True, chunks_count=len(chunks)) diff --git a/agent_core/rag/retriever.py b/agent_core/rag/retriever.py new file mode 100644 index 0000000..d03a99b --- /dev/null +++ b/agent_core/rag/retriever.py @@ -0,0 +1,69 @@ +import json +import re +import importlib.util +from pathlib import Path + +from django.conf import settings + +from .chroma_store import query_chunks + + +def _default_store_path() -> Path: + return Path(settings.CHROMA_PATH) / "rag_store.json" + + +def _load_store(store_path: Path) -> list[dict]: + if not store_path.exists(): + return [] + with store_path.open("r", encoding="utf-8") as file: + return json.load(file) + + +def _tokens(text: str) -> set[str]: + lowered = text.lower() + ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered)) + cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered)) + chars = {char for char in lowered if "\u4e00" <= char <= "\u9fff"} + return ascii_tokens | cjk_tokens | chars + + +def _score(query_tokens: set[str], content: str) -> float: + content_tokens = _tokens(content) + if not query_tokens or not content_tokens: + return 0.0 + overlap = query_tokens & content_tokens + return round(len(overlap) / len(query_tokens), 4) + + +def retrieve( + scenario_id: str, + query: str, + collection: str, + top_k: int = 5, + document_ids: list[int] | None = None, + store_path: str | Path | None = None, +) -> list[dict]: + if store_path is None and importlib.util.find_spec("chromadb") is not None: + return query_chunks( + scenario_id=scenario_id, + query=query, + collection=collection, + top_k=top_k, + document_ids=document_ids, + ) + resolved_store_path = Path(store_path) if store_path else _default_store_path() + query_tokens = _tokens(query) + allowed_document_ids = set(document_ids or []) + scored_chunks = [] + for chunk in _load_store(resolved_store_path): + if chunk.get("scenario_id") != scenario_id: + continue + if chunk.get("collection") != collection: + continue + if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids: + continue + score = _score(query_tokens, chunk.get("content", "")) + if score <= 0: + continue + scored_chunks.append({**chunk, "score": score}) + return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k] diff --git a/agent_core/results.py b/agent_core/results.py new file mode 100644 index 0000000..d6cc08c --- /dev/null +++ b/agent_core/results.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field + + +@dataclass +class AgentResult: + answer: str = "" + structured_output: dict = field(default_factory=dict) + references: list = field(default_factory=list) + tool_calls: list = field(default_factory=list) + raw_output: str = "" + model_name: str = "mock-model" + latency_ms: int = 0 + status: str = "success" + error: str = "" diff --git a/agent_core/schemas/__init__.py b/agent_core/schemas/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/agent_core/schemas/__init__.py @@ -0,0 +1 @@ + diff --git a/agent_core/schemas/outputs.py b/agent_core/schemas/outputs.py new file mode 100644 index 0000000..c3a7be7 --- /dev/null +++ b/agent_core/schemas/outputs.py @@ -0,0 +1,7 @@ +SUPPORTED_OUTPUT_TYPES = { + "general_answer", + "document_review_report", + "ticket_response", + "quality_report", + "risk_audit_report", +} diff --git a/agent_core/structured_output.py b/agent_core/structured_output.py new file mode 100644 index 0000000..1214a0d --- /dev/null +++ b/agent_core/structured_output.py @@ -0,0 +1,8 @@ +def build_mock_structured_output(output_type: str, user_input: str, references: list) -> dict: + return { + "output_type": output_type, + "summary": f"模拟结构化输出:{user_input}", + "references_count": len(references), + "risk_level": "low", + "suggested_actions": ["补充真实 LLM Provider 后替换模拟结果"], + } diff --git a/agent_core/tool_registry.py b/agent_core/tool_registry.py new file mode 100644 index 0000000..270737f --- /dev/null +++ b/agent_core/tool_registry.py @@ -0,0 +1,40 @@ +from .tools.builtin_tools import BUILTIN_TOOLS + + +def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]: + results = [] + for tool_name in tool_names: + tool = BUILTIN_TOOLS.get(tool_name) + if tool is None: + results.append( + { + "tool_name": tool_name, + "success": False, + "arguments": {"user_input": user_input}, + "result": {}, + "error": "工具未注册", + } + ) + continue + try: + result = tool(user_input=user_input) + results.append( + { + "tool_name": tool_name, + "success": True, + "arguments": {"user_input": user_input}, + "result": result, + "error": "", + } + ) + except Exception as exc: + results.append( + { + "tool_name": tool_name, + "success": False, + "arguments": {"user_input": user_input}, + "result": {}, + "error": str(exc), + } + ) + return results diff --git a/agent_core/tools/__init__.py b/agent_core/tools/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/agent_core/tools/__init__.py @@ -0,0 +1 @@ + diff --git a/agent_core/tools/builtin_tools.py b/agent_core/tools/builtin_tools.py new file mode 100644 index 0000000..6c7464d --- /dev/null +++ b/agent_core/tools/builtin_tools.py @@ -0,0 +1,47 @@ +def calculate_rate(user_input: str) -> dict: + return {"rate": 1.0, "note": "模拟比例计算结果"} + + +def query_demo_records(user_input: str) -> dict: + try: + from apps.audit.models import DemoBusinessRecord + except Exception as exc: + return {"records": [], "error": str(exc)} + + queryset = DemoBusinessRecord.objects.all() + tokens = {token.strip().lower() for token in user_input.split() if token.strip()} + scenario_ids = set(queryset.values_list("scenario_id", flat=True)) + record_types = set(queryset.values_list("record_type", flat=True)) + matched_scenario_ids = scenario_ids & tokens + matched_record_types = record_types & tokens + if matched_scenario_ids: + queryset = queryset.filter(scenario_id__in=matched_scenario_ids) + if matched_record_types: + queryset = queryset.filter(record_type__in=matched_record_types) + records = [ + { + "id": record.id, + "scenario_id": record.scenario_id, + "record_type": record.record_type, + "title": record.title, + "payload": record.payload, + } + for record in queryset[:20] + ] + return {"records": records} + + +def check_required_fields(user_input: str) -> dict: + return {"missing_fields": [], "note": "模拟必填项检查结果"} + + +def generate_action_items(user_input: str) -> dict: + return {"items": [f"围绕问题继续核实:{user_input}"]} + + +BUILTIN_TOOLS = { + "calculate_rate": calculate_rate, + "query_demo_records": query_demo_records, + "check_required_fields": check_required_fields, + "generate_action_items": generate_action_items, +} diff --git a/tests/test_agent_core.py b/tests/test_agent_core.py new file mode 100644 index 0000000..579a089 --- /dev/null +++ b/tests/test_agent_core.py @@ -0,0 +1,122 @@ +from agent_core.orchestrator import run_agent +from agent_core.rag.ingest import ingest_document +from agent_core.rag.retriever import retrieve + + +def test_run_agent_returns_structured_mock_result(): + scenario = { + "id": "knowledge_qa", + "name": "知识库问答助手", + "rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3}, + "tools": ["generate_action_items"], + "output": {"type": "general_answer"}, + } + + result = run_agent(scenario, "如何处理异常?") + + assert result.status == "success" + assert result.answer + assert result.structured_output["output_type"] == "general_answer" + assert isinstance(result.references, list) + assert result.tool_calls[0]["tool_name"] == "generate_action_items" + + +def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path): + store_path = tmp_path / "rag_store.json" + text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。" + + result = ingest_document( + scenario_id="quality_analysis", + source_file="quality.md", + text=text, + collection="quality_analysis", + store_path=store_path, + ) + ingest_document( + scenario_id="risk_audit", + source_file="risk.md", + text="报销审核需要检查发票、金额和审批链。", + collection="risk_audit", + store_path=store_path, + ) + + chunks = retrieve( + scenario_id="quality_analysis", + query="质量异常批次", + collection="quality_analysis", + top_k=3, + store_path=store_path, + ) + + assert result.success is True + assert result.chunks_count >= 1 + assert chunks + assert chunks[0]["source"] == "quality.md" + assert "质量异常" in chunks[0]["content"] + assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks) + + +def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path): + store_path = tmp_path / "rag_store.json" + + ingest_document( + document_id=1, + scenario_id="knowledge_qa", + source_file="old.md", + text="旧制度要求人工登记。", + collection="knowledge_qa", + store_path=store_path, + ) + ingest_document( + document_id=1, + scenario_id="knowledge_qa", + source_file="new.md", + text="新制度要求系统自动登记。", + collection="knowledge_qa", + store_path=store_path, + ) + ingest_document( + document_id=2, + scenario_id="knowledge_qa", + source_file="other.md", + text="系统自动登记后需要生成审计记录。", + collection="knowledge_qa", + store_path=store_path, + ) + + chunks = retrieve( + scenario_id="knowledge_qa", + query="系统自动登记", + collection="knowledge_qa", + top_k=5, + document_ids=[1], + store_path=store_path, + ) + + assert chunks + assert {chunk["document_id"] for chunk in chunks} == {1} + assert all(chunk["source"] == "new.md" for chunk in chunks) + assert all("旧制度" not in chunk["content"] for chunk in chunks) + + +def test_run_agent_uses_retrieved_document_chunks(tmp_path): + store_path = tmp_path / "rag_store.json" + ingest_document( + scenario_id="knowledge_qa", + source_file="sop.md", + text="异常处理 SOP:先隔离现场,再通知负责人。", + collection="knowledge_qa", + store_path=store_path, + ) + scenario = { + "id": "knowledge_qa", + "name": "知识库问答助手", + "rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3}, + "tools": [], + "output": {"type": "general_answer"}, + } + + result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path}) + + assert result.references[0]["source"] == "sop.md" + assert "隔离现场" in result.references[0]["content"] diff --git a/tests/test_llm_provider.py b/tests/test_llm_provider.py new file mode 100644 index 0000000..5ac0e78 --- /dev/null +++ b/tests/test_llm_provider.py @@ -0,0 +1,111 @@ +from agent_core.llm_provider import ( + EmbeddingConfigurationError, + LLMConfigurationError, + create_embedding_provider, + create_llm_provider, +) + + +def test_create_llm_provider_requires_api_key_for_openai_compatible(): + provider = create_llm_provider( + { + "LLM_API_KEY": "", + "LLM_BASE_URL": "https://api.openai.com/v1", + "LLM_MODEL": "gpt-4.1-mini", + "LLM_PROVIDER": "openai_compatible", + } + ) + + response = provider.generate([{"role": "user", "content": "hello"}]) + + assert response.success is False + assert isinstance(response.error, LLMConfigurationError) + assert "LLM_API_KEY" in str(response.error) + + +def test_mock_provider_returns_deterministic_content(): + provider = create_llm_provider({"LLM_PROVIDER": "mock", "LLM_MODEL": "demo-model"}) + + response = provider.generate([{"role": "user", "content": "hello"}]) + + assert response.success is True + assert response.model_name == "demo-model" + assert "hello" in response.content + + +def test_openai_compatible_provider_posts_chat_completion(monkeypatch): + captured = {} + + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, traceback): + return False + + def read(self): + return b'{"choices":[{"message":{"content":"ok"}}],"model":"demo-model"}' + + def fake_urlopen(request, timeout): + captured["url"] = request.full_url + captured["headers"] = dict(request.header_items()) + captured["body"] = request.data.decode("utf-8") + return FakeResponse() + + monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen) + provider = create_llm_provider( + { + "LLM_PROVIDER": "openai_compatible", + "LLM_API_KEY": "sk-test", + "LLM_BASE_URL": "https://api.siliconflow.cn/v1", + "LLM_MODEL": "demo-model", + } + ) + + response = provider.generate([{"role": "user", "content": "hello"}]) + + assert response.success is True + assert response.content == "ok" + assert captured["url"] == "https://api.siliconflow.cn/v1/chat/completions" + assert '"model": "demo-model"' in captured["body"] + assert captured["headers"]["Authorization"] == "Bearer sk-test" + + +def test_embedding_provider_uses_openai_compatible_embeddings(monkeypatch): + class FakeResponse: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, traceback): + return False + + def read(self): + return b'{"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]}' + + monkeypatch.setattr("agent_core.llm_provider.urlopen", lambda request, timeout: FakeResponse()) + provider = create_embedding_provider( + { + "EMBEDDING_API_KEY": "sk-test", + "EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1", + "EMBEDDING_MODEL": "demo-embedding", + } + ) + + assert provider.embed_texts(["a", "b"]) == [[0.1, 0.2], [0.3, 0.4]] + + +def test_embedding_provider_requires_api_key(): + provider = create_embedding_provider( + { + "EMBEDDING_API_KEY": "", + "EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1", + "EMBEDDING_MODEL": "demo-embedding", + } + ) + + try: + provider.embed_texts(["a"]) + except EmbeddingConfigurationError as exc: + assert "EMBEDDING_API_KEY" in str(exc) + else: + raise AssertionError("expected EmbeddingConfigurationError")