mirror of https://github.com/Mai-with-u/MaiBot.git
505 lines
19 KiB
Python
505 lines
19 KiB
Python
"""Expression routes pytest tests"""
|
|
|
|
from datetime import datetime
|
|
from typing import Generator
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, APIRouter
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy.pool import StaticPool
|
|
from sqlalchemy import text
|
|
from sqlmodel import Session, SQLModel, create_engine, select
|
|
|
|
from src.common.database.database_model import Expression
|
|
from src.common.database.database import get_db_session
|
|
|
|
|
|
def create_test_app() -> FastAPI:
|
|
"""Create minimal test app with only expression router"""
|
|
app = FastAPI(title="Test App")
|
|
from src.webui.routers.expression import router as expression_router
|
|
|
|
main_router = APIRouter(prefix="/api/webui")
|
|
main_router.include_router(expression_router)
|
|
app.include_router(main_router)
|
|
|
|
return app
|
|
|
|
|
|
app = create_test_app()
|
|
|
|
|
|
# Test database setup
|
|
@pytest.fixture(name="test_engine")
|
|
def test_engine_fixture():
|
|
"""Create in-memory SQLite database for testing"""
|
|
engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
SQLModel.metadata.create_all(engine)
|
|
return engine
|
|
|
|
|
|
@pytest.fixture(name="test_session")
|
|
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
|
|
"""Create a test database session with transaction rollback"""
|
|
connection = test_engine.connect()
|
|
transaction = connection.begin()
|
|
session = Session(bind=connection)
|
|
|
|
yield session
|
|
|
|
session.close()
|
|
transaction.rollback()
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture(name="client")
|
|
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
|
|
"""Create TestClient with overridden database session"""
|
|
from contextlib import contextmanager
|
|
|
|
@contextmanager
|
|
def get_test_db_session():
|
|
yield test_session
|
|
|
|
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
|
|
|
|
with TestClient(app) as client:
|
|
yield client
|
|
|
|
|
|
@pytest.fixture(name="mock_auth")
|
|
def mock_auth_fixture(monkeypatch):
|
|
"""Mock authentication to always return True"""
|
|
mock_verify = MagicMock(return_value=True)
|
|
monkeypatch.setattr("src.webui.routers.expression.verify_auth_token_from_cookie_or_header", mock_verify)
|
|
|
|
|
|
@pytest.fixture(name="sample_expression")
|
|
def sample_expression_fixture(test_session: Session) -> Expression:
|
|
"""Insert a sample expression into test database"""
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (1, '测试情景', '测试风格', '测试上下文', '测试上文', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
|
|
assert expression is not None
|
|
return expression
|
|
|
|
|
|
# ============ Tests ============
|
|
|
|
|
|
def test_list_expressions_empty(client: TestClient, mock_auth):
|
|
"""Test GET /expression/list with empty database"""
|
|
response = client.get("/api/webui/expression/list")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["total"] == 0
|
|
assert data["page"] == 1
|
|
assert data["page_size"] == 20
|
|
assert data["data"] == []
|
|
|
|
|
|
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test GET /expression/list returns expression data"""
|
|
response = client.get("/api/webui/expression/list")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["total"] == 1
|
|
assert len(data["data"]) == 1
|
|
|
|
expr_data = data["data"][0]
|
|
assert expr_data["id"] == sample_expression.id
|
|
assert expr_data["situation"] == "测试情景"
|
|
assert expr_data["style"] == "测试风格"
|
|
assert expr_data["chat_id"] == "test_chat_001"
|
|
|
|
|
|
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test GET /expression/list pagination works correctly"""
|
|
for i in range(5):
|
|
test_session.execute(
|
|
text(
|
|
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '', '', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
# Request page 1 with page_size=2
|
|
response = client.get("/api/webui/expression/list?page=1&page_size=2")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 5
|
|
assert data["page"] == 1
|
|
assert data["page_size"] == 2
|
|
assert len(data["data"]) == 2
|
|
|
|
# Request page 2
|
|
response = client.get("/api/webui/expression/list?page=2&page_size=2")
|
|
data = response.json()
|
|
assert data["page"] == 2
|
|
assert len(data["data"]) == 2
|
|
|
|
|
|
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test GET /expression/list with search filter"""
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (1, '找人吃饭', '热情', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_001')"
|
|
)
|
|
)
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (2, '拒绝邀请', '礼貌', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_002')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
# Search for "吃饭"
|
|
response = client.get("/api/webui/expression/list?search=吃饭")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert data["data"][0]["situation"] == "找人吃饭"
|
|
|
|
|
|
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test GET /expression/list with chat_id filter"""
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (1, '情景A', '风格A', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_A')"
|
|
)
|
|
)
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (2, '情景B', '风格B', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_B')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
# Filter by chat_A
|
|
response = client.get("/api/webui/expression/list?chat_id=chat_A")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert data["data"][0]["situation"] == "情景A"
|
|
assert data["data"][0]["chat_id"] == "chat_A"
|
|
|
|
|
|
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test GET /expression/{id} returns correct detail"""
|
|
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"]["id"] == sample_expression.id
|
|
assert data["data"]["situation"] == "测试情景"
|
|
assert data["data"]["style"] == "测试风格"
|
|
assert data["data"]["chat_id"] == "test_chat_001"
|
|
|
|
|
|
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
|
|
"""Test GET /expression/{id} returns 404 for non-existent ID"""
|
|
response = client.get("/api/webui/expression/99999")
|
|
assert response.status_code == 404
|
|
|
|
data = response.json()
|
|
assert "未找到" in data["detail"]
|
|
|
|
|
|
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
|
|
response = client.get(f"/api/webui/expression/{sample_expression.id}")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()["data"]
|
|
|
|
# Verify legacy fields exist and have default values
|
|
assert "checked" in data
|
|
assert "rejected" in data
|
|
assert "modified_by" in data
|
|
|
|
# Verify hardcoded default values (from expression_to_response)
|
|
assert data["checked"] is False
|
|
assert data["rejected"] is False
|
|
assert data["modified_by"] is None
|
|
|
|
|
|
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
|
|
# Valid update request (only allowed fields)
|
|
update_payload = {
|
|
"situation": "更新后的情景",
|
|
"style": "更新后的风格",
|
|
}
|
|
|
|
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"]["situation"] == "更新后的情景"
|
|
assert data["data"]["style"] == "更新后的风格"
|
|
|
|
# Verify legacy fields still returned (hardcoded values)
|
|
assert data["data"]["checked"] is False
|
|
assert data["data"]["rejected"] is False
|
|
|
|
|
|
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
|
|
# Request with invalid field (checked not in schema)
|
|
update_payload = {
|
|
"situation": "新情景",
|
|
"checked": True, # This field should be ignored by Pydantic
|
|
"rejected": True, # This field should be ignored
|
|
}
|
|
|
|
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"]["situation"] == "新情景"
|
|
|
|
# Response should have hardcoded False values (not True from request)
|
|
assert data["data"]["checked"] is False
|
|
assert data["data"]["rejected"] is False
|
|
|
|
|
|
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
|
|
update_payload = {"chat_id": "updated_chat_999"}
|
|
|
|
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
# Verify chat_id is returned in response (mapped from session_id)
|
|
assert data["data"]["chat_id"] == "updated_chat_999"
|
|
|
|
|
|
def test_update_expression_not_found(client: TestClient, mock_auth):
|
|
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
|
|
update_payload = {"situation": "新情景"}
|
|
|
|
response = client.patch("/api/webui/expression/99999", json=update_payload)
|
|
assert response.status_code == 404
|
|
|
|
data = response.json()
|
|
assert "未找到" in data["detail"]
|
|
|
|
|
|
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test PATCH /expression/{id} returns 400 for empty update request"""
|
|
update_payload = {}
|
|
|
|
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
|
|
assert response.status_code == 400
|
|
|
|
data = response.json()
|
|
assert "未提供任何需要更新的字段" in data["detail"]
|
|
|
|
|
|
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test DELETE /expression/{id} successfully deletes expression"""
|
|
expression_id = sample_expression.id
|
|
|
|
response = client.delete(f"/api/webui/expression/{expression_id}")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "成功删除" in data["message"]
|
|
|
|
# Verify expression is deleted
|
|
get_response = client.get(f"/api/webui/expression/{expression_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
|
|
def test_delete_expression_not_found(client: TestClient, mock_auth):
|
|
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
|
|
response = client.delete("/api/webui/expression/99999")
|
|
assert response.status_code == 404
|
|
|
|
data = response.json()
|
|
assert "未找到" in data["detail"]
|
|
|
|
|
|
def test_create_expression_success(client: TestClient, mock_auth):
|
|
"""Test POST /expression/ successfully creates expression"""
|
|
create_payload = {
|
|
"situation": "新建情景",
|
|
"style": "新建风格",
|
|
"chat_id": "new_chat_123",
|
|
}
|
|
|
|
response = client.post("/api/webui/expression/", json=create_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "创建成功" in data["message"]
|
|
assert data["data"]["situation"] == "新建情景"
|
|
assert data["data"]["style"] == "新建风格"
|
|
assert data["data"]["chat_id"] == "new_chat_123"
|
|
|
|
# Verify legacy fields
|
|
assert data["data"]["checked"] is False
|
|
assert data["data"]["rejected"] is False
|
|
assert data["data"]["modified_by"] is None
|
|
|
|
|
|
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
|
|
expression_ids = []
|
|
for i in range(3):
|
|
test_session.execute(
|
|
text(
|
|
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}')"
|
|
)
|
|
)
|
|
expression_ids.append(i + 1)
|
|
test_session.commit()
|
|
|
|
delete_payload = {"ids": expression_ids}
|
|
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert "成功删除 3 个" in data["message"]
|
|
|
|
for expr_id in expression_ids:
|
|
get_response = client.get(f"/api/webui/expression/{expr_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
|
|
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test POST /expression/batch/delete handles partial not found IDs"""
|
|
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
|
|
|
|
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
# Should delete only the 1 valid ID
|
|
assert "成功删除 1 个" in data["message"]
|
|
|
|
|
|
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test GET /expression/stats/summary returns correct statistics"""
|
|
for i in range(3):
|
|
test_session.execute(
|
|
text(
|
|
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
response = client.get("/api/webui/expression/stats/summary")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["data"]["total"] == 3
|
|
assert data["data"]["chat_count"] == 2
|
|
|
|
|
|
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
|
|
"""Test GET /expression/review/stats returns hardcoded 0 counts"""
|
|
test_session.execute(
|
|
text(
|
|
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
|
|
"VALUES (1, '待审核', '风格', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_001')"
|
|
)
|
|
)
|
|
test_session.commit()
|
|
|
|
response = client.get("/api/webui/expression/review/stats")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
# Verify all review counts are 0 (hardcoded in refactored code)
|
|
assert data["total"] == 1 # Total expressions exists
|
|
assert data["unchecked"] == 0
|
|
assert data["passed"] == 0
|
|
assert data["rejected"] == 0
|
|
assert data["ai_checked"] == 0
|
|
assert data["user_checked"] == 0
|
|
|
|
|
|
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test GET /expression/review/list with filter_type=unchecked returns empty (legacy behavior)"""
|
|
# filter_type=unchecked should return no results (legacy removed)
|
|
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["total"] == 0 # No results (legacy fields removed)
|
|
|
|
|
|
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
|
|
response = client.get("/api/webui/expression/review/list?filter_type=all")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["total"] == 1
|
|
assert len(data["data"]) == 1
|
|
|
|
|
|
def test_batch_review_expressions_unsupported(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test POST /expression/review/batch returns failure for require_unchecked=True"""
|
|
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
|
|
|
|
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["failed"] == 1 # Should fail because require_unchecked=True
|
|
assert "不支持审核状态过滤" in data["results"][0]["message"]
|
|
|
|
|
|
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
|
|
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
|
|
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
|
|
|
|
response = client.post("/api/webui/expression/review/batch", json=review_payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["succeeded"] == 1
|
|
assert data["results"][0]["success"] is True
|