mirror of https://github.com/Mai-with-u/MaiBot.git
128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
from pathlib import Path
|
||
import ast
|
||
import subprocess
|
||
import sys
|
||
|
||
base_file_path = Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_model.py"
|
||
target_file_path = (
|
||
Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_datamodel.py"
|
||
)
|
||
|
||
with open(base_file_path, "r", encoding="utf-8") as f:
|
||
source_text = f.read()
|
||
source_lines = source_text.splitlines()
|
||
|
||
try:
|
||
tree = ast.parse(source_text)
|
||
except SyntaxError as e:
|
||
raise e
|
||
|
||
code_lines = [
|
||
"from typing import Optional",
|
||
"from pydantic import BaseModel",
|
||
"from datetime import datetime",
|
||
"from .database_model import ModelUser, ImageType",
|
||
]
|
||
|
||
|
||
def src(node):
|
||
seg = ast.get_source_segment(source_text, node)
|
||
return seg if seg is not None else ast.unparse(node)
|
||
|
||
|
||
for node in tree.body:
|
||
if not isinstance(node, ast.ClassDef):
|
||
continue
|
||
# 判断是否 SQLModel 且 table=True
|
||
has_sqlmodel = any(
|
||
(isinstance(b, ast.Name) and b.id == "SQLModel") or (isinstance(b, ast.Attribute) and b.attr == "SQLModel")
|
||
for b in node.bases
|
||
)
|
||
has_table_kw = any(
|
||
(kw.arg == "table" and isinstance(kw.value, ast.Constant) and kw.value.value is True) for kw in node.keywords
|
||
)
|
||
if not (has_sqlmodel and has_table_kw):
|
||
continue
|
||
|
||
class_name = node.name
|
||
code_lines.append("")
|
||
code_lines.append(f"class {class_name}(BaseModel):")
|
||
|
||
fields_added = 0
|
||
for item in node.body:
|
||
# 跳过 __tablename__ 等
|
||
if isinstance(item, ast.Assign):
|
||
if len(item.targets) != 1 or not isinstance(item.targets[0], ast.Name):
|
||
continue
|
||
name = item.targets[0].id
|
||
if name == "__tablename__":
|
||
continue
|
||
value_src = src(item.value)
|
||
line = f" {name} = {value_src}"
|
||
fields_added += 1
|
||
lineno = getattr(item, "lineno", None)
|
||
elif isinstance(item, ast.AnnAssign):
|
||
# 注解赋值
|
||
if not isinstance(item.target, ast.Name):
|
||
continue
|
||
name = item.target.id
|
||
ann = src(item.annotation) if item.annotation is not None else None
|
||
if item.value is None:
|
||
line = f" {name}: {ann}" if ann else f" {name}"
|
||
elif isinstance(item.value, ast.Call) and (
|
||
(isinstance(item.value.func, ast.Name) and item.value.func.id == "Field")
|
||
or (isinstance(item.value.func, ast.Attribute) and item.value.func.attr == "Field")
|
||
):
|
||
default_kw = next((kw for kw in item.value.keywords if kw.arg == "default"), None)
|
||
if default_kw is None:
|
||
# 没有 default,保留类型但不赋值
|
||
line = f" {name}: {ann}" if ann else f" {name}"
|
||
else:
|
||
default_src = src(default_kw.value)
|
||
line = f" {name}: {ann} = {default_src}"
|
||
else:
|
||
value_src = src(item.value)
|
||
line = f" {name}: {ann} = {value_src}" if ann else f" {name} = {value_src}"
|
||
fields_added += 1
|
||
lineno = getattr(item, "lineno", None)
|
||
else:
|
||
continue
|
||
|
||
# 提取同一行的行内注释作为字段说明(如果存在)
|
||
comment = None
|
||
if lineno is not None:
|
||
src_line = source_lines[lineno - 1]
|
||
if "#" in src_line:
|
||
# 取第一个 #
|
||
comment = src_line.split("#", 1)[1].strip()
|
||
# 避免三引号冲突
|
||
comment = comment.replace('"""', '\\"""')
|
||
|
||
code_lines.append(line)
|
||
if comment:
|
||
code_lines.append(f' """{comment}"""')
|
||
else:
|
||
print(f"Warning: No comment found for field '{name}' in class '{class_name}'.")
|
||
|
||
if fields_added == 0:
|
||
code_lines.append(" pass")
|
||
|
||
with open(target_file_path, "w", encoding="utf-8") as f:
|
||
f.write("\n".join(code_lines) + "\n")
|
||
|
||
try:
|
||
result = subprocess.run(["ruff", "format", str(target_file_path)], capture_output=True, text=True)
|
||
except FileNotFoundError:
|
||
print("ruff 未找到,请安装 ruff 并确保其在 PATH 中(例如:pip install ruff)", file=sys.stderr)
|
||
sys.exit(127)
|
||
|
||
# 输出 ruff 的 stdout/stderr
|
||
if result.stdout:
|
||
print(result.stdout, end="")
|
||
if result.stderr:
|
||
print(result.stderr, file=sys.stderr, end="")
|
||
|
||
if result.returncode != 0:
|
||
print(f"ruff 检查失败,退出码:{result.returncode}", file=sys.stderr)
|
||
sys.exit(result.returncode)
|