refactor(webui): migrate jargon routes from Peewee to SQLModel

- 完全迁移到 SQLModel
- chat_id → session_id 映射
- ChatStreams → ChatSession 替代
- 移除 is_global 字段
- 使用 group_id 替代 group_name
pull/1496/head
DrSmoothl 2026-02-17 19:58:21 +08:00
parent 7da0811b5c
commit 390d1daefd
No known key found for this signature in database
1 changed files with 159 additions and 166 deletions

View File

@ -1,13 +1,16 @@
"""黑话(俚语)管理路由"""
import json
from typing import Optional, List, Annotated
from typing import Annotated, Any, List, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import func as fn
from sqlmodel import Session, col, delete, select
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Jargon
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon")
@ -43,7 +46,7 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str:
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
"""
获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
@ -51,19 +54,18 @@ def get_display_name_for_chat_id(chat_id_str: str) -> str:
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
return chat_id_str[:20]
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
stream_id = stream_ids[0]
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
return ", ".join(names) if names else chat_id_str
if not chat_session:
return stream_id[:20]
if chat_session.group_id:
return str(chat_session.group_id)
return chat_session.session_id[:20]
# ==================== 请求/响应模型 ====================
@ -79,7 +81,6 @@ class JargonResponse(BaseModel):
chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0
is_jargon: Optional[bool] = None
is_complete: bool = False
@ -94,7 +95,7 @@ class JargonListResponse(BaseModel):
total: int
page: int
page_size: int
data: List[JargonResponse]
data: List[dict[str, Any]]
class JargonDetailResponse(BaseModel):
@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel):
raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel):
@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel):
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None
@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict
data: dict[str, Any]
class ChatInfoResponse(BaseModel):
@ -181,27 +180,24 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict:
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"""将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
chat_id = jargon.session_id or ""
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
return {
"id": jargon.id,
"content": jargon.content,
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": jargon.chat_id,
"stream_id": stream_id,
"chat_id": chat_id,
"stream_id": jargon.session_id,
"chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only,
"inference_content_only": jargon.inference_with_content_only,
}
@ -215,49 +211,41 @@ async def get_jargon_list(
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
):
"""获取黑话列表"""
try:
# 构建查询
query = Jargon.select()
statement = select(Jargon)
count_statement = select(fn.count()).select_from(Jargon)
# 搜索过滤
if search:
query = query.where(
(Jargon.content.contains(search))
| (Jargon.meaning.contains(search))
| (Jargon.raw_content.contains(search))
search_filter = (
(col(Jargon.content).contains(search))
| (col(Jargon.meaning).contains(search))
| (col(Jargon.raw_content).contains(search))
)
statement = statement.where(search_filter)
count_statement = count_statement.where(search_filter)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
# 使用第一个 stream_id 进行模糊匹配
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
chat_filter = col(Jargon.session_id).contains(stream_ids[0])
else:
# 如果无法解析,使用精确匹配
query = query.where(Jargon.chat_id == chat_id)
chat_filter = col(Jargon.session_id) == chat_id
statement = statement.where(chat_filter)
count_statement = count_statement.where(chat_filter)
# 按是否是黑话筛选
if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon)
statement = statement.where(col(Jargon.is_jargon) == is_jargon)
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
# 按是否全局筛选
if is_global is not None:
query = query.where(Jargon.is_global == is_global)
statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
statement = statement.offset((page - 1) * page_size).limit(page_size)
# 获取总数
total = query.count()
# 分页和排序(按使用次数降序)
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
with get_db_session() as session:
total = session.exec(count_statement).one()
jargons = session.exec(statement).all()
data = [jargon_to_dict(jargon, session) for jargon in jargons]
return JargonListResponse(
success=True,
@ -276,10 +264,9 @@ async def get_jargon_list(
async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
with get_db_session() as session:
statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
@ -290,27 +277,28 @@ async def get_chat_list():
seen_stream_ids.add(stream_ids[0])
result = []
for stream_id in seen_stream_ids:
# 尝试从 ChatStreams 表获取聊天名称
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id方便筛选匹配
chat_name=chat_stream.group_name or stream_id,
platform=chat_stream.platform,
is_group=True,
with get_db_session() as session:
for stream_id in seen_stream_ids:
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if chat_session:
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
result.append(
ChatInfoResponse(
chat_id=stream_id,
chat_name=chat_name,
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
)
)
else:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
platform=None,
is_group=False,
else:
result.append(
ChatInfoResponse(
chat_id=stream_id,
chat_name=stream_id[:20],
platform=None,
is_group=False,
)
)
)
return ChatListResponse(success=True, data=result)
@ -323,35 +311,35 @@ async def get_chat_list():
async def get_jargon_stats():
"""获取黑话统计数据"""
try:
# 总数量
total = Jargon.select().count()
with get_db_session() as session:
total = session.exec(select(fn.count()).select_from(Jargon)).one()
# 已确认是黑话的数量
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
confirmed_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
# 已确认不是黑话的数量
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
complete_count = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
).one()
# 未判定的数量
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
chat_count = session.exec(
select(fn.count()).select_from(
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
)
).one()
# 全局黑话数量
global_count = Jargon.select().where(Jargon.is_global).count()
# 已完成推断的数量
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
top_chats = session.exec(
select(col(Jargon.session_id), fn.count().label("count"))
.where(col(Jargon.session_id).is_not(None))
.group_by(col(Jargon.session_id))
.order_by(fn.count().desc())
.limit(5)
).all()
top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
return JargonStatsResponse(
success=True,
@ -360,7 +348,6 @@ async def get_jargon_stats():
"confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon,
"pending": pending,
"global_count": global_count,
"complete_count": complete_count,
"chat_count": chat_count,
"top_chats": top_chats_dict,
@ -376,11 +363,13 @@ async def get_jargon_stats():
async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
return JargonDetailResponse(success=True, data=data)
except HTTPException:
raise
@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int):
async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
with get_db_session() as session:
existing = session.exec(
select(Jargon).where(
(col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
)
).first()
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话
jargon = Jargon.create(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning,
chat_id=request.chat_id,
is_global=request.is_global,
count=0,
is_jargon=None,
is_complete=False,
)
jargon = Jargon(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning or "",
session_id=request.chat_id,
count=0,
is_jargon=None,
is_complete=False,
)
session.add(jargon)
session.flush()
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonCreateResponse(
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
return JargonCreateResponse(success=True, message="创建成功", data=data)
except HTTPException:
raise
@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest):
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
jargon.save()
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if field == "is_global":
continue
if field == "chat_id":
jargon.session_id = value
continue
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
session.add(jargon)
logger.info(f"更新黑话成功: id={jargon_id}")
logger.info(f"更新黑话成功: id={jargon_id}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonUpdateResponse(
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
return JargonUpdateResponse(success=True, message="更新成功", data=data)
except HTTPException:
raise
@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
async def delete_jargon(jargon_id: int):
"""删除黑话"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content
jargon.delete_instance()
content = jargon.content
session.delete(jargon)
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse(
success=True,
message="删除成功",
deleted_count=1,
)
return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1)
except HTTPException:
raise
@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
with get_db_session() as session:
result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse(
success=True,
@ -516,14 +507,16 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
with get_db_session() as session:
jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all()
for jargon in jargons:
jargon.is_jargon = is_jargon
session.add(jargon)
updated_count = len(jargons)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse(
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态")
except HTTPException:
raise