From 18d045d4874c9c95c186fc2a970064f0f2aba333 Mon Sep 17 00:00:00 2001 From: bruce Date: Sat, 6 Jun 2026 01:20:26 +0800 Subject: [PATCH] =?UTF-8?q?feat(file-summary):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=A4=84=E7=90=86=E6=8A=80=E8=83=BD=E9=93=BE?= =?UTF-8?q?=E8=B7=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- review_agent/file_summary/paths.py | 12 +++ .../file_summary/services/__init__.py | 1 + review_agent/file_summary/services/archive.py | 77 +++++++++++++++++++ .../file_summary/services/inventory.py | 49 ++++++++++++ .../file_summary/services/page_count.py | 59 ++++++++++++++ .../file_summary/services/product_detect.py | 31 ++++++++ review_agent/file_summary/skills/__init__.py | 1 + .../file_summary/skills/archive_extract.py | 26 +++++++ review_agent/file_summary/skills/base.py | 24 ++++++ .../skills/document_page_count.py | 64 +++++++++++++++ .../file_summary/skills/file_inventory.py | 21 +++++ .../file_summary/skills/product_detect.py | 12 +++ review_agent/file_summary/skills/registry.py | 22 ++++++ review_agent/file_summary/workflow.py | 43 ++++++++--- tests/test_file_summary_archive.py | 25 ++++++ tests/test_file_summary_inventory.py | 24 ++++++ tests/test_file_summary_page_count.py | 66 ++++++++++++++++ tests/test_file_summary_product_detect.py | 29 +++++++ tests/test_file_summary_skills.py | 27 +++++++ 19 files changed, 604 insertions(+), 9 deletions(-) create mode 100644 review_agent/file_summary/paths.py create mode 100644 review_agent/file_summary/services/__init__.py create mode 100644 review_agent/file_summary/services/archive.py create mode 100644 review_agent/file_summary/services/inventory.py create mode 100644 review_agent/file_summary/services/page_count.py create mode 100644 review_agent/file_summary/services/product_detect.py create mode 100644 review_agent/file_summary/skills/__init__.py create mode 100644 review_agent/file_summary/skills/archive_extract.py create mode 100644 review_agent/file_summary/skills/base.py create mode 100644 review_agent/file_summary/skills/document_page_count.py create mode 100644 review_agent/file_summary/skills/file_inventory.py create mode 100644 review_agent/file_summary/skills/product_detect.py create mode 100644 review_agent/file_summary/skills/registry.py create mode 100644 tests/test_file_summary_archive.py create mode 100644 tests/test_file_summary_inventory.py create mode 100644 tests/test_file_summary_page_count.py create mode 100644 tests/test_file_summary_product_detect.py create mode 100644 tests/test_file_summary_skills.py diff --git a/review_agent/file_summary/paths.py b/review_agent/file_summary/paths.py new file mode 100644 index 0000000..8735825 --- /dev/null +++ b/review_agent/file_summary/paths.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from pathlib import Path + +from django.conf import settings + + +def resolve_storage_path(storage_path: str) -> Path: + path = Path(storage_path) + if path.is_absolute(): + return path + return Path(settings.MEDIA_ROOT) / path diff --git a/review_agent/file_summary/services/__init__.py b/review_agent/file_summary/services/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/review_agent/file_summary/services/__init__.py @@ -0,0 +1 @@ + diff --git a/review_agent/file_summary/services/archive.py b/review_agent/file_summary/services/archive.py new file mode 100644 index 0000000..9e554e8 --- /dev/null +++ b/review_agent/file_summary/services/archive.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path +from zipfile import ZipFile + +import py7zr + + +ARCHIVE_EXTENSIONS = {"zip", "7z", "rar"} + + +def _ensure_inside_target(path: Path, target_dir: Path) -> None: + target = target_dir.resolve() + resolved = path.resolve() + if target != resolved and target not in resolved.parents: + raise ValueError("解压路径必须位于批次工作目录内。") + + +def _safe_member_path(target_dir: Path, member_name: str) -> Path: + destination = target_dir / member_name + _ensure_inside_target(destination, target_dir) + return destination + + +def extract_archive(archive_path: str | Path, target_dir: str | Path) -> list[Path]: + archive_path = Path(archive_path) + target_dir = Path(target_dir) + target_dir.mkdir(parents=True, exist_ok=True) + ext = archive_path.suffix.lower().lstrip(".") + if ext not in ARCHIVE_EXTENSIONS: + return [] + + if ext == "zip": + return _extract_zip(archive_path, target_dir) + if ext == "7z": + return _extract_7z(archive_path, target_dir) + return _extract_rar(archive_path, target_dir) + + +def _extract_zip(archive_path: Path, target_dir: Path) -> list[Path]: + extracted: list[Path] = [] + with ZipFile(archive_path) as archive: + for member in archive.infolist(): + destination = _safe_member_path(target_dir, member.filename) + if member.is_dir(): + destination.mkdir(parents=True, exist_ok=True) + continue + destination.parent.mkdir(parents=True, exist_ok=True) + with archive.open(member) as source, destination.open("wb") as target: + target.write(source.read()) + extracted.append(destination) + return extracted + + +def _extract_7z(archive_path: Path, target_dir: Path) -> list[Path]: + with py7zr.SevenZipFile(archive_path, mode="r") as archive: + names = archive.getnames() + for name in names: + _safe_member_path(target_dir, name) + archive.extractall(path=target_dir) + return [target_dir / name for name in names if (target_dir / name).is_file()] + + +def _extract_rar(archive_path: Path, target_dir: Path) -> list[Path]: + result = subprocess.run( + ["7z", "x", f"-o{target_dir}", str(archive_path), "-y"], + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr or result.stdout or "rar 解压失败") + extracted = [path for path in target_dir.rglob("*") if path.is_file()] + for path in extracted: + _ensure_inside_target(path, target_dir) + return extracted diff --git a/review_agent/file_summary/services/inventory.py b/review_agent/file_summary/services/inventory.py new file mode 100644 index 0000000..e7282db --- /dev/null +++ b/review_agent/file_summary/services/inventory.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from pathlib import Path + +from review_agent.models import FileSummaryBatch, FileSummaryItem + + +SUPPORTED_EXTENSIONS = {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx"} + + +def _directory_level(relative_path: Path) -> str: + if len(relative_path.parts) <= 1: + return "" + return "/".join(relative_path.parts[:-1]) + + +def scan_files_to_items(*, batch: FileSummaryBatch, roots: list[Path]) -> list[FileSummaryItem]: + files: list[tuple[Path, Path]] = [] + for root in roots: + root = Path(root) + if root.is_file(): + files.append((root.parent, root)) + continue + for path in sorted(item for item in root.rglob("*") if item.is_file()): + if path.name.startswith(".") or path.stat().st_size == 0: + continue + files.append((root, path)) + + created: list[FileSummaryItem] = [] + for index, (root, path) in enumerate(files, start=1): + relative = path.relative_to(root).as_posix() + file_type = path.suffix.lower().lstrip(".") + item = FileSummaryItem.objects.create( + batch=batch, + file_index=index, + directory_level=_directory_level(Path(relative)), + file_name=path.name, + file_type=file_type, + relative_path=relative, + storage_path=str(path), + statistics_status=FileSummaryItem.StatisticsStatus.SKIPPED, + ) + created.append(item) + + batch.total_files = len(created) + batch.supported_files = sum(1 for item in created if item.file_type in SUPPORTED_EXTENSIONS) + batch.unsupported_files = len(created) - batch.supported_files + batch.save(update_fields=["total_files", "supported_files", "unsupported_files"]) + return created diff --git a/review_agent/file_summary/services/page_count.py b/review_agent/file_summary/services/page_count.py new file mode 100644 index 0000000..3a90b9b --- /dev/null +++ b/review_agent/file_summary/services/page_count.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +SUPPORTED_EXTENSIONS = {"pdf", "doc", "docx", "xls", "xlsx", "ppt", "pptx"} + + +@dataclass(frozen=True) +class PageCountResult: + status: str + page_count: int | None = None + error_message: str = "" + + +def count_document_pages(path: str | Path) -> PageCountResult: + file_path = Path(path) + ext = file_path.suffix.lower().lstrip(".") + if ext not in SUPPORTED_EXTENSIONS: + return PageCountResult(status="unsupported") + + try: + if ext == "pdf": + from pypdf import PdfReader + + return PageCountResult(status="success", page_count=len(PdfReader(str(file_path)).pages)) + if ext == "docx": + from docx import Document + + properties = Document(str(file_path)).core_properties + pages = getattr(properties, "pages", None) + if pages: + return PageCountResult(status="success", page_count=pages) + return PageCountResult(status="uncertain") + if ext == "xlsx": + from openpyxl import load_workbook + + workbook = load_workbook(str(file_path), read_only=True, data_only=True) + return PageCountResult(status="success", page_count=len(workbook.sheetnames)) + if ext == "xls": + import xlrd + + workbook = xlrd.open_workbook(str(file_path), on_demand=True) + return PageCountResult(status="success", page_count=workbook.nsheets) + if ext == "pptx": + from pptx import Presentation + + return PageCountResult(status="success", page_count=len(Presentation(str(file_path)).slides)) + if ext in {"doc", "ppt"}: + import olefile + + if olefile.isOleFile(str(file_path)): + return PageCountResult(status="uncertain") + return PageCountResult(status="failed", error_message="不是有效的 OLE 文件。") + except Exception as exc: + return PageCountResult(status="failed", error_message=str(exc)) + + return PageCountResult(status="uncertain") diff --git a/review_agent/file_summary/services/product_detect.py b/review_agent/file_summary/services/product_detect.py new file mode 100644 index 0000000..ff48dba --- /dev/null +++ b/review_agent/file_summary/services/product_detect.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from pathlib import Path + +from review_agent.models import FileSummaryBatch + + +def detect_product_name(batch: FileSummaryBatch) -> str: + product_name = "" + for item in batch.items.order_by("file_index"): + parts = Path(item.relative_path).parts + if len(parts) > 1: + product_name = parts[0] + break + name = Path(item.file_name).stem + for keyword in ("产品", "试剂盒", "说明书"): + if keyword in name: + product_name = name + break + if product_name: + break + + if not product_name: + return "" + + batch.product_name = product_name + batch.save(update_fields=["product_name"]) + if batch.conversation.title.startswith("新对话"): + batch.conversation.title = f"{product_name}-文件汇总" + batch.conversation.save(update_fields=["title", "updated_at"]) + return product_name diff --git a/review_agent/file_summary/skills/__init__.py b/review_agent/file_summary/skills/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/review_agent/file_summary/skills/__init__.py @@ -0,0 +1 @@ + diff --git a/review_agent/file_summary/skills/archive_extract.py b/review_agent/file_summary/skills/archive_extract.py new file mode 100644 index 0000000..83487b8 --- /dev/null +++ b/review_agent/file_summary/skills/archive_extract.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pathlib import Path + +from review_agent.models import FileSummaryBatchAttachment + +from ..paths import resolve_storage_path +from ..services.archive import ARCHIVE_EXTENSIONS, extract_archive +from .base import BaseSkill, SkillResult, WorkflowContext + + +class ArchiveExtractSkill(BaseSkill): + name = "archive_extract" + + def run(self, context: WorkflowContext) -> SkillResult: + extracted_count = 0 + target_dir = Path(context.batch.work_dir or "") + if not target_dir: + return SkillResult(success=True, data={"extracted_count": 0}) + + for binding in FileSummaryBatchAttachment.objects.filter(batch=context.batch): + path = resolve_storage_path(binding.attachment.storage_path) + if path.suffix.lower().lstrip(".") not in ARCHIVE_EXTENSIONS: + continue + extracted_count += len(extract_archive(path, target_dir)) + return SkillResult(success=True, data={"extracted_count": extracted_count}) diff --git a/review_agent/file_summary/skills/base.py b/review_agent/file_summary/skills/base.py new file mode 100644 index 0000000..b8e6313 --- /dev/null +++ b/review_agent/file_summary/skills/base.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from review_agent.models import FileSummaryBatch + + +@dataclass(frozen=True) +class WorkflowContext: + batch: FileSummaryBatch + + +@dataclass +class SkillResult: + success: bool + data: dict = field(default_factory=dict) + message: str = "" + + +class BaseSkill: + name = "" + + def run(self, context: WorkflowContext) -> SkillResult: + raise NotImplementedError diff --git a/review_agent/file_summary/skills/document_page_count.py b/review_agent/file_summary/skills/document_page_count.py new file mode 100644 index 0000000..f53ad77 --- /dev/null +++ b/review_agent/file_summary/skills/document_page_count.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from review_agent.models import FileSummaryItem + +from ..services.page_count import SUPPORTED_EXTENSIONS, count_document_pages +from .base import BaseSkill, SkillResult, WorkflowContext + + +class DocumentPageCountSkill(BaseSkill): + name = "document_page_count" + + def run(self, context: WorkflowContext) -> SkillResult: + success_files = failed_files = unsupported_files = uncertain_files = total_pages = 0 + for item in context.batch.items.order_by("file_index"): + if item.file_type not in SUPPORTED_EXTENSIONS: + item.statistics_status = FileSummaryItem.StatisticsStatus.UNSUPPORTED + unsupported_files += 1 + item.save(update_fields=["statistics_status", "updated_at"]) + continue + + result = None + for attempt in range(1, 4): + result = count_document_pages(item.storage_path) + item.retry_count = attempt - 1 + if result.status != "failed": + break + item.statistics_status = result.status + item.page_count = result.page_count + item.error_message = result.error_message + item.save( + update_fields=[ + "statistics_status", + "page_count", + "retry_count", + "error_message", + "updated_at", + ] + ) + + if result.status == FileSummaryItem.StatisticsStatus.SUCCESS: + success_files += 1 + total_pages += result.page_count or 0 + elif result.status == FileSummaryItem.StatisticsStatus.UNCERTAIN: + uncertain_files += 1 + elif result.status == FileSummaryItem.StatisticsStatus.UNSUPPORTED: + unsupported_files += 1 + else: + failed_files += 1 + + context.batch.success_files = success_files + context.batch.failed_files = failed_files + context.batch.unsupported_files = unsupported_files + context.batch.uncertain_files = uncertain_files + context.batch.total_pages = total_pages + context.batch.save( + update_fields=[ + "success_files", + "failed_files", + "unsupported_files", + "uncertain_files", + "total_pages", + ] + ) + return SkillResult(success=True) diff --git a/review_agent/file_summary/skills/file_inventory.py b/review_agent/file_summary/skills/file_inventory.py new file mode 100644 index 0000000..75a94dc --- /dev/null +++ b/review_agent/file_summary/skills/file_inventory.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pathlib import Path + +from review_agent.models import FileSummaryBatchAttachment + +from ..paths import resolve_storage_path +from ..services.inventory import scan_files_to_items +from .base import BaseSkill, SkillResult, WorkflowContext + + +class FileInventorySkill(BaseSkill): + name = "file_inventory" + + def run(self, context: WorkflowContext) -> SkillResult: + roots = [ + resolve_storage_path(binding.attachment.storage_path) + for binding in FileSummaryBatchAttachment.objects.filter(batch=context.batch) + ] + items = scan_files_to_items(batch=context.batch, roots=roots) + return SkillResult(success=True, data={"total_files": len(items)}) diff --git a/review_agent/file_summary/skills/product_detect.py b/review_agent/file_summary/skills/product_detect.py new file mode 100644 index 0000000..cf86b63 --- /dev/null +++ b/review_agent/file_summary/skills/product_detect.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from ..services.product_detect import detect_product_name +from .base import BaseSkill, SkillResult, WorkflowContext + + +class ProductDetectSkill(BaseSkill): + name = "product_detect" + + def run(self, context: WorkflowContext) -> SkillResult: + product_name = detect_product_name(context.batch) + return SkillResult(success=True, data={"product_name": product_name}) diff --git a/review_agent/file_summary/skills/registry.py b/review_agent/file_summary/skills/registry.py new file mode 100644 index 0000000..9dde1e7 --- /dev/null +++ b/review_agent/file_summary/skills/registry.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from .base import BaseSkill, SkillResult, WorkflowContext + + +class SkillRegistry: + def __init__(self): + self._skills: dict[str, BaseSkill] = {} + + def register(self, skill: BaseSkill) -> None: + if not skill.name: + raise ValueError("Skill 必须声明 name。") + self._skills[skill.name] = skill + + def get(self, name: str) -> BaseSkill: + try: + return self._skills[name] + except KeyError as exc: + raise KeyError(f"Skill 未注册:{name}") from exc + + def execute(self, name: str, context: WorkflowContext) -> SkillResult: + return self.get(name).run(context) diff --git a/review_agent/file_summary/workflow.py b/review_agent/file_summary/workflow.py index 9316350..65b517f 100644 --- a/review_agent/file_summary/workflow.py +++ b/review_agent/file_summary/workflow.py @@ -16,19 +16,34 @@ from review_agent.models import ( ) from .events import record_event +from .skills.archive_extract import ArchiveExtractSkill +from .skills.base import WorkflowContext +from .skills.document_page_count import DocumentPageCountSkill +from .skills.file_inventory import FileInventorySkill +from .skills.product_detect import ProductDetectSkill +from .skills.registry import SkillRegistry NODE_DEFINITIONS = [ - ("upload", "附件固化"), - ("extract", "压缩包解压"), - ("inventory", "文件扫描"), - ("page_count", "页数统计"), - ("product_detect", "产品识别"), - ("report", "报告输出"), - ("complete", "完成"), + ("upload", "附件固化", ""), + ("extract", "压缩包解压", "archive_extract"), + ("inventory", "文件扫描", "file_inventory"), + ("page_count", "页数统计", "document_page_count"), + ("product_detect", "产品识别", "product_detect"), + ("report", "报告输出", ""), + ("complete", "完成", ""), ] +def default_skill_registry() -> SkillRegistry: + registry = SkillRegistry() + registry.register(ArchiveExtractSkill()) + registry.register(FileInventorySkill()) + registry.register(DocumentPageCountSkill()) + registry.register(ProductDetectSkill()) + return registry + + def build_batch_no() -> str: return f"FS-{timezone.localtime().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[:6]}" @@ -61,7 +76,7 @@ def create_file_summary_batch( attachment.upload_status = FileAttachment.UploadStatus.BOUND attachment.save(update_fields=["upload_status"]) - for code, name in NODE_DEFINITIONS: + for code, name, _skill_name in NODE_DEFINITIONS: WorkflowNodeRun.objects.create(batch=batch, node_code=code, node_name=name) record_event(batch, "workflow_created", {"batch_id": batch.pk, "batch_no": batch.batch_no}) @@ -69,8 +84,9 @@ def create_file_summary_batch( class WorkflowExecutor: - def __init__(self, batch: FileSummaryBatch): + def __init__(self, batch: FileSummaryBatch, registry: SkillRegistry | None = None): self.batch = batch + self.registry = registry or default_skill_registry() def run(self) -> None: self.batch.status = FileSummaryBatch.Status.RUNNING @@ -107,6 +123,15 @@ class WorkflowExecutor: {"node_code": node.node_code, "status": node.status, "progress": node.progress}, ) + skill_name = next( + (skill for code, _name, skill in NODE_DEFINITIONS if code == node.node_code), + "", + ) + if skill_name: + result = self.registry.execute(skill_name, WorkflowContext(batch=self.batch)) + if not result.success: + raise RuntimeError(result.message or f"{node.node_name}执行失败") + node.status = WorkflowNodeRun.Status.SUCCESS node.progress = 100 node.finished_at = timezone.now() diff --git a/tests/test_file_summary_archive.py b/tests/test_file_summary_archive.py new file mode 100644 index 0000000..29a1a80 --- /dev/null +++ b/tests/test_file_summary_archive.py @@ -0,0 +1,25 @@ +from zipfile import ZipFile +import pytest + +from review_agent.file_summary.services.archive import extract_archive + + +def test_extract_zip_preserves_safe_paths(tmp_path): + archive_path = tmp_path / "safe.zip" + with ZipFile(archive_path, "w") as archive: + archive.writestr("dir/a.txt", "content") + + target = tmp_path / "out" + extracted = extract_archive(archive_path, target) + + assert extracted == [target / "dir" / "a.txt"] + assert (target / "dir" / "a.txt").read_text(encoding="utf-8") == "content" + + +def test_extract_zip_rejects_path_traversal(tmp_path): + archive_path = tmp_path / "evil.zip" + with ZipFile(archive_path, "w") as archive: + archive.writestr("../evil.txt", "bad") + + with pytest.raises(ValueError): + extract_archive(archive_path, tmp_path / "out") diff --git a/tests/test_file_summary_inventory.py b/tests/test_file_summary_inventory.py new file mode 100644 index 0000000..74758a5 --- /dev/null +++ b/tests/test_file_summary_inventory.py @@ -0,0 +1,24 @@ +from pathlib import Path +import pytest + +from review_agent.file_summary.services.inventory import scan_files_to_items +from review_agent.models import Conversation, FileSummaryBatch, FileSummaryItem + + +pytestmark = pytest.mark.django_db + + +def test_scan_files_to_items_preserves_relative_paths(tmp_path, django_user_model): + root = tmp_path / "work" + (root / "a").mkdir(parents=True) + (root / "a" / "one.pdf").write_bytes(b"pdf") + (root / "two.txt").write_text("x", encoding="utf-8") + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + batch = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-I") + + items = scan_files_to_items(batch=batch, roots=[root]) + + assert [item.relative_path for item in items] == ["a/one.pdf", "two.txt"] + assert FileSummaryItem.objects.filter(batch=batch).count() == 2 + assert items[0].statistics_status == FileSummaryItem.StatisticsStatus.SKIPPED diff --git a/tests/test_file_summary_page_count.py b/tests/test_file_summary_page_count.py new file mode 100644 index 0000000..e3c6077 --- /dev/null +++ b/tests/test_file_summary_page_count.py @@ -0,0 +1,66 @@ +import pytest +from docx import Document +from openpyxl import Workbook +from pptx import Presentation + +from review_agent.file_summary.services.page_count import count_document_pages +from review_agent.file_summary.skills.document_page_count import DocumentPageCountSkill +from review_agent.file_summary.skills.base import WorkflowContext +from review_agent.models import Conversation, FileSummaryBatch, FileSummaryItem + + +pytestmark = pytest.mark.django_db + + +def test_count_document_pages_for_office_formats(tmp_path): + docx_path = tmp_path / "a.docx" + Document().save(docx_path) + + xlsx_path = tmp_path / "a.xlsx" + workbook = Workbook() + workbook.create_sheet("第二页") + workbook.save(xlsx_path) + + pptx_path = tmp_path / "a.pptx" + presentation = Presentation() + presentation.slides.add_slide(presentation.slide_layouts[6]) + presentation.save(pptx_path) + + assert count_document_pages(docx_path).status in {"success", "uncertain"} + assert count_document_pages(xlsx_path).page_count == 2 + assert count_document_pages(pptx_path).page_count == 1 + + +def test_document_page_count_skill_marks_unsupported_and_success(tmp_path, django_user_model): + xlsx_path = tmp_path / "a.xlsx" + workbook = Workbook() + workbook.save(xlsx_path) + txt_path = tmp_path / "a.txt" + txt_path.write_text("x", encoding="utf-8") + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + batch = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-P") + xlsx_item = FileSummaryItem.objects.create( + batch=batch, + file_index=1, + file_name="a.xlsx", + file_type="xlsx", + relative_path="a.xlsx", + storage_path=str(xlsx_path), + ) + txt_item = FileSummaryItem.objects.create( + batch=batch, + file_index=2, + file_name="a.txt", + file_type="txt", + relative_path="a.txt", + storage_path=str(txt_path), + ) + + result = DocumentPageCountSkill().run(WorkflowContext(batch=batch)) + + xlsx_item.refresh_from_db() + txt_item.refresh_from_db() + assert result.success is True + assert xlsx_item.statistics_status == FileSummaryItem.StatisticsStatus.SUCCESS + assert txt_item.statistics_status == FileSummaryItem.StatisticsStatus.UNSUPPORTED diff --git a/tests/test_file_summary_product_detect.py b/tests/test_file_summary_product_detect.py new file mode 100644 index 0000000..8cf895c --- /dev/null +++ b/tests/test_file_summary_product_detect.py @@ -0,0 +1,29 @@ +import pytest + +from review_agent.file_summary.services.product_detect import detect_product_name +from review_agent.models import Conversation, FileSummaryBatch, FileSummaryItem + + +pytestmark = pytest.mark.django_db + + +def test_detect_product_name_from_top_level_directory(django_user_model): + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="新对话 06-06") + batch = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-D") + FileSummaryItem.objects.create( + batch=batch, + file_index=1, + file_name="说明书.docx", + file_type="docx", + relative_path="甲型试剂盒/说明书.docx", + storage_path="x", + ) + + product_name = detect_product_name(batch) + + batch.refresh_from_db() + conversation.refresh_from_db() + assert product_name == "甲型试剂盒" + assert batch.product_name == "甲型试剂盒" + assert conversation.title == "甲型试剂盒-文件汇总" diff --git a/tests/test_file_summary_skills.py b/tests/test_file_summary_skills.py new file mode 100644 index 0000000..a700155 --- /dev/null +++ b/tests/test_file_summary_skills.py @@ -0,0 +1,27 @@ +import pytest + +from review_agent.file_summary.skills.base import BaseSkill, SkillResult, WorkflowContext +from review_agent.file_summary.skills.registry import SkillRegistry + + +class EchoSkill(BaseSkill): + name = "echo" + + def run(self, context): + return SkillResult(success=True, data={"batch_id": context.batch.id}) + + +@pytest.mark.django_db +def test_skill_registry_executes_registered_skill(django_user_model): + from review_agent.models import Conversation, FileSummaryBatch + + user = django_user_model.objects.create_user(username="owner", password="pass") + conversation = Conversation.objects.create(user=user, title="会话") + batch = FileSummaryBatch.objects.create(conversation=conversation, user=user, batch_no="FS-X") + registry = SkillRegistry() + registry.register(EchoSkill()) + + result = registry.execute("echo", WorkflowContext(batch=batch)) + + assert result.success is True + assert result.data == {"batch_id": batch.id}