feat(agent-core): 增加智能编排与模型工具基础
This commit is contained in:
1
agent_core/__init__.py
Normal file
1
agent_core/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
132
agent_core/llm_provider.py
Normal file
132
agent_core/llm_provider.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
from urllib.error import URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
|
||||||
|
class LLMConfigurationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfigurationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
content: str = ""
|
||||||
|
model_name: str = ""
|
||||||
|
success: bool = True
|
||||||
|
error: Exception | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMProvider:
|
||||||
|
def __init__(self, model_name: str = "mock-model"):
|
||||||
|
self.model_name = model_name or "mock-model"
|
||||||
|
|
||||||
|
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
||||||
|
user_content = ""
|
||||||
|
for message in reversed(messages):
|
||||||
|
if message.get("role") == "user":
|
||||||
|
user_content = message.get("content", "")
|
||||||
|
break
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"模拟模型回答:{user_content}",
|
||||||
|
model_name=self.model_name,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleProvider:
|
||||||
|
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def generate(self, messages: list[dict], response_format: dict | None = None) -> LLMResponse:
|
||||||
|
if not self.api_key:
|
||||||
|
return LLMResponse(
|
||||||
|
model_name=self.model_name,
|
||||||
|
success=False,
|
||||||
|
error=LLMConfigurationError("LLM_API_KEY 未配置,无法调用 OpenAI 兼容模型接口"),
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
if response_format:
|
||||||
|
payload["response_format"] = response_format
|
||||||
|
try:
|
||||||
|
data = _post_json(
|
||||||
|
base_url=self.base_url,
|
||||||
|
endpoint="chat/completions",
|
||||||
|
api_key=self.api_key,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
content = choice.get("message", {}).get("content", "")
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model_name=data.get("model", self.model_name),
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
return LLMResponse(model_name=self.model_name, success=False, error=exc)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleEmbeddingProvider:
|
||||||
|
def __init__(self, api_key: str, base_url: str, model_name: str):
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def embed_texts(self, texts: list[str]) -> list[list[float]]:
|
||||||
|
if not self.api_key:
|
||||||
|
raise EmbeddingConfigurationError("EMBEDDING_API_KEY 未配置,无法调用 OpenAI 兼容 Embedding 接口")
|
||||||
|
data = _post_json(
|
||||||
|
base_url=self.base_url,
|
||||||
|
endpoint="embeddings",
|
||||||
|
api_key=self.api_key,
|
||||||
|
payload={"model": self.model_name, "input": texts},
|
||||||
|
)
|
||||||
|
return [item.get("embedding", []) for item in data.get("data", [])]
|
||||||
|
|
||||||
|
|
||||||
|
def _post_json(base_url: str, endpoint: str, api_key: str, payload: dict) -> dict:
|
||||||
|
url = f"{base_url.rstrip('/')}/{endpoint}"
|
||||||
|
request = Request(
|
||||||
|
url,
|
||||||
|
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
method="POST",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with urlopen(request, timeout=60) as response:
|
||||||
|
return json.loads(response.read().decode("utf-8"))
|
||||||
|
except URLError as exc:
|
||||||
|
raise RuntimeError(f"OpenAI 兼容接口调用失败:{exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm_provider(config: dict | None = None):
|
||||||
|
config = config or {}
|
||||||
|
provider_name = config.get("LLM_PROVIDER", "mock")
|
||||||
|
model_name = config.get("LLM_MODEL", "mock-model")
|
||||||
|
if provider_name == "mock":
|
||||||
|
return MockLLMProvider(model_name=model_name)
|
||||||
|
return OpenAICompatibleProvider(
|
||||||
|
api_key=config.get("LLM_API_KEY", ""),
|
||||||
|
base_url=config.get("LLM_BASE_URL", "https://api.openai.com/v1"),
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_provider(config: dict | None = None):
|
||||||
|
config = config or {}
|
||||||
|
return OpenAICompatibleEmbeddingProvider(
|
||||||
|
api_key=config.get("EMBEDDING_API_KEY", config.get("LLM_API_KEY", "")),
|
||||||
|
base_url=config.get("EMBEDDING_BASE_URL", config.get("LLM_BASE_URL", "https://api.openai.com/v1")),
|
||||||
|
model_name=config.get("EMBEDDING_MODEL", "text-embedding-3-small"),
|
||||||
|
)
|
||||||
40
agent_core/orchestrator.py
Normal file
40
agent_core/orchestrator.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from .results import AgentResult
|
||||||
|
from .structured_output import build_mock_structured_output
|
||||||
|
from .tool_registry import run_declared_tools
|
||||||
|
from .rag.retriever import retrieve
|
||||||
|
|
||||||
|
|
||||||
|
def run_agent(scenario_config: dict, user_input: str, options: dict | None = None) -> AgentResult:
|
||||||
|
started_at = time.perf_counter()
|
||||||
|
options = options or {}
|
||||||
|
output_type = scenario_config.get("output", {}).get("type", "general_answer")
|
||||||
|
|
||||||
|
references = []
|
||||||
|
rag_config = scenario_config.get("rag", {})
|
||||||
|
if rag_config.get("enabled"):
|
||||||
|
references = retrieve(
|
||||||
|
scenario_id=scenario_config.get("id", ""),
|
||||||
|
query=user_input,
|
||||||
|
collection=rag_config.get("collection", scenario_config.get("id", "")),
|
||||||
|
top_k=rag_config.get("top_k", 5),
|
||||||
|
document_ids=options.get("document_ids"),
|
||||||
|
store_path=options.get("rag_store_path"),
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = run_declared_tools(scenario_config.get("tools", []), user_input)
|
||||||
|
structured_output = build_mock_structured_output(output_type, user_input, references)
|
||||||
|
answer = f"已根据「{scenario_config.get('name', '当前场景')}」生成模拟回答:{user_input}"
|
||||||
|
latency_ms = int((time.perf_counter() - started_at) * 1000)
|
||||||
|
|
||||||
|
return AgentResult(
|
||||||
|
answer=answer,
|
||||||
|
structured_output=structured_output,
|
||||||
|
references=references,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
raw_output=answer,
|
||||||
|
model_name=options.get("model_name", "mock-model"),
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
status="success",
|
||||||
|
)
|
||||||
1
agent_core/rag/__init__.py
Normal file
1
agent_core/rag/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
96
agent_core/rag/chroma_store.py
Normal file
96
agent_core/rag/chroma_store.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from agent_core.llm_provider import create_embedding_provider
|
||||||
|
|
||||||
|
|
||||||
|
def _client(path: str | Path | None = None):
|
||||||
|
import chromadb
|
||||||
|
|
||||||
|
resolved_path = str(path or settings.CHROMA_PATH)
|
||||||
|
return chromadb.PersistentClient(path=resolved_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _embedding_provider():
|
||||||
|
return create_embedding_provider(
|
||||||
|
{
|
||||||
|
"EMBEDDING_API_KEY": settings.EMBEDDING_API_KEY,
|
||||||
|
"EMBEDDING_BASE_URL": settings.EMBEDDING_BASE_URL,
|
||||||
|
"EMBEDDING_MODEL": settings.EMBEDDING_MODEL,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_chunks(
|
||||||
|
collection: str,
|
||||||
|
chunks: list[dict],
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> None:
|
||||||
|
client = _client(store_path)
|
||||||
|
chroma_collection = client.get_or_create_collection(collection)
|
||||||
|
document_ids = {chunk["document_id"] for chunk in chunks if chunk.get("document_id") is not None}
|
||||||
|
for document_id in document_ids:
|
||||||
|
chroma_collection.delete(where={"document_id": document_id})
|
||||||
|
texts = [chunk["content"] for chunk in chunks]
|
||||||
|
embeddings = _embedding_provider().embed_texts(texts)
|
||||||
|
chroma_collection.upsert(
|
||||||
|
ids=[chunk["chunk_id"] for chunk in chunks],
|
||||||
|
documents=texts,
|
||||||
|
embeddings=embeddings,
|
||||||
|
metadatas=[
|
||||||
|
{
|
||||||
|
"scenario_id": chunk["scenario_id"],
|
||||||
|
"document_id": chunk["document_id"],
|
||||||
|
"source": chunk["source"],
|
||||||
|
"chunk_id": chunk["chunk_id"],
|
||||||
|
"created_at": chunk["created_at"],
|
||||||
|
}
|
||||||
|
for chunk in chunks
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def query_chunks(
|
||||||
|
scenario_id: str,
|
||||||
|
query: str,
|
||||||
|
collection: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
document_ids: list[int] | None = None,
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
client = _client(store_path)
|
||||||
|
chroma_collection = client.get_or_create_collection(collection)
|
||||||
|
where: dict = {"scenario_id": scenario_id}
|
||||||
|
if document_ids:
|
||||||
|
where = {
|
||||||
|
"$and": [
|
||||||
|
{"scenario_id": scenario_id},
|
||||||
|
{"document_id": {"$in": document_ids}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
embedding = _embedding_provider().embed_texts([query])[0]
|
||||||
|
result = chroma_collection.query(
|
||||||
|
query_embeddings=[embedding],
|
||||||
|
n_results=top_k,
|
||||||
|
where=where,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
chunks = []
|
||||||
|
documents = result.get("documents", [[]])[0]
|
||||||
|
metadatas = result.get("metadatas", [[]])[0]
|
||||||
|
distances = result.get("distances", [[]])[0]
|
||||||
|
for content, metadata, distance in zip(documents, metadatas, distances):
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"scenario_id": metadata.get("scenario_id"),
|
||||||
|
"document_id": metadata.get("document_id"),
|
||||||
|
"collection": collection,
|
||||||
|
"source": metadata.get("source"),
|
||||||
|
"chunk_id": metadata.get("chunk_id"),
|
||||||
|
"content": content,
|
||||||
|
"created_at": metadata.get("created_at"),
|
||||||
|
"score": round(1 / (1 + float(distance)), 4),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return chunks
|
||||||
116
agent_core/rag/ingest.py
Normal file
116
agent_core/rag/ingest.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import importlib.util
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from .chroma_store import upsert_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IngestResult:
|
||||||
|
success: bool
|
||||||
|
chunks_count: int = 0
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def _default_store_path() -> Path:
|
||||||
|
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_store(store_path: Path) -> list[dict]:
|
||||||
|
if not store_path.exists():
|
||||||
|
return []
|
||||||
|
with store_path.open("r", encoding="utf-8") as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_store(store_path: Path, chunks: list[dict]) -> None:
|
||||||
|
store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with store_path.open("w", encoding="utf-8") as file:
|
||||||
|
json.dump(chunks, file, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_text(text: str, chunk_size: int = 800, overlap: int = 120) -> list[str]:
|
||||||
|
normalized = re.sub(r"\s+", " ", text).strip()
|
||||||
|
if not normalized:
|
||||||
|
return []
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
while start < len(normalized):
|
||||||
|
end = start + chunk_size
|
||||||
|
chunks.append(normalized[start:end])
|
||||||
|
if end >= len(normalized):
|
||||||
|
break
|
||||||
|
start = max(end - overlap, start + 1)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_document(
|
||||||
|
scenario_id: str,
|
||||||
|
source_file: str,
|
||||||
|
text: str,
|
||||||
|
collection: str,
|
||||||
|
document_id: int | None = None,
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> IngestResult:
|
||||||
|
if not text.strip():
|
||||||
|
return IngestResult(success=False, error="文档内容为空")
|
||||||
|
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
||||||
|
return _ingest_chroma_document(document_id, scenario_id, source_file, text, collection)
|
||||||
|
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||||
|
existing_chunks = [
|
||||||
|
chunk
|
||||||
|
for chunk in _load_store(resolved_store_path)
|
||||||
|
if not (
|
||||||
|
chunk.get("document_id") == document_id
|
||||||
|
and chunk.get("scenario_id") == scenario_id
|
||||||
|
and chunk.get("collection") == collection
|
||||||
|
)
|
||||||
|
]
|
||||||
|
created_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
new_chunks = []
|
||||||
|
for index, chunk_text in enumerate(_split_text(text), start=1):
|
||||||
|
new_chunks.append(
|
||||||
|
{
|
||||||
|
"scenario_id": scenario_id,
|
||||||
|
"document_id": document_id,
|
||||||
|
"collection": collection,
|
||||||
|
"source": source_file,
|
||||||
|
"chunk_id": f"{scenario_id}:{source_file}:{index}",
|
||||||
|
"content": chunk_text,
|
||||||
|
"created_at": created_at,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
_save_store(resolved_store_path, [*existing_chunks, *new_chunks])
|
||||||
|
return IngestResult(success=True, chunks_count=len(new_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
def _ingest_chroma_document(
|
||||||
|
document_id: int | None,
|
||||||
|
scenario_id: str,
|
||||||
|
source_file: str,
|
||||||
|
text: str,
|
||||||
|
collection: str,
|
||||||
|
) -> IngestResult:
|
||||||
|
created_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"scenario_id": scenario_id,
|
||||||
|
"document_id": document_id,
|
||||||
|
"collection": collection,
|
||||||
|
"source": source_file,
|
||||||
|
"chunk_id": f"{scenario_id}:{document_id or source_file}:{index}",
|
||||||
|
"content": chunk_text,
|
||||||
|
"created_at": created_at,
|
||||||
|
}
|
||||||
|
for index, chunk_text in enumerate(_split_text(text), start=1)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
upsert_chunks(collection=collection, chunks=chunks)
|
||||||
|
except Exception as exc:
|
||||||
|
return IngestResult(success=False, error=str(exc))
|
||||||
|
return IngestResult(success=True, chunks_count=len(chunks))
|
||||||
69
agent_core/rag/retriever.py
Normal file
69
agent_core/rag/retriever.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
import importlib.util
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from .chroma_store import query_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _default_store_path() -> Path:
|
||||||
|
return Path(settings.CHROMA_PATH) / "rag_store.json"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_store(store_path: Path) -> list[dict]:
|
||||||
|
if not store_path.exists():
|
||||||
|
return []
|
||||||
|
with store_path.open("r", encoding="utf-8") as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
def _tokens(text: str) -> set[str]:
|
||||||
|
lowered = text.lower()
|
||||||
|
ascii_tokens = set(re.findall(r"[a-z0-9_]+", lowered))
|
||||||
|
cjk_tokens = set(re.findall(r"[\u4e00-\u9fff]{2,}", lowered))
|
||||||
|
chars = {char for char in lowered if "\u4e00" <= char <= "\u9fff"}
|
||||||
|
return ascii_tokens | cjk_tokens | chars
|
||||||
|
|
||||||
|
|
||||||
|
def _score(query_tokens: set[str], content: str) -> float:
|
||||||
|
content_tokens = _tokens(content)
|
||||||
|
if not query_tokens or not content_tokens:
|
||||||
|
return 0.0
|
||||||
|
overlap = query_tokens & content_tokens
|
||||||
|
return round(len(overlap) / len(query_tokens), 4)
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
scenario_id: str,
|
||||||
|
query: str,
|
||||||
|
collection: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
document_ids: list[int] | None = None,
|
||||||
|
store_path: str | Path | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
if store_path is None and importlib.util.find_spec("chromadb") is not None:
|
||||||
|
return query_chunks(
|
||||||
|
scenario_id=scenario_id,
|
||||||
|
query=query,
|
||||||
|
collection=collection,
|
||||||
|
top_k=top_k,
|
||||||
|
document_ids=document_ids,
|
||||||
|
)
|
||||||
|
resolved_store_path = Path(store_path) if store_path else _default_store_path()
|
||||||
|
query_tokens = _tokens(query)
|
||||||
|
allowed_document_ids = set(document_ids or [])
|
||||||
|
scored_chunks = []
|
||||||
|
for chunk in _load_store(resolved_store_path):
|
||||||
|
if chunk.get("scenario_id") != scenario_id:
|
||||||
|
continue
|
||||||
|
if chunk.get("collection") != collection:
|
||||||
|
continue
|
||||||
|
if allowed_document_ids and chunk.get("document_id") not in allowed_document_ids:
|
||||||
|
continue
|
||||||
|
score = _score(query_tokens, chunk.get("content", ""))
|
||||||
|
if score <= 0:
|
||||||
|
continue
|
||||||
|
scored_chunks.append({**chunk, "score": score})
|
||||||
|
return sorted(scored_chunks, key=lambda item: item["score"], reverse=True)[:top_k]
|
||||||
14
agent_core/results.py
Normal file
14
agent_core/results.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentResult:
|
||||||
|
answer: str = ""
|
||||||
|
structured_output: dict = field(default_factory=dict)
|
||||||
|
references: list = field(default_factory=list)
|
||||||
|
tool_calls: list = field(default_factory=list)
|
||||||
|
raw_output: str = ""
|
||||||
|
model_name: str = "mock-model"
|
||||||
|
latency_ms: int = 0
|
||||||
|
status: str = "success"
|
||||||
|
error: str = ""
|
||||||
1
agent_core/schemas/__init__.py
Normal file
1
agent_core/schemas/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
7
agent_core/schemas/outputs.py
Normal file
7
agent_core/schemas/outputs.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
SUPPORTED_OUTPUT_TYPES = {
|
||||||
|
"general_answer",
|
||||||
|
"document_review_report",
|
||||||
|
"ticket_response",
|
||||||
|
"quality_report",
|
||||||
|
"risk_audit_report",
|
||||||
|
}
|
||||||
8
agent_core/structured_output.py
Normal file
8
agent_core/structured_output.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
def build_mock_structured_output(output_type: str, user_input: str, references: list) -> dict:
|
||||||
|
return {
|
||||||
|
"output_type": output_type,
|
||||||
|
"summary": f"模拟结构化输出:{user_input}",
|
||||||
|
"references_count": len(references),
|
||||||
|
"risk_level": "low",
|
||||||
|
"suggested_actions": ["补充真实 LLM Provider 后替换模拟结果"],
|
||||||
|
}
|
||||||
40
agent_core/tool_registry.py
Normal file
40
agent_core/tool_registry.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from .tools.builtin_tools import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
|
||||||
|
def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]:
|
||||||
|
results = []
|
||||||
|
for tool_name in tool_names:
|
||||||
|
tool = BUILTIN_TOOLS.get(tool_name)
|
||||||
|
if tool is None:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"success": False,
|
||||||
|
"arguments": {"user_input": user_input},
|
||||||
|
"result": {},
|
||||||
|
"error": "工具未注册",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
result = tool(user_input=user_input)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"success": True,
|
||||||
|
"arguments": {"user_input": user_input},
|
||||||
|
"result": result,
|
||||||
|
"error": "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"success": False,
|
||||||
|
"arguments": {"user_input": user_input},
|
||||||
|
"result": {},
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return results
|
||||||
1
agent_core/tools/__init__.py
Normal file
1
agent_core/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
47
agent_core/tools/builtin_tools.py
Normal file
47
agent_core/tools/builtin_tools.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
def calculate_rate(user_input: str) -> dict:
|
||||||
|
return {"rate": 1.0, "note": "模拟比例计算结果"}
|
||||||
|
|
||||||
|
|
||||||
|
def query_demo_records(user_input: str) -> dict:
|
||||||
|
try:
|
||||||
|
from apps.audit.models import DemoBusinessRecord
|
||||||
|
except Exception as exc:
|
||||||
|
return {"records": [], "error": str(exc)}
|
||||||
|
|
||||||
|
queryset = DemoBusinessRecord.objects.all()
|
||||||
|
tokens = {token.strip().lower() for token in user_input.split() if token.strip()}
|
||||||
|
scenario_ids = set(queryset.values_list("scenario_id", flat=True))
|
||||||
|
record_types = set(queryset.values_list("record_type", flat=True))
|
||||||
|
matched_scenario_ids = scenario_ids & tokens
|
||||||
|
matched_record_types = record_types & tokens
|
||||||
|
if matched_scenario_ids:
|
||||||
|
queryset = queryset.filter(scenario_id__in=matched_scenario_ids)
|
||||||
|
if matched_record_types:
|
||||||
|
queryset = queryset.filter(record_type__in=matched_record_types)
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": record.id,
|
||||||
|
"scenario_id": record.scenario_id,
|
||||||
|
"record_type": record.record_type,
|
||||||
|
"title": record.title,
|
||||||
|
"payload": record.payload,
|
||||||
|
}
|
||||||
|
for record in queryset[:20]
|
||||||
|
]
|
||||||
|
return {"records": records}
|
||||||
|
|
||||||
|
|
||||||
|
def check_required_fields(user_input: str) -> dict:
|
||||||
|
return {"missing_fields": [], "note": "模拟必填项检查结果"}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_action_items(user_input: str) -> dict:
|
||||||
|
return {"items": [f"围绕问题继续核实:{user_input}"]}
|
||||||
|
|
||||||
|
|
||||||
|
BUILTIN_TOOLS = {
|
||||||
|
"calculate_rate": calculate_rate,
|
||||||
|
"query_demo_records": query_demo_records,
|
||||||
|
"check_required_fields": check_required_fields,
|
||||||
|
"generate_action_items": generate_action_items,
|
||||||
|
}
|
||||||
122
tests/test_agent_core.py
Normal file
122
tests/test_agent_core.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
from agent_core.orchestrator import run_agent
|
||||||
|
from agent_core.rag.ingest import ingest_document
|
||||||
|
from agent_core.rag.retriever import retrieve
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_agent_returns_structured_mock_result():
|
||||||
|
scenario = {
|
||||||
|
"id": "knowledge_qa",
|
||||||
|
"name": "知识库问答助手",
|
||||||
|
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||||||
|
"tools": ["generate_action_items"],
|
||||||
|
"output": {"type": "general_answer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = run_agent(scenario, "如何处理异常?")
|
||||||
|
|
||||||
|
assert result.status == "success"
|
||||||
|
assert result.answer
|
||||||
|
assert result.structured_output["output_type"] == "general_answer"
|
||||||
|
assert isinstance(result.references, list)
|
||||||
|
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
|
||||||
|
store_path = tmp_path / "rag_store.json"
|
||||||
|
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
|
||||||
|
|
||||||
|
result = ingest_document(
|
||||||
|
scenario_id="quality_analysis",
|
||||||
|
source_file="quality.md",
|
||||||
|
text=text,
|
||||||
|
collection="quality_analysis",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
ingest_document(
|
||||||
|
scenario_id="risk_audit",
|
||||||
|
source_file="risk.md",
|
||||||
|
text="报销审核需要检查发票、金额和审批链。",
|
||||||
|
collection="risk_audit",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = retrieve(
|
||||||
|
scenario_id="quality_analysis",
|
||||||
|
query="质量异常批次",
|
||||||
|
collection="quality_analysis",
|
||||||
|
top_k=3,
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.chunks_count >= 1
|
||||||
|
assert chunks
|
||||||
|
assert chunks[0]["source"] == "quality.md"
|
||||||
|
assert "质量异常" in chunks[0]["content"]
|
||||||
|
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
|
||||||
|
store_path = tmp_path / "rag_store.json"
|
||||||
|
|
||||||
|
ingest_document(
|
||||||
|
document_id=1,
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
source_file="old.md",
|
||||||
|
text="旧制度要求人工登记。",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
ingest_document(
|
||||||
|
document_id=1,
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
source_file="new.md",
|
||||||
|
text="新制度要求系统自动登记。",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
ingest_document(
|
||||||
|
document_id=2,
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
source_file="other.md",
|
||||||
|
text="系统自动登记后需要生成审计记录。",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = retrieve(
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
query="系统自动登记",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
top_k=5,
|
||||||
|
document_ids=[1],
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chunks
|
||||||
|
assert {chunk["document_id"] for chunk in chunks} == {1}
|
||||||
|
assert all(chunk["source"] == "new.md" for chunk in chunks)
|
||||||
|
assert all("旧制度" not in chunk["content"] for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
|
||||||
|
store_path = tmp_path / "rag_store.json"
|
||||||
|
ingest_document(
|
||||||
|
scenario_id="knowledge_qa",
|
||||||
|
source_file="sop.md",
|
||||||
|
text="异常处理 SOP:先隔离现场,再通知负责人。",
|
||||||
|
collection="knowledge_qa",
|
||||||
|
store_path=store_path,
|
||||||
|
)
|
||||||
|
scenario = {
|
||||||
|
"id": "knowledge_qa",
|
||||||
|
"name": "知识库问答助手",
|
||||||
|
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||||||
|
"tools": [],
|
||||||
|
"output": {"type": "general_answer"},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
|
||||||
|
|
||||||
|
assert result.references[0]["source"] == "sop.md"
|
||||||
|
assert "隔离现场" in result.references[0]["content"]
|
||||||
111
tests/test_llm_provider.py
Normal file
111
tests/test_llm_provider.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from agent_core.llm_provider import (
|
||||||
|
EmbeddingConfigurationError,
|
||||||
|
LLMConfigurationError,
|
||||||
|
create_embedding_provider,
|
||||||
|
create_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_provider_requires_api_key_for_openai_compatible():
|
||||||
|
provider = create_llm_provider(
|
||||||
|
{
|
||||||
|
"LLM_API_KEY": "",
|
||||||
|
"LLM_BASE_URL": "https://api.openai.com/v1",
|
||||||
|
"LLM_MODEL": "gpt-4.1-mini",
|
||||||
|
"LLM_PROVIDER": "openai_compatible",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.success is False
|
||||||
|
assert isinstance(response.error, LLMConfigurationError)
|
||||||
|
assert "LLM_API_KEY" in str(response.error)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_provider_returns_deterministic_content():
|
||||||
|
provider = create_llm_provider({"LLM_PROVIDER": "mock", "LLM_MODEL": "demo-model"})
|
||||||
|
|
||||||
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.model_name == "demo-model"
|
||||||
|
assert "hello" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compatible_provider_posts_chat_completion(monkeypatch):
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, traceback):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
return b'{"choices":[{"message":{"content":"ok"}}],"model":"demo-model"}'
|
||||||
|
|
||||||
|
def fake_urlopen(request, timeout):
|
||||||
|
captured["url"] = request.full_url
|
||||||
|
captured["headers"] = dict(request.header_items())
|
||||||
|
captured["body"] = request.data.decode("utf-8")
|
||||||
|
return FakeResponse()
|
||||||
|
|
||||||
|
monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen)
|
||||||
|
provider = create_llm_provider(
|
||||||
|
{
|
||||||
|
"LLM_PROVIDER": "openai_compatible",
|
||||||
|
"LLM_API_KEY": "sk-test",
|
||||||
|
"LLM_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||||
|
"LLM_MODEL": "demo-model",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||||
|
|
||||||
|
assert response.success is True
|
||||||
|
assert response.content == "ok"
|
||||||
|
assert captured["url"] == "https://api.siliconflow.cn/v1/chat/completions"
|
||||||
|
assert '"model": "demo-model"' in captured["body"]
|
||||||
|
assert captured["headers"]["Authorization"] == "Bearer sk-test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_provider_uses_openai_compatible_embeddings(monkeypatch):
|
||||||
|
class FakeResponse:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, traceback):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def read(self):
|
||||||
|
return b'{"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]}'
|
||||||
|
|
||||||
|
monkeypatch.setattr("agent_core.llm_provider.urlopen", lambda request, timeout: FakeResponse())
|
||||||
|
provider = create_embedding_provider(
|
||||||
|
{
|
||||||
|
"EMBEDDING_API_KEY": "sk-test",
|
||||||
|
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||||
|
"EMBEDDING_MODEL": "demo-embedding",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert provider.embed_texts(["a", "b"]) == [[0.1, 0.2], [0.3, 0.4]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_provider_requires_api_key():
|
||||||
|
provider = create_embedding_provider(
|
||||||
|
{
|
||||||
|
"EMBEDDING_API_KEY": "",
|
||||||
|
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||||
|
"EMBEDDING_MODEL": "demo-embedding",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider.embed_texts(["a"])
|
||||||
|
except EmbeddingConfigurationError as exc:
|
||||||
|
assert "EMBEDDING_API_KEY" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected EmbeddingConfigurationError")
|
||||||
Reference in New Issue
Block a user