437 lines
15 KiB
Python
437 lines
15 KiB
Python
from pathlib import Path
|
||
from io import BytesIO
|
||
import re
|
||
import tempfile
|
||
import xml.etree.ElementTree as ET
|
||
from zipfile import BadZipFile, ZipFile
|
||
|
||
from agent_core.rag.ingest import ingest_document
|
||
from apps.chat.services import create_conversation_for_batch
|
||
from django.core.files.uploadedfile import SimpleUploadedFile
|
||
|
||
from .models import SubmissionBatch, UploadedDocument
|
||
|
||
|
||
def create_uploaded_document(
|
||
scenario_id: str,
|
||
uploaded_file,
|
||
batch: SubmissionBatch | None = None,
|
||
*,
|
||
relative_path: str | None = None,
|
||
) -> UploadedDocument:
|
||
"""
|
||
保存上传文件的元数据记录。
|
||
|
||
Documents 模块只记录文件与场景关系、原始名称、类型和大小,
|
||
真正的入库动作由用户后续主动触发,避免上传阶段就耦合 RAG 流程。
|
||
"""
|
||
extension = _detect_extension(uploaded_file.name)
|
||
return UploadedDocument.objects.create(
|
||
batch=batch,
|
||
scenario_id=scenario_id,
|
||
original_name=Path(relative_path or uploaded_file.name).name,
|
||
file=uploaded_file,
|
||
file_type=extension,
|
||
size=uploaded_file.size,
|
||
relative_path=relative_path or uploaded_file.name,
|
||
status=UploadedDocument.STATUS_UPLOADED,
|
||
)
|
||
|
||
|
||
def import_submission_batch(scenario_id: str, uploaded_files: list) -> dict:
|
||
"""
|
||
导入资料包并建立批次、文档、目录汇总和主会话。
|
||
|
||
当前实现保持离线稳定,重点保证:
|
||
- 资料包记录可落库
|
||
- 产品名称可解析
|
||
- 会话可自动绑定
|
||
- 可直接产出 overview report
|
||
"""
|
||
batch = SubmissionBatch.objects.create(
|
||
batch_id=_generate_batch_id(),
|
||
workflow_type="registration",
|
||
import_status=SubmissionBatch.STATUS_PROCESSING,
|
||
)
|
||
documents = []
|
||
candidates = []
|
||
chapter_summary = {}
|
||
total_pages = 0
|
||
warnings = []
|
||
|
||
expanded_result = _expand_uploaded_files(uploaded_files)
|
||
expanded_files = expanded_result["files"]
|
||
warnings.extend(expanded_result["warnings"])
|
||
for uploaded_item in expanded_files:
|
||
uploaded_file = uploaded_item["uploaded_file"]
|
||
relative_path = uploaded_item["relative_path"]
|
||
document = create_uploaded_document(
|
||
scenario_id,
|
||
uploaded_file,
|
||
batch=batch,
|
||
relative_path=relative_path,
|
||
)
|
||
text = extract_text(document)
|
||
page_count = _estimate_page_count(text)
|
||
document.page_count = page_count
|
||
document.page_count_confidence = "estimated"
|
||
document.document_role = _detect_document_role(document.relative_path)
|
||
document.chapter_code = _detect_chapter_code(document.relative_path, text)
|
||
document.chapter_match_status = "matched" if document.chapter_code else "unknown"
|
||
document.needs_manual_review = not bool(document.chapter_code)
|
||
document.save(
|
||
update_fields=[
|
||
"page_count",
|
||
"page_count_confidence",
|
||
"document_role",
|
||
"chapter_code",
|
||
"chapter_match_status",
|
||
"needs_manual_review",
|
||
"updated_at",
|
||
]
|
||
)
|
||
documents.append(document)
|
||
total_pages += page_count
|
||
chapter_key = document.chapter_code or "UNCLASSIFIED"
|
||
chapter_summary[chapter_key] = chapter_summary.get(chapter_key, 0) + 1
|
||
candidates.extend(_extract_product_candidates(document.relative_path, text))
|
||
|
||
product_name, product_warnings = _select_product_name(candidates)
|
||
warnings.extend(product_warnings)
|
||
conversation = create_conversation_for_batch(batch.batch_id, product_name)
|
||
|
||
if not documents:
|
||
warnings.append("未发现可导入的支持文件,请检查资料包格式或补充 PDF/DOCX/MD/TXT 文件。")
|
||
|
||
batch.product_name = product_name
|
||
batch.conversation_id = conversation.conversation_id
|
||
batch.file_count = len(documents)
|
||
batch.page_count = total_pages
|
||
batch.chapter_summary = [
|
||
{"chapter_code": chapter_code, "document_count": count}
|
||
for chapter_code, count in sorted(chapter_summary.items())
|
||
]
|
||
batch.exception_count = len(warnings)
|
||
if not documents:
|
||
batch.import_status = SubmissionBatch.STATUS_FAILED
|
||
elif warnings:
|
||
batch.import_status = SubmissionBatch.STATUS_REVIEW_REQUIRED
|
||
else:
|
||
batch.import_status = SubmissionBatch.STATUS_COMPLETED
|
||
batch.save(
|
||
update_fields=[
|
||
"product_name",
|
||
"conversation_id",
|
||
"file_count",
|
||
"page_count",
|
||
"chapter_summary",
|
||
"exception_count",
|
||
"import_status",
|
||
"updated_at",
|
||
]
|
||
)
|
||
return {
|
||
"batch_id": batch.batch_id,
|
||
"conversation_id": conversation.conversation_id,
|
||
"product_name": batch.product_name,
|
||
"registration_overview_report": {
|
||
"batch_id": batch.batch_id,
|
||
"product_name": batch.product_name,
|
||
"file_count": batch.file_count,
|
||
"total_page_count": batch.page_count,
|
||
"chapter_summary": batch.chapter_summary,
|
||
"documents": [
|
||
{
|
||
"document_id": document.id,
|
||
"original_name": document.original_name,
|
||
"chapter_code": document.chapter_code,
|
||
"page_count": document.page_count,
|
||
"document_role": document.document_role,
|
||
}
|
||
for document in documents
|
||
],
|
||
"warnings": warnings,
|
||
},
|
||
}
|
||
|
||
|
||
def extract_text(document: UploadedDocument) -> str:
|
||
"""
|
||
根据文档类型选择合适的文本抽取策略。
|
||
|
||
V1 的目标是“可演示且稳定”,因此:
|
||
- `.txt` / `.md` 直接按文本读取
|
||
- `.pdf` 优先走 pypdf,失败时回退为二进制容错读取
|
||
- `.docx` 优先解析 Word XML,失败时回退为二进制容错读取
|
||
"""
|
||
path = Path(document.file.path)
|
||
extension = f".{document.file_type.lower().lstrip('.')}"
|
||
if extension == ".pdf":
|
||
return _extract_pdf_text(path)
|
||
if extension == ".docx":
|
||
return _extract_docx_text(path)
|
||
return _read_text_file(path)
|
||
|
||
|
||
def index_document(document: UploadedDocument) -> UploadedDocument:
|
||
"""
|
||
触发单个文档入库,并把成功/失败状态回写到 UploadedDocument。
|
||
|
||
这里故意不抛业务异常给 View:
|
||
View 层只需要知道“最终状态是什么”,而错误信息统一落到模型字段中,
|
||
便于页面重试和演示。
|
||
"""
|
||
try:
|
||
text = extract_text(document)
|
||
ingest_result = ingest_document(
|
||
document_id=document.id,
|
||
scenario_id=document.scenario_id,
|
||
source_file=document.original_name,
|
||
text=text,
|
||
collection=document.scenario_id,
|
||
)
|
||
_apply_ingest_result(document, ingest_result.success, ingest_result.error)
|
||
except Exception as exc:
|
||
_apply_ingest_result(document, success=False, error=str(exc))
|
||
document.save(update_fields=["status", "error_message", "updated_at"])
|
||
return document
|
||
|
||
|
||
def _apply_ingest_result(document: UploadedDocument, success: bool, error: str = "") -> None:
|
||
"""把入库结果映射为 UploadedDocument 的稳定状态字段。"""
|
||
if success:
|
||
document.status = UploadedDocument.STATUS_INDEXED
|
||
document.error_message = ""
|
||
return
|
||
document.status = UploadedDocument.STATUS_FAILED
|
||
document.error_message = error
|
||
|
||
|
||
def _detect_extension(file_name: str) -> str:
|
||
"""统一将扩展名转成小写且去掉前导点,便于模型字段存储。"""
|
||
return Path(file_name).suffix.lower().lstrip(".")
|
||
|
||
|
||
def _generate_batch_id() -> str:
|
||
return f"SUB-20260604-{SubmissionBatch.objects.count() + 1:03d}"
|
||
|
||
|
||
def _estimate_page_count(text: str) -> int:
|
||
stripped = text.strip()
|
||
if not stripped:
|
||
return 0
|
||
line_count = len([line for line in stripped.splitlines() if line.strip()])
|
||
return max(1, line_count)
|
||
|
||
|
||
def _expand_uploaded_files(uploaded_files: list) -> list[dict]:
|
||
expanded_files = []
|
||
warnings = []
|
||
for uploaded_file in uploaded_files:
|
||
extension = Path(uploaded_file.name).suffix.lower()
|
||
if extension == ".zip":
|
||
extraction = _extract_zip_entries(uploaded_file)
|
||
expanded_files.extend(extraction["files"])
|
||
warnings.extend(extraction["warnings"])
|
||
continue
|
||
if extension == ".7z":
|
||
extraction = _extract_7z_entries(uploaded_file)
|
||
expanded_files.extend(extraction["files"])
|
||
warnings.extend(extraction["warnings"])
|
||
continue
|
||
expanded_files.append(
|
||
{
|
||
"relative_path": uploaded_file.name,
|
||
"uploaded_file": uploaded_file,
|
||
}
|
||
)
|
||
return {"files": expanded_files, "warnings": warnings}
|
||
|
||
|
||
def _extract_zip_entries(uploaded_file) -> dict:
|
||
archive_bytes = uploaded_file.read()
|
||
uploaded_file.seek(0)
|
||
entries = []
|
||
warnings = []
|
||
with ZipFile(BytesIO(archive_bytes)) as archive:
|
||
for info in archive.infolist():
|
||
if info.is_dir():
|
||
continue
|
||
relative_path = info.filename.replace("\\", "/")
|
||
extension = Path(relative_path).suffix.lower()
|
||
if extension not in {".txt", ".md", ".pdf", ".docx"}:
|
||
warnings.append(f"跳过不支持的文件:{relative_path}")
|
||
continue
|
||
file_data = archive.read(info.filename)
|
||
extracted_file = SimpleUploadedFile(
|
||
Path(relative_path).name,
|
||
file_data,
|
||
)
|
||
entries.append(
|
||
{
|
||
"relative_path": relative_path,
|
||
"uploaded_file": extracted_file,
|
||
}
|
||
)
|
||
return {"files": entries, "warnings": warnings}
|
||
|
||
|
||
def _extract_7z_entries(uploaded_file) -> dict:
|
||
try:
|
||
import py7zr
|
||
except ImportError as exc:
|
||
raise RuntimeError("处理 .7z 资料包需要安装 py7zr。") from exc
|
||
|
||
archive_bytes = uploaded_file.read()
|
||
uploaded_file.seek(0)
|
||
entries = []
|
||
warnings = []
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
with py7zr.SevenZipFile(BytesIO(archive_bytes), mode="r") as archive:
|
||
archive.extractall(path=temp_dir)
|
||
base_path = Path(temp_dir)
|
||
for file_path in sorted(base_path.rglob("*")):
|
||
if not file_path.is_file():
|
||
continue
|
||
relative_path = file_path.relative_to(base_path).as_posix()
|
||
extension = Path(relative_path).suffix.lower()
|
||
if extension not in {".txt", ".md", ".pdf", ".docx"}:
|
||
warnings.append(f"跳过不支持的文件:{relative_path}")
|
||
continue
|
||
extracted_file = SimpleUploadedFile(
|
||
file_path.name,
|
||
file_path.read_bytes(),
|
||
)
|
||
entries.append(
|
||
{
|
||
"relative_path": relative_path,
|
||
"uploaded_file": extracted_file,
|
||
}
|
||
)
|
||
return {"files": entries, "warnings": warnings}
|
||
|
||
|
||
def _detect_document_role(file_name: str) -> str:
|
||
normalized = file_name.lower()
|
||
if "申请表" in file_name:
|
||
return "application_form"
|
||
if "说明书" in file_name:
|
||
return "product_manual"
|
||
if "产品列表" in file_name:
|
||
return "product_list"
|
||
if "声明" in file_name:
|
||
return "declaration"
|
||
if normalized.endswith(".pdf"):
|
||
return "pdf_document"
|
||
return "general_document"
|
||
|
||
|
||
def _detect_chapter_code(file_name: str, text: str) -> str:
|
||
for source in (file_name, text):
|
||
match = re.search(r"(CH\d+(?:\.\d+)*)", source, flags=re.IGNORECASE)
|
||
if match:
|
||
return match.group(1).upper()
|
||
if "监管" in file_name or "申请表" in file_name or "说明书" in file_name:
|
||
return "CH1"
|
||
return ""
|
||
|
||
|
||
def _extract_product_candidates(file_name: str, text: str) -> list[dict]:
|
||
source_type = _detect_candidate_source(file_name)
|
||
if not source_type:
|
||
return []
|
||
patterns = [
|
||
r"产品名称[::]\s*([^\n\r]+)",
|
||
r"名称[::]\s*([^\n\r]+检测试剂盒[^\n\r]*)",
|
||
]
|
||
for pattern in patterns:
|
||
match = re.search(pattern, text)
|
||
if match:
|
||
return [{"source_type": source_type, "product_name": match.group(1).strip()}]
|
||
cleaned = Path(file_name).stem.replace("目标产品", "").replace("说明书", "").strip("-_ ")
|
||
if cleaned and "申请表" not in cleaned and "产品列表" not in cleaned:
|
||
return [{"source_type": source_type, "product_name": cleaned}]
|
||
return []
|
||
|
||
|
||
def _detect_candidate_source(file_name: str) -> str:
|
||
if "申请表" in file_name:
|
||
return "application_form"
|
||
if "说明书" in file_name:
|
||
return "product_manual"
|
||
if "产品列表" in file_name:
|
||
return "product_list"
|
||
return ""
|
||
|
||
|
||
def _select_product_name(candidates: list[dict]) -> tuple[str, list[str]]:
|
||
if not candidates:
|
||
return "", ["未识别到产品名称,建议人工补录。"]
|
||
|
||
priority = {
|
||
"application_form": 1,
|
||
"product_manual": 2,
|
||
"product_list": 3,
|
||
}
|
||
sorted_candidates = sorted(
|
||
candidates,
|
||
key=lambda item: priority.get(item["source_type"], 99),
|
||
)
|
||
top_candidate = sorted_candidates[0]
|
||
warnings = []
|
||
conflict_names = {
|
||
item["product_name"]
|
||
for item in sorted_candidates
|
||
if item["product_name"] != top_candidate["product_name"]
|
||
}
|
||
if conflict_names:
|
||
warnings.append(
|
||
"产品名称来源冲突:"
|
||
+ " / ".join([top_candidate["product_name"], *sorted(conflict_names)])
|
||
)
|
||
return top_candidate["product_name"], warnings
|
||
|
||
|
||
def _read_text_file(path: Path) -> str:
|
||
"""优先按 UTF-8 读取;失败时回退到系统默认编码。"""
|
||
try:
|
||
return path.read_text(encoding="utf-8")
|
||
except UnicodeDecodeError:
|
||
return path.read_text()
|
||
|
||
|
||
def _extract_pdf_text(path: Path) -> str:
|
||
"""优先使用 pypdf 抽取 PDF 文本,失败时回退到容错方案。"""
|
||
try:
|
||
import pypdf
|
||
|
||
reader = pypdf.PdfReader(str(path))
|
||
return "\n".join(page.extract_text() or "" for page in reader.pages)
|
||
except Exception:
|
||
return _read_binary_text_fallback(path)
|
||
|
||
|
||
def _extract_docx_text(path: Path) -> str:
|
||
"""提取 Word XML 中的可见文字内容,不追求保留样式。"""
|
||
try:
|
||
with ZipFile(path) as archive:
|
||
document_xml = archive.read("word/document.xml")
|
||
root = ET.fromstring(document_xml)
|
||
namespace = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"}
|
||
texts = [node.text for node in root.findall(".//w:t", namespace) if node.text]
|
||
return "\n".join(texts)
|
||
except (BadZipFile, KeyError, ET.ParseError):
|
||
return _read_binary_text_fallback(path)
|
||
|
||
|
||
def _read_binary_text_fallback(path: Path) -> str:
|
||
"""
|
||
当结构化抽取失败时,退回到“尽可能保留纯文本”的保底方案。
|
||
|
||
该方案不保证版式,但足以支撑 V1 入库和演示。
|
||
"""
|
||
data = path.read_bytes()
|
||
text = data.decode("utf-8", errors="ignore")
|
||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]+", " ", text)
|
||
return text.strip()
|