MaiBot/code_scripts/generate_database_datamodel...

128 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from pathlib import Path
import ast
import subprocess
import sys
base_file_path = Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_model.py"
target_file_path = (
Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_datamodel.py"
)
with open(base_file_path, "r", encoding="utf-8") as f:
source_text = f.read()
source_lines = source_text.splitlines()
try:
tree = ast.parse(source_text)
except SyntaxError as e:
raise e
code_lines = [
"from typing import Optional",
"from pydantic import BaseModel",
"from datetime import datetime",
"from .database_model import ModelUser, ImageType",
]
def src(node):
seg = ast.get_source_segment(source_text, node)
return seg if seg is not None else ast.unparse(node)
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
# 判断是否 SQLModel 且 table=True
has_sqlmodel = any(
(isinstance(b, ast.Name) and b.id == "SQLModel") or (isinstance(b, ast.Attribute) and b.attr == "SQLModel")
for b in node.bases
)
has_table_kw = any(
(kw.arg == "table" and isinstance(kw.value, ast.Constant) and kw.value.value is True) for kw in node.keywords
)
if not (has_sqlmodel and has_table_kw):
continue
class_name = node.name
code_lines.append("")
code_lines.append(f"class {class_name}(BaseModel):")
fields_added = 0
for item in node.body:
# 跳过 __tablename__ 等
if isinstance(item, ast.Assign):
if len(item.targets) != 1 or not isinstance(item.targets[0], ast.Name):
continue
name = item.targets[0].id
if name == "__tablename__":
continue
value_src = src(item.value)
line = f" {name} = {value_src}"
fields_added += 1
lineno = getattr(item, "lineno", None)
elif isinstance(item, ast.AnnAssign):
# 注解赋值
if not isinstance(item.target, ast.Name):
continue
name = item.target.id
ann = src(item.annotation) if item.annotation is not None else None
if item.value is None:
line = f" {name}: {ann}" if ann else f" {name}"
elif isinstance(item.value, ast.Call) and (
(isinstance(item.value.func, ast.Name) and item.value.func.id == "Field")
or (isinstance(item.value.func, ast.Attribute) and item.value.func.attr == "Field")
):
default_kw = next((kw for kw in item.value.keywords if kw.arg == "default"), None)
if default_kw is None:
# 没有 default保留类型但不赋值
line = f" {name}: {ann}" if ann else f" {name}"
else:
default_src = src(default_kw.value)
line = f" {name}: {ann} = {default_src}"
else:
value_src = src(item.value)
line = f" {name}: {ann} = {value_src}" if ann else f" {name} = {value_src}"
fields_added += 1
lineno = getattr(item, "lineno", None)
else:
continue
# 提取同一行的行内注释作为字段说明(如果存在)
comment = None
if lineno is not None:
src_line = source_lines[lineno - 1]
if "#" in src_line:
# 取第一个 #
comment = src_line.split("#", 1)[1].strip()
# 避免三引号冲突
comment = comment.replace('"""', '\\"""')
code_lines.append(line)
if comment:
code_lines.append(f' """{comment}"""')
else:
print(f"Warning: No comment found for field '{name}' in class '{class_name}'.")
if fields_added == 0:
code_lines.append(" pass")
with open(target_file_path, "w", encoding="utf-8") as f:
f.write("\n".join(code_lines) + "\n")
try:
result = subprocess.run(["ruff", "format", str(target_file_path)], capture_output=True, text=True)
except FileNotFoundError:
print("ruff 未找到,请安装 ruff 并确保其在 PATH 中例如pip install ruff", file=sys.stderr)
sys.exit(127)
# 输出 ruff 的 stdout/stderr
if result.stdout:
print(result.stdout, end="")
if result.stderr:
print(result.stderr, file=sys.stderr, end="")
if result.returncode != 0:
print(f"ruff 检查失败,退出码:{result.returncode}", file=sys.stderr)
sys.exit(result.returncode)