176 lines
5.5 KiB
Python
176 lines
5.5 KiB
Python
from agent_core.llm_provider import (
|
|
EmbeddingConfigurationError,
|
|
LLMConfigurationError,
|
|
create_embedding_provider,
|
|
create_llm_provider,
|
|
get_runtime_llm_config,
|
|
)
|
|
from urllib.error import HTTPError
|
|
from io import BytesIO
|
|
|
|
|
|
def test_create_llm_provider_requires_api_key_for_openai_compatible():
|
|
provider = create_llm_provider(
|
|
{
|
|
"LLM_API_KEY": "",
|
|
"LLM_BASE_URL": "https://api.openai.com/v1",
|
|
"LLM_MODEL": "gpt-4.1-mini",
|
|
"LLM_PROVIDER": "openai_compatible",
|
|
}
|
|
)
|
|
|
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
|
|
|
assert response.success is False
|
|
assert isinstance(response.error, LLMConfigurationError)
|
|
assert "LLM_API_KEY" in str(response.error)
|
|
|
|
|
|
def test_mock_provider_returns_deterministic_content():
|
|
provider = create_llm_provider({"LLM_PROVIDER": "mock", "LLM_MODEL": "demo-model"})
|
|
|
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
|
|
|
assert response.success is True
|
|
assert response.model_name == "demo-model"
|
|
assert "hello" in response.content
|
|
|
|
|
|
def test_openai_compatible_provider_posts_chat_completion(monkeypatch):
|
|
captured = {}
|
|
|
|
class FakeResponse:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, traceback):
|
|
return False
|
|
|
|
def read(self):
|
|
return b'{"choices":[{"message":{"content":"ok"}}],"model":"demo-model"}'
|
|
|
|
def fake_urlopen(request, timeout):
|
|
captured["url"] = request.full_url
|
|
captured["headers"] = dict(request.header_items())
|
|
captured["body"] = request.data.decode("utf-8")
|
|
return FakeResponse()
|
|
|
|
monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen)
|
|
provider = create_llm_provider(
|
|
{
|
|
"LLM_PROVIDER": "openai_compatible",
|
|
"LLM_API_KEY": "sk-test",
|
|
"LLM_BASE_URL": "https://api.siliconflow.cn/v1",
|
|
"LLM_MODEL": "demo-model",
|
|
}
|
|
)
|
|
|
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
|
|
|
assert response.success is True
|
|
assert response.content == "ok"
|
|
assert captured["url"] == "https://api.siliconflow.cn/v1/chat/completions"
|
|
assert '"model": "demo-model"' in captured["body"]
|
|
assert captured["headers"]["Authorization"] == "Bearer sk-test"
|
|
|
|
|
|
def test_openai_compatible_provider_falls_back_when_response_format_is_rejected(monkeypatch):
|
|
captured_bodies = []
|
|
|
|
class FakeResponse:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, traceback):
|
|
return False
|
|
|
|
def read(self):
|
|
return b'{"choices":[{"message":{"content":"fallback ok"}}],"model":"demo-model"}'
|
|
|
|
def fake_urlopen(request, timeout):
|
|
body = request.data.decode("utf-8")
|
|
captured_bodies.append(body)
|
|
if len(captured_bodies) == 1:
|
|
raise HTTPError(
|
|
request.full_url,
|
|
400,
|
|
"Bad Request",
|
|
hdrs=None,
|
|
fp=BytesIO(b'{"error":{"message":"response_format is not supported"}}'),
|
|
)
|
|
return FakeResponse()
|
|
|
|
monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen)
|
|
provider = create_llm_provider(
|
|
{
|
|
"LLM_PROVIDER": "openai_compatible",
|
|
"LLM_API_KEY": "sk-test",
|
|
"LLM_BASE_URL": "https://api.siliconflow.cn/v1",
|
|
"LLM_MODEL": "demo-model",
|
|
}
|
|
)
|
|
|
|
response = provider.generate(
|
|
[{"role": "user", "content": "hello"}],
|
|
response_format={"type": "json_object"},
|
|
)
|
|
|
|
assert response.success is True
|
|
assert response.content == "fallback ok"
|
|
assert '"response_format"' in captured_bodies[0]
|
|
assert '"response_format"' not in captured_bodies[1]
|
|
|
|
|
|
def test_embedding_provider_uses_openai_compatible_embeddings(monkeypatch):
|
|
class FakeResponse:
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc, traceback):
|
|
return False
|
|
|
|
def read(self):
|
|
return b'{"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]}'
|
|
|
|
monkeypatch.setattr("agent_core.llm_provider.urlopen", lambda request, timeout: FakeResponse())
|
|
provider = create_embedding_provider(
|
|
{
|
|
"EMBEDDING_API_KEY": "sk-test",
|
|
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
|
"EMBEDDING_MODEL": "demo-embedding",
|
|
}
|
|
)
|
|
|
|
assert provider.embed_texts(["a", "b"]) == [[0.1, 0.2], [0.3, 0.4]]
|
|
|
|
|
|
def test_embedding_provider_requires_api_key():
|
|
provider = create_embedding_provider(
|
|
{
|
|
"EMBEDDING_API_KEY": "",
|
|
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
|
"EMBEDDING_MODEL": "demo-embedding",
|
|
}
|
|
)
|
|
|
|
try:
|
|
provider.embed_texts(["a"])
|
|
except EmbeddingConfigurationError as exc:
|
|
assert "EMBEDDING_API_KEY" in str(exc)
|
|
else:
|
|
raise AssertionError("expected EmbeddingConfigurationError")
|
|
|
|
|
|
def test_get_runtime_llm_config_uses_environment_and_overrides(monkeypatch):
|
|
monkeypatch.setenv("LLM_PROVIDER", "mock")
|
|
monkeypatch.setenv("LLM_API_KEY", "sk-env")
|
|
monkeypatch.setenv("LLM_BASE_URL", "https://env.example/v1")
|
|
monkeypatch.setenv("LLM_MODEL", "env-model")
|
|
|
|
config = get_runtime_llm_config({"LLM_MODEL": "override-model"})
|
|
|
|
assert config["LLM_PROVIDER"] == "mock"
|
|
assert config["LLM_API_KEY"] == "sk-env"
|
|
assert config["LLM_BASE_URL"] == "https://env.example/v1"
|
|
assert config["LLM_MODEL"] == "override-model"
|