chore(master): 清理V2合并后的旧版遗留文件
This commit is contained in:
@@ -1,250 +0,0 @@
|
||||
from agent_core.orchestrator import build_messages, run_agent
|
||||
from agent_core.rag.ingest import _split_text, ingest_document
|
||||
from agent_core.rag.retriever import retrieve
|
||||
|
||||
|
||||
def test_run_agent_returns_structured_result_from_llm_output():
|
||||
scenario = {
|
||||
"id": "knowledge_qa",
|
||||
"name": "知识库问答助手",
|
||||
"agent": {
|
||||
"role": "知识库助手",
|
||||
"goal": "基于资料回答问题",
|
||||
"instructions": ["仅根据证据回答"],
|
||||
},
|
||||
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||||
"tools": ["generate_action_items"],
|
||||
"output": {"type": "general_answer"},
|
||||
}
|
||||
provider_response = """
|
||||
{
|
||||
"answer": "请先隔离异常现场,再通知负责人。",
|
||||
"confidence": "high",
|
||||
"references": [
|
||||
{"source": "sop.md", "excerpt": "异常处理 SOP:先隔离现场"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
class FakeProvider:
|
||||
def generate(self, messages, response_format=None):
|
||||
from agent_core.llm_provider import LLMResponse
|
||||
|
||||
return LLMResponse(
|
||||
content=provider_response,
|
||||
model_name="demo-model",
|
||||
success=True,
|
||||
)
|
||||
|
||||
result = run_agent(
|
||||
scenario,
|
||||
"如何处理异常?",
|
||||
options={"llm_provider": FakeProvider()},
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
assert result.answer == "请先隔离异常现场,再通知负责人。"
|
||||
assert result.structured_output["output_type"] == "general_answer"
|
||||
assert result.structured_output["confidence"] == "high"
|
||||
assert isinstance(result.references, list)
|
||||
assert result.tool_calls[0]["tool_name"] == "generate_action_items"
|
||||
assert result.model_name == "demo-model"
|
||||
|
||||
|
||||
def test_run_agent_falls_back_when_llm_returns_non_json():
|
||||
scenario = {
|
||||
"id": "document_review",
|
||||
"name": "文档审核助手",
|
||||
"agent": {
|
||||
"role": "审核助手",
|
||||
"goal": "总结审核意见",
|
||||
"instructions": ["输出重点问题"],
|
||||
},
|
||||
"rag": {"enabled": False},
|
||||
"tools": [],
|
||||
"output": {"type": "document_review_report"},
|
||||
}
|
||||
|
||||
class FakeProvider:
|
||||
def generate(self, messages, response_format=None):
|
||||
from agent_core.llm_provider import LLMResponse
|
||||
|
||||
return LLMResponse(
|
||||
content="这是非 JSON 的普通回答",
|
||||
model_name="demo-model",
|
||||
success=True,
|
||||
)
|
||||
|
||||
result = run_agent(
|
||||
scenario,
|
||||
"请检查合同风险",
|
||||
options={"llm_provider": FakeProvider()},
|
||||
)
|
||||
|
||||
assert result.status == "success"
|
||||
assert result.answer == "这是非 JSON 的普通回答"
|
||||
assert result.structured_output["output_type"] == "document_review_report"
|
||||
assert result.structured_output["summary"] == "这是非 JSON 的普通回答"
|
||||
assert result.structured_output["parse_mode"] == "fallback"
|
||||
|
||||
|
||||
def test_build_messages_contains_role_goal_references_and_tool_results():
|
||||
scenario = {
|
||||
"name": "质量异常分析助手",
|
||||
"agent": {
|
||||
"role": "质量管理专家",
|
||||
"goal": "生成结构化质量分析报告",
|
||||
"instructions": ["必须引用知识库", "缺失信息要说明"],
|
||||
},
|
||||
"output": {"type": "quality_report"},
|
||||
}
|
||||
|
||||
messages = build_messages(
|
||||
scenario_config=scenario,
|
||||
user_input="分析 A 线异常",
|
||||
references=[{"source": "sop.md", "content": "先隔离现场"}],
|
||||
tool_calls=[
|
||||
{
|
||||
"tool_name": "query_demo_records",
|
||||
"success": True,
|
||||
"result": {"records": [{"title": "A线缺陷"}]},
|
||||
"error": "",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert messages[0]["role"] == "system"
|
||||
assert "质量管理专家" in messages[0]["content"]
|
||||
assert "生成结构化质量分析报告" in messages[0]["content"]
|
||||
assert "quality_report" in messages[0]["content"]
|
||||
assert "先隔离现场" in messages[1]["content"]
|
||||
assert "A线缺陷" in messages[1]["content"]
|
||||
assert "分析 A 线异常" in messages[2]["content"]
|
||||
|
||||
|
||||
def test_rag_ingest_and_retrieve_filters_by_scenario_and_query(tmp_path):
|
||||
store_path = tmp_path / "rag_store.json"
|
||||
text = "设备点检需要先断电挂牌。质量异常需要记录批次、工位和缺陷现象。"
|
||||
|
||||
result = ingest_document(
|
||||
scenario_id="quality_analysis",
|
||||
source_file="quality.md",
|
||||
text=text,
|
||||
collection="quality_analysis",
|
||||
store_path=store_path,
|
||||
)
|
||||
ingest_document(
|
||||
scenario_id="risk_audit",
|
||||
source_file="risk.md",
|
||||
text="报销审核需要检查发票、金额和审批链。",
|
||||
collection="risk_audit",
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
chunks = retrieve(
|
||||
scenario_id="quality_analysis",
|
||||
query="质量异常批次",
|
||||
collection="quality_analysis",
|
||||
top_k=3,
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.chunks_count >= 1
|
||||
assert chunks
|
||||
assert chunks[0]["source"] == "quality.md"
|
||||
assert "质量异常" in chunks[0]["content"]
|
||||
assert all(chunk["scenario_id"] == "quality_analysis" for chunk in chunks)
|
||||
|
||||
|
||||
def test_rag_reingest_replaces_same_document_and_retrieve_filters_document_ids(tmp_path):
|
||||
store_path = tmp_path / "rag_store.json"
|
||||
|
||||
ingest_document(
|
||||
document_id=1,
|
||||
scenario_id="knowledge_qa",
|
||||
source_file="old.md",
|
||||
text="旧制度要求人工登记。",
|
||||
collection="knowledge_qa",
|
||||
store_path=store_path,
|
||||
)
|
||||
ingest_document(
|
||||
document_id=1,
|
||||
scenario_id="knowledge_qa",
|
||||
source_file="new.md",
|
||||
text="新制度要求系统自动登记。",
|
||||
collection="knowledge_qa",
|
||||
store_path=store_path,
|
||||
)
|
||||
ingest_document(
|
||||
document_id=2,
|
||||
scenario_id="knowledge_qa",
|
||||
source_file="other.md",
|
||||
text="系统自动登记后需要生成审计记录。",
|
||||
collection="knowledge_qa",
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
chunks = retrieve(
|
||||
scenario_id="knowledge_qa",
|
||||
query="系统自动登记",
|
||||
collection="knowledge_qa",
|
||||
top_k=5,
|
||||
document_ids=[1],
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
assert chunks
|
||||
assert {chunk["document_id"] for chunk in chunks} == {1}
|
||||
assert all(chunk["source"] == "new.md" for chunk in chunks)
|
||||
assert all("旧制度" not in chunk["content"] for chunk in chunks)
|
||||
|
||||
|
||||
def test_run_agent_uses_retrieved_document_chunks(tmp_path):
|
||||
store_path = tmp_path / "rag_store.json"
|
||||
ingest_document(
|
||||
scenario_id="knowledge_qa",
|
||||
source_file="sop.md",
|
||||
text="异常处理 SOP:先隔离现场,再通知负责人。",
|
||||
collection="knowledge_qa",
|
||||
store_path=store_path,
|
||||
)
|
||||
scenario = {
|
||||
"id": "knowledge_qa",
|
||||
"name": "知识库问答助手",
|
||||
"rag": {"enabled": True, "collection": "knowledge_qa", "top_k": 3},
|
||||
"tools": [],
|
||||
"output": {"type": "general_answer"},
|
||||
}
|
||||
|
||||
result = run_agent(scenario, "异常处理怎么做?", options={"rag_store_path": store_path})
|
||||
|
||||
assert result.references[0]["source"] == "sop.md"
|
||||
assert "隔离现场" in result.references[0]["content"]
|
||||
|
||||
|
||||
def test_rag_split_text_keeps_overlap_and_non_empty_chunks():
|
||||
chunks = _split_text("A" * 20, chunk_size=8, overlap=3)
|
||||
|
||||
assert chunks == ["AAAAAAAA", "AAAAAAAA", "AAAAAAAA", "AAAAA"]
|
||||
|
||||
|
||||
def test_retrieve_returns_empty_when_query_has_no_overlap(tmp_path):
|
||||
store_path = tmp_path / "rag_store.json"
|
||||
ingest_document(
|
||||
scenario_id="knowledge_qa",
|
||||
source_file="rules.md",
|
||||
text="这里描述的是报销流程和审批链。",
|
||||
collection="knowledge_qa",
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
chunks = retrieve(
|
||||
scenario_id="knowledge_qa",
|
||||
query="设备点检",
|
||||
collection="knowledge_qa",
|
||||
top_k=3,
|
||||
store_path=store_path,
|
||||
)
|
||||
|
||||
assert chunks == []
|
||||
@@ -1,119 +0,0 @@
|
||||
from django.urls import reverse
|
||||
|
||||
from agent_core.results import AgentResult
|
||||
from apps.audit.models import AgentAuditLog, DemoBusinessRecord
|
||||
from apps.audit.services import create_audit_log
|
||||
from agent_core.tools.builtin_tools import query_demo_records
|
||||
|
||||
|
||||
def test_create_audit_log_records_success_result(db):
|
||||
result = AgentResult(answer="回答", structured_output={"x": 1}, status="success")
|
||||
|
||||
log = create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||
|
||||
assert AgentAuditLog.objects.count() == 1
|
||||
assert log.final_answer == "回答"
|
||||
assert log.structured_output == {"x": 1}
|
||||
assert log.status == "success"
|
||||
|
||||
|
||||
def test_audit_list_page_shows_log(client, db):
|
||||
result = AgentResult(answer="回答", status="success")
|
||||
create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||
|
||||
response = client.get(reverse("audit:list"))
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "知识库问答助手" in response.content.decode("utf-8")
|
||||
|
||||
|
||||
def test_audit_list_can_filter_by_scenario(client, db):
|
||||
create_audit_log(
|
||||
"knowledge_qa",
|
||||
"知识库问答助手",
|
||||
"制度问题",
|
||||
AgentResult(answer="回答一", status="success"),
|
||||
)
|
||||
create_audit_log(
|
||||
"quality_analysis",
|
||||
"质量异常分析助手",
|
||||
"质量问题",
|
||||
AgentResult(answer="回答二", status="success"),
|
||||
)
|
||||
|
||||
response = client.get(reverse("audit:list"), {"scenario_id": "knowledge_qa"})
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "知识库问答助手" in content
|
||||
assert "质量异常分析助手" not in content
|
||||
|
||||
|
||||
def test_audit_list_page_shows_user_input_summary(client, db):
|
||||
create_audit_log(
|
||||
"knowledge_qa",
|
||||
"知识库问答助手",
|
||||
"这是一个比较长的用户输入,用于确认列表页会展示输入摘要。",
|
||||
AgentResult(answer="回答", status="success"),
|
||||
)
|
||||
|
||||
response = client.get(reverse("audit:list"))
|
||||
|
||||
assert "这是一个比较长的用户输入" in response.content.decode("utf-8")
|
||||
|
||||
|
||||
def test_audit_detail_page_shows_raw_output(client, db):
|
||||
result = AgentResult(
|
||||
answer="结构化回答",
|
||||
raw_output='{"answer":"结构化回答","confidence":"high"}',
|
||||
status="success",
|
||||
)
|
||||
log = create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||
|
||||
response = client.get(reverse("audit:detail", args=[log.id]))
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "原始输出" in content
|
||||
assert "confidence" in content
|
||||
assert "high" in content
|
||||
|
||||
|
||||
def test_create_audit_log_masks_api_keys_from_error_message(db):
|
||||
result = AgentResult(
|
||||
answer="",
|
||||
status="failed",
|
||||
error="LLM_API_KEY=sk-secret-value 调用失败",
|
||||
)
|
||||
|
||||
log = create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||
|
||||
assert "sk-secret-value" not in log.error_message
|
||||
assert "sk-***" in log.error_message
|
||||
|
||||
|
||||
def test_create_audit_log_masks_embedding_api_keys_from_error_message(db):
|
||||
result = AgentResult(
|
||||
answer="",
|
||||
status="failed",
|
||||
error="EMBEDDING_API_KEY=embed-secret 调用失败",
|
||||
)
|
||||
|
||||
log = create_audit_log("knowledge_qa", "知识库问答助手", "问题", result)
|
||||
|
||||
assert "embed-secret" not in log.error_message
|
||||
assert "EMBEDDING_API_KEY=***" in log.error_message
|
||||
|
||||
|
||||
def test_query_demo_records_reads_demo_business_record_table(db):
|
||||
DemoBusinessRecord.objects.create(
|
||||
scenario_id="quality_analysis",
|
||||
record_type="defect",
|
||||
title="A线缺陷",
|
||||
payload={"rate": 0.12},
|
||||
)
|
||||
|
||||
result = query_demo_records(user_input="quality_analysis defect")
|
||||
|
||||
assert result["records"][0]["title"] == "A线缺陷"
|
||||
assert result["records"][0]["payload"] == {"rate": 0.12}
|
||||
@@ -1,107 +0,0 @@
|
||||
from django.urls import reverse
|
||||
|
||||
from agent_core.results import AgentResult
|
||||
from apps.audit.models import AgentAuditLog
|
||||
from apps.documents.models import UploadedDocument
|
||||
|
||||
|
||||
def test_chat_post_returns_agent_result_and_audit_log(client, db):
|
||||
response = client.post(
|
||||
reverse("chat:index", args=["knowledge_qa"]),
|
||||
{"message": "如何处理异常?"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
content = response.content.decode("utf-8")
|
||||
assert "mock-model" in content
|
||||
assert "模拟回答" in content
|
||||
assert AgentAuditLog.objects.count() == 1
|
||||
|
||||
|
||||
def test_chat_rejects_empty_message(client, db):
|
||||
response = client.post(reverse("chat:index", args=["knowledge_qa"]), {"message": ""})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert AgentAuditLog.objects.count() == 0
|
||||
assert "请输入要咨询的问题" in response.content.decode("utf-8")
|
||||
|
||||
|
||||
def test_chat_passes_selected_document_ids_to_agent_core(client, db, monkeypatch):
|
||||
selected = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="selected.md",
|
||||
file_type="md",
|
||||
size=1,
|
||||
status=UploadedDocument.STATUS_INDEXED,
|
||||
)
|
||||
other = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="other.md",
|
||||
file_type="md",
|
||||
size=1,
|
||||
status=UploadedDocument.STATUS_INDEXED,
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def fake_run_agent(scenario_config, user_input, options=None):
|
||||
captured["options"] = options or {}
|
||||
from agent_core.results import AgentResult
|
||||
|
||||
return AgentResult(answer="ok", status="success")
|
||||
|
||||
monkeypatch.setattr("apps.chat.views.run_agent", fake_run_agent)
|
||||
|
||||
response = client.post(
|
||||
reverse("chat:index", args=["knowledge_qa"]),
|
||||
{"message": "只查选中文档", "document_ids": [str(selected.id)]},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert captured["options"]["document_ids"] == [selected.id]
|
||||
assert other.id not in captured["options"]["document_ids"]
|
||||
|
||||
|
||||
def test_chat_renders_structured_output_references_and_tool_calls(client, db, monkeypatch):
|
||||
def fake_run_agent(scenario_config, user_input, options=None):
|
||||
return AgentResult(
|
||||
answer="建议先隔离现场。",
|
||||
structured_output={
|
||||
"output_type": "quality_report",
|
||||
"summary": "发现异常批次需要立即处置。",
|
||||
"risk_level": "high",
|
||||
"suggested_actions": ["隔离现场", "通知负责人"],
|
||||
},
|
||||
references=[
|
||||
{
|
||||
"source": "sop.md",
|
||||
"content": "异常处理 SOP:先隔离现场,再通知负责人。",
|
||||
}
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"tool_name": "query_demo_records",
|
||||
"success": True,
|
||||
"result": {"records": [{"title": "A线缺陷"}]},
|
||||
"error": "",
|
||||
}
|
||||
],
|
||||
model_name="mock-model",
|
||||
status="success",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("apps.chat.views.run_agent", fake_run_agent)
|
||||
|
||||
response = client.post(
|
||||
reverse("chat:index", args=["quality_analysis"]),
|
||||
{"message": "分析 A 线异常"},
|
||||
)
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "结构化结果" in content
|
||||
assert "发现异常批次需要立即处置" in content
|
||||
assert "引用片段" in content
|
||||
assert "sop.md" in content
|
||||
assert "工具调用" in content
|
||||
assert "query_demo_records" in content
|
||||
assert "查看本次审计日志" in content
|
||||
@@ -1,147 +0,0 @@
|
||||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||||
from django.urls import reverse
|
||||
|
||||
from apps.documents.forms import DocumentUploadForm
|
||||
from apps.documents.models import UploadedDocument
|
||||
from apps.documents.services import extract_text, index_document
|
||||
|
||||
|
||||
def test_upload_txt_document_creates_uploaded_record(client, db):
|
||||
file = SimpleUploadedFile("rules.txt", "hello".encode("utf-8"), content_type="text/plain")
|
||||
|
||||
response = client.post(
|
||||
reverse("documents:upload"),
|
||||
{"scenario_id": "knowledge_qa", "file": file},
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
document = UploadedDocument.objects.get()
|
||||
assert document.status == "uploaded"
|
||||
assert document.file_type == "txt"
|
||||
assert document.scenario_id == "knowledge_qa"
|
||||
|
||||
|
||||
def test_upload_redirect_shows_success_message(client, db):
|
||||
file = SimpleUploadedFile("notice.txt", "hello".encode("utf-8"), content_type="text/plain")
|
||||
|
||||
response = client.post(
|
||||
reverse("documents:upload"),
|
||||
{"scenario_id": "knowledge_qa", "file": file},
|
||||
follow=True,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "文件已上传,可继续执行入库" in response.content.decode("utf-8")
|
||||
|
||||
|
||||
def test_upload_accepts_pdf_and_docx_documents(client, db):
|
||||
for filename, payload in [
|
||||
("policy.pdf", b"%PDF-1.4\nplain policy text"),
|
||||
("contract.docx", b"fake-docx-body"),
|
||||
]:
|
||||
file = SimpleUploadedFile(filename, payload)
|
||||
|
||||
response = client.post(
|
||||
reverse("documents:upload"),
|
||||
{"scenario_id": "knowledge_qa", "file": file},
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
assert set(UploadedDocument.objects.values_list("file_type", flat=True)) == {"pdf", "docx"}
|
||||
|
||||
|
||||
def test_index_document_updates_status_to_indexed(client, db):
|
||||
document = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="rules.md",
|
||||
file="knowledge_qa/rules.md",
|
||||
file_type="md",
|
||||
size=5,
|
||||
status="uploaded",
|
||||
)
|
||||
document.file.save("rules.md", SimpleUploadedFile("rules.md", b"# rule").file)
|
||||
|
||||
response = client.post(reverse("documents:index", args=[document.id]))
|
||||
|
||||
assert response.status_code == 302
|
||||
document.refresh_from_db()
|
||||
assert document.status == "indexed"
|
||||
assert document.error_message == ""
|
||||
|
||||
|
||||
def test_extract_text_supports_pdf_and_docx_plain_text_fallback(db):
|
||||
pdf_document = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="policy.pdf",
|
||||
file_type="pdf",
|
||||
size=10,
|
||||
status="uploaded",
|
||||
)
|
||||
pdf_document.file.save("policy.pdf", SimpleUploadedFile("policy.pdf", b"%PDF-1.4\nSafety policy"))
|
||||
|
||||
docx_document = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="contract.docx",
|
||||
file_type="docx",
|
||||
size=10,
|
||||
status="uploaded",
|
||||
)
|
||||
docx_document.file.save(
|
||||
"contract.docx",
|
||||
SimpleUploadedFile("contract.docx", b"Contract clause review"),
|
||||
)
|
||||
|
||||
assert "Safety policy" in extract_text(pdf_document)
|
||||
assert "Contract clause review" in extract_text(docx_document)
|
||||
|
||||
|
||||
def test_document_upload_form_builds_scenario_choices():
|
||||
form = DocumentUploadForm()
|
||||
|
||||
choice_values = [value for value, _label in form.fields["scenario_id"].choices]
|
||||
|
||||
assert "knowledge_qa" in choice_values
|
||||
assert "quality_analysis" in choice_values
|
||||
|
||||
|
||||
def test_index_failure_message_is_visible_on_document_list(client, db, monkeypatch):
|
||||
document = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="broken.md",
|
||||
file_type="md",
|
||||
size=5,
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
def fake_index_document(target_document):
|
||||
target_document.status = UploadedDocument.STATUS_FAILED
|
||||
target_document.error_message = "模拟入库失败"
|
||||
target_document.save(update_fields=["status", "error_message", "updated_at"])
|
||||
return target_document
|
||||
|
||||
monkeypatch.setattr("apps.documents.views.index_document", fake_index_document)
|
||||
|
||||
response = client.post(reverse("documents:index", args=[document.id]), follow=True)
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "文档入库失败,请检查错误原因后重试" in content
|
||||
assert "模拟入库失败" in content
|
||||
|
||||
|
||||
def test_index_document_marks_failed_when_extracted_text_is_empty(db, monkeypatch):
|
||||
document = UploadedDocument.objects.create(
|
||||
scenario_id="knowledge_qa",
|
||||
original_name="empty.md",
|
||||
file_type="md",
|
||||
size=0,
|
||||
status="uploaded",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("apps.documents.services.extract_text", lambda target: " ")
|
||||
|
||||
updated_document = index_document(document)
|
||||
|
||||
assert updated_document.status == UploadedDocument.STATUS_FAILED
|
||||
assert "文档内容为空" in updated_document.error_message
|
||||
@@ -1,126 +0,0 @@
|
||||
from agent_core.llm_provider import (
|
||||
EmbeddingConfigurationError,
|
||||
LLMConfigurationError,
|
||||
create_embedding_provider,
|
||||
create_llm_provider,
|
||||
get_runtime_llm_config,
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_provider_requires_api_key_for_openai_compatible():
|
||||
provider = create_llm_provider(
|
||||
{
|
||||
"LLM_API_KEY": "",
|
||||
"LLM_BASE_URL": "https://api.openai.com/v1",
|
||||
"LLM_MODEL": "gpt-4.1-mini",
|
||||
"LLM_PROVIDER": "openai_compatible",
|
||||
}
|
||||
)
|
||||
|
||||
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.success is False
|
||||
assert isinstance(response.error, LLMConfigurationError)
|
||||
assert "LLM_API_KEY" in str(response.error)
|
||||
|
||||
|
||||
def test_mock_provider_returns_deterministic_content():
|
||||
provider = create_llm_provider({"LLM_PROVIDER": "mock", "LLM_MODEL": "demo-model"})
|
||||
|
||||
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.success is True
|
||||
assert response.model_name == "demo-model"
|
||||
assert "hello" in response.content
|
||||
|
||||
|
||||
def test_openai_compatible_provider_posts_chat_completion(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class FakeResponse:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return b'{"choices":[{"message":{"content":"ok"}}],"model":"demo-model"}'
|
||||
|
||||
def fake_urlopen(request, timeout):
|
||||
captured["url"] = request.full_url
|
||||
captured["headers"] = dict(request.header_items())
|
||||
captured["body"] = request.data.decode("utf-8")
|
||||
return FakeResponse()
|
||||
|
||||
monkeypatch.setattr("agent_core.llm_provider.urlopen", fake_urlopen)
|
||||
provider = create_llm_provider(
|
||||
{
|
||||
"LLM_PROVIDER": "openai_compatible",
|
||||
"LLM_API_KEY": "sk-test",
|
||||
"LLM_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||
"LLM_MODEL": "demo-model",
|
||||
}
|
||||
)
|
||||
|
||||
response = provider.generate([{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.success is True
|
||||
assert response.content == "ok"
|
||||
assert captured["url"] == "https://api.siliconflow.cn/v1/chat/completions"
|
||||
assert '"model": "demo-model"' in captured["body"]
|
||||
assert captured["headers"]["Authorization"] == "Bearer sk-test"
|
||||
|
||||
|
||||
def test_embedding_provider_uses_openai_compatible_embeddings(monkeypatch):
|
||||
class FakeResponse:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, traceback):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return b'{"data":[{"embedding":[0.1,0.2]},{"embedding":[0.3,0.4]}]}'
|
||||
|
||||
monkeypatch.setattr("agent_core.llm_provider.urlopen", lambda request, timeout: FakeResponse())
|
||||
provider = create_embedding_provider(
|
||||
{
|
||||
"EMBEDDING_API_KEY": "sk-test",
|
||||
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||
"EMBEDDING_MODEL": "demo-embedding",
|
||||
}
|
||||
)
|
||||
|
||||
assert provider.embed_texts(["a", "b"]) == [[0.1, 0.2], [0.3, 0.4]]
|
||||
|
||||
|
||||
def test_embedding_provider_requires_api_key():
|
||||
provider = create_embedding_provider(
|
||||
{
|
||||
"EMBEDDING_API_KEY": "",
|
||||
"EMBEDDING_BASE_URL": "https://api.siliconflow.cn/v1",
|
||||
"EMBEDDING_MODEL": "demo-embedding",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
provider.embed_texts(["a"])
|
||||
except EmbeddingConfigurationError as exc:
|
||||
assert "EMBEDDING_API_KEY" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected EmbeddingConfigurationError")
|
||||
|
||||
|
||||
def test_get_runtime_llm_config_uses_environment_and_overrides(monkeypatch):
|
||||
monkeypatch.setenv("LLM_PROVIDER", "mock")
|
||||
monkeypatch.setenv("LLM_API_KEY", "sk-env")
|
||||
monkeypatch.setenv("LLM_BASE_URL", "https://env.example/v1")
|
||||
monkeypatch.setenv("LLM_MODEL", "env-model")
|
||||
|
||||
config = get_runtime_llm_config({"LLM_MODEL": "override-model"})
|
||||
|
||||
assert config["LLM_PROVIDER"] == "mock"
|
||||
assert config["LLM_API_KEY"] == "sk-env"
|
||||
assert config["LLM_BASE_URL"] == "https://env.example/v1"
|
||||
assert config["LLM_MODEL"] == "override-model"
|
||||
@@ -1,21 +0,0 @@
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
from django.urls import reverse
|
||||
|
||||
|
||||
def test_core_settings_expose_documented_paths():
|
||||
assert settings.SCENARIO_CONFIG_DIR.name == "configs"
|
||||
assert settings.CHROMA_PATH.name == "chroma"
|
||||
assert settings.MEDIA_ROOT.name == "uploads"
|
||||
assert settings.EMBEDDING_MODEL == os.environ.get(
|
||||
"EMBEDDING_MODEL",
|
||||
"text-embedding-3-small",
|
||||
)
|
||||
assert settings.EMBEDDING_BASE_URL == settings.LLM_BASE_URL
|
||||
assert settings.EMBEDDING_API_KEY == settings.LLM_API_KEY
|
||||
|
||||
|
||||
def test_home_url_is_registered(client):
|
||||
response = client.get(reverse("scenarios:index"))
|
||||
assert response.status_code == 200
|
||||
@@ -1,129 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from apps.scenarios.services import (
|
||||
ScenarioNotFound,
|
||||
get_scenario,
|
||||
list_scenario_issues,
|
||||
list_scenarios,
|
||||
)
|
||||
|
||||
|
||||
def test_list_scenarios_loads_five_configs():
|
||||
scenarios = list_scenarios()
|
||||
assert [scenario["id"] for scenario in scenarios] == [
|
||||
"document_review",
|
||||
"knowledge_qa",
|
||||
"quality_analysis",
|
||||
"risk_audit",
|
||||
"ticket_assistant",
|
||||
]
|
||||
|
||||
|
||||
def test_get_scenario_returns_full_agent_config():
|
||||
scenario = get_scenario("quality_analysis")
|
||||
assert scenario["agent"]["role"] == "质量管理专家"
|
||||
assert scenario["rag"]["enabled"] is True
|
||||
assert scenario["output"]["type"] == "quality_report"
|
||||
assert "质量异常分析" in scenario["applicable_questions"][0]
|
||||
|
||||
|
||||
def test_get_scenario_raises_clear_error_for_missing_id():
|
||||
with pytest.raises(ScenarioNotFound, match="场景不存在"):
|
||||
get_scenario("missing")
|
||||
|
||||
|
||||
def test_home_page_shows_applicable_questions(client):
|
||||
response = client.get("/")
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "适用题型" in content
|
||||
assert "SOP 问答" in content
|
||||
|
||||
|
||||
def test_list_scenarios_skips_invalid_config_and_collects_issues(settings, tmp_path):
|
||||
valid_file = tmp_path / "valid.yaml"
|
||||
invalid_file = tmp_path / "invalid.yaml"
|
||||
valid_file.write_text(
|
||||
"""
|
||||
id: demo_valid
|
||||
name: 有效场景
|
||||
description: 用于测试
|
||||
agent:
|
||||
role: 测试助手
|
||||
goal: 正常返回
|
||||
instructions:
|
||||
- 输出结果
|
||||
rag:
|
||||
enabled: false
|
||||
tools: []
|
||||
output:
|
||||
type: general_answer
|
||||
audit:
|
||||
enabled: true
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
invalid_file.write_text(
|
||||
"""
|
||||
id: broken
|
||||
name: 非法场景
|
||||
description: 缺少 agent.goal
|
||||
agent:
|
||||
role: 测试助手
|
||||
instructions:
|
||||
- 输出结果
|
||||
rag:
|
||||
enabled: true
|
||||
tools: []
|
||||
output:
|
||||
type: general_answer
|
||||
audit:
|
||||
enabled: true
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
settings.SCENARIO_CONFIG_DIR = tmp_path
|
||||
|
||||
scenarios = list_scenarios()
|
||||
issues = list_scenario_issues()
|
||||
|
||||
assert [scenario["id"] for scenario in scenarios] == ["demo_valid"]
|
||||
assert len(issues) == 1
|
||||
assert issues[0]["file_name"] == "invalid.yaml"
|
||||
assert "agent.goal" in issues[0]["message"]
|
||||
|
||||
|
||||
def test_home_page_shows_invalid_scenario_issues_instead_of_500(client, settings, tmp_path):
|
||||
valid_file = tmp_path / "valid.yaml"
|
||||
invalid_file = tmp_path / "invalid.yaml"
|
||||
valid_file.write_text(
|
||||
"""
|
||||
id: demo_valid
|
||||
name: 有效场景
|
||||
description: 用于测试
|
||||
agent:
|
||||
role: 测试助手
|
||||
goal: 正常返回
|
||||
instructions:
|
||||
- 输出结果
|
||||
rag:
|
||||
enabled: false
|
||||
tools: []
|
||||
output:
|
||||
type: general_answer
|
||||
audit:
|
||||
enabled: true
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
invalid_file.write_text("id: broken\nname: 缺失结构", encoding="utf-8")
|
||||
settings.SCENARIO_CONFIG_DIR = tmp_path
|
||||
|
||||
response = client.get("/")
|
||||
|
||||
content = response.content.decode("utf-8")
|
||||
assert response.status_code == 200
|
||||
assert "有效场景" in content
|
||||
assert "配置异常" in content
|
||||
assert "invalid.yaml" in content
|
||||
@@ -1,54 +0,0 @@
|
||||
from agent_core.tool_registry import ToolRegistry, run_declared_tools
|
||||
from agent_core.tools.builtin_tools import calculate_rate, check_required_fields
|
||||
|
||||
|
||||
def test_tool_registry_register_get_and_run():
|
||||
registry = ToolRegistry()
|
||||
|
||||
def hello_tool(user_input: str) -> dict:
|
||||
return {"echo": user_input}
|
||||
|
||||
registry.register("hello", hello_tool)
|
||||
|
||||
assert registry.get("hello") is hello_tool
|
||||
assert registry.run("hello", user_input="demo") == {
|
||||
"tool_name": "hello",
|
||||
"success": True,
|
||||
"arguments": {"user_input": "demo"},
|
||||
"result": {"echo": "demo"},
|
||||
"error": "",
|
||||
}
|
||||
|
||||
|
||||
def test_tool_registry_returns_failed_result_for_missing_tool():
|
||||
registry = ToolRegistry()
|
||||
|
||||
result = registry.run("missing", user_input="demo")
|
||||
|
||||
assert result["tool_name"] == "missing"
|
||||
assert result["success"] is False
|
||||
assert result["error"] == "工具未注册"
|
||||
|
||||
|
||||
def test_run_declared_tools_executes_multiple_tools_in_order():
|
||||
results = run_declared_tools(["generate_action_items", "missing_tool"], "请生成行动项")
|
||||
|
||||
assert [item["tool_name"] for item in results] == ["generate_action_items", "missing_tool"]
|
||||
assert results[0]["success"] is True
|
||||
assert results[1]["success"] is False
|
||||
|
||||
|
||||
def test_calculate_rate_extracts_fraction_like_numbers():
|
||||
result = calculate_rate("产线合格率,已完成 18 件,总数 24 件")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["numerator"] == 18.0
|
||||
assert result["denominator"] == 24.0
|
||||
assert result["rate"] == 0.75
|
||||
|
||||
|
||||
def test_check_required_fields_reports_missing_fields():
|
||||
result = check_required_fields("请检查必填项:合同编号、供应商、金额。当前只提供了合同编号和金额。")
|
||||
|
||||
assert "供应商" in result["missing_fields"]
|
||||
assert "合同编号" not in result["missing_fields"]
|
||||
Reference in New Issue
Block a user