Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev

pull/1496/head
UnCLAS-Prommer 2026-02-18 16:00:58 +08:00
commit ccd1be7bed
No known key found for this signature in database
45 changed files with 7190 additions and 744 deletions

1
.gitignore vendored
View File

@ -353,3 +353,4 @@ interested_rates.txt
MaiBot.code-workspace MaiBot.code-workspace
*.lock *.lock
actionlint actionlint
.sisyphus/

View File

@ -8,7 +8,9 @@
"build": "tsc -b && vite build", "build": "tsc -b && vite build",
"lint": "eslint .", "lint": "eslint .",
"preview": "vite preview", "preview": "vite preview",
"format": "prettier --write \"src/**/*.{ts,tsx,css}\"" "format": "prettier --write \"src/**/*.{ts,tsx,css}\"",
"test": "vitest",
"test:ui": "vitest --ui"
}, },
"dependencies": { "dependencies": {
"@codemirror/lang-javascript": "^6.2.4", "@codemirror/lang-javascript": "^6.2.4",
@ -75,21 +77,27 @@
}, },
"devDependencies": { "devDependencies": {
"@eslint/js": "^9.39.1", "@eslint/js": "^9.39.1",
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^16.3.2",
"@testing-library/user-event": "^14.6.1",
"@types/node": "^24.10.2", "@types/node": "^24.10.2",
"@types/react": "^19.2.7", "@types/react": "^19.2.7",
"@types/react-dom": "^19.2.3", "@types/react-dom": "^19.2.3",
"@vitejs/plugin-react": "^5.1.2", "@vitejs/plugin-react": "^5.1.2",
"@vitest/ui": "^4.0.18",
"autoprefixer": "^10.4.22", "autoprefixer": "^10.4.22",
"eslint": "^9.39.1", "eslint": "^9.39.1",
"eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.4.24", "eslint-plugin-react-refresh": "^0.4.24",
"globals": "^16.5.0", "globals": "^16.5.0",
"jsdom": "^28.1.0",
"postcss": "^8.5.6", "postcss": "^8.5.6",
"prettier": "^3.7.4", "prettier": "^3.7.4",
"prettier-plugin-tailwindcss": "^0.7.2", "prettier-plugin-tailwindcss": "^0.7.2",
"tailwindcss": "^3", "tailwindcss": "^3",
"typescript": "~5.9.3", "typescript": "~5.9.3",
"typescript-eslint": "^8.49.0", "typescript-eslint": "^8.49.0",
"vite": "^7.2.7" "vite": "^7.2.7",
"vitest": "^4.0.18"
} }
} }

View File

@ -0,0 +1,114 @@
import * as React from 'react'
import type { ConfigSchema, FieldSchema } from '@/types/config-schema'
import { fieldHooks, type FieldHookRegistry } from '@/lib/field-hooks'
import { DynamicField } from './DynamicField'
export interface DynamicConfigFormProps {
schema: ConfigSchema
values: Record<string, unknown>
onChange: (field: string, value: unknown) => void
hooks?: FieldHookRegistry
}
/**
* DynamicConfigForm -
*
* ConfigSchema
* 1. Hook FieldHookRegistry
* - replace
* - wrapper children
* 2. schema schema.nested
* 3. 使 DynamicField
*/
export const DynamicConfigForm: React.FC<DynamicConfigFormProps> = ({
schema,
values,
onChange,
hooks = fieldHooks, // 默认使用全局单例
}) => {
/**
*
* Hook Hook
*/
const renderField = (field: FieldSchema) => {
const fieldPath = field.name
// 检查是否有注册的 Hook
if (hooks.has(fieldPath)) {
const hookEntry = hooks.get(fieldPath)
if (!hookEntry) return null // Type guard理论上不会发生
const HookComponent = hookEntry.component
if (hookEntry.type === 'replace') {
// replace 模式:完全替换默认渲染
return (
<HookComponent
fieldPath={fieldPath}
value={values[field.name]}
onChange={(v) => onChange(field.name, v)}
/>
)
} else {
// wrapper 模式:包装默认渲染
return (
<HookComponent
fieldPath={fieldPath}
value={values[field.name]}
onChange={(v) => onChange(field.name, v)}
>
<DynamicField
schema={field}
value={values[field.name]}
onChange={(v) => onChange(field.name, v)}
fieldPath={fieldPath}
/>
</HookComponent>
)
}
}
// 无 Hook使用默认渲染
return (
<DynamicField
schema={field}
value={values[field.name]}
onChange={(v) => onChange(field.name, v)}
fieldPath={fieldPath}
/>
)
}
return (
<div className="space-y-4">
{/* 渲染顶层字段 */}
{schema.fields.map((field) => (
<div key={field.name}>{renderField(field)}</div>
))}
{/* 渲染嵌套 schema */}
{schema.nested &&
Object.entries(schema.nested).map(([key, nestedSchema]) => (
<div key={key} className="mt-6 space-y-4">
{/* 嵌套 schema 标题 */}
<div className="border-b pb-2">
<h3 className="text-lg font-semibold">{nestedSchema.className}</h3>
{nestedSchema.classDoc && (
<p className="text-sm text-muted-foreground">{nestedSchema.classDoc}</p>
)}
</div>
{/* 递归渲染嵌套表单 */}
<DynamicConfigForm
schema={nestedSchema}
values={(values[key] as Record<string, unknown>) || {}}
onChange={(field, value) => onChange(`${key}.${field}`, value)}
hooks={hooks}
/>
</div>
))}
</div>
)
}

View File

@ -0,0 +1,246 @@
import * as React from "react"
import * as LucideIcons from "lucide-react"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"
import { Slider } from "@/components/ui/slider"
import { Switch } from "@/components/ui/switch"
import { Textarea } from "@/components/ui/textarea"
import type { FieldSchema } from "@/types/config-schema"
export interface DynamicFieldProps {
schema: FieldSchema
value: unknown
onChange: (value: unknown) => void
// eslint-disable-next-line @typescript-eslint/no-unused-vars
fieldPath?: string // 用于 Hook 系统(未来使用)
}
/**
* DynamicField - x-widget shadcn/ui
*
*
* 1. x-widget schema x-widget使
* 2. type 退 x-widget type
*/
export const DynamicField: React.FC<DynamicFieldProps> = ({
schema,
value,
onChange,
}) => {
/**
*
*/
const renderIcon = () => {
if (!schema['x-icon']) return null
const IconComponent = (LucideIcons as any)[schema['x-icon']]
if (!IconComponent) return null
return <IconComponent className="h-4 w-4" />
}
/**
* x-widget type
*/
const renderInputComponent = () => {
const widget = schema['x-widget']
const type = schema.type
// x-widget 优先
if (widget) {
switch (widget) {
case 'slider':
return renderSlider()
case 'switch':
return renderSwitch()
case 'textarea':
return renderTextarea()
case 'select':
return renderSelect()
case 'custom':
return (
<div className="rounded-md border border-dashed border-muted-foreground/25 bg-muted/10 p-4 text-center text-sm text-muted-foreground">
Custom field requires Hook
</div>
)
default:
// 未知的 x-widget回退到 type
break
}
}
// type 回退
switch (type) {
case 'boolean':
return renderSwitch()
case 'number':
case 'integer':
return renderNumberInput()
case 'string':
return renderTextInput()
case 'select':
return renderSelect()
case 'array':
return (
<div className="rounded-md border border-dashed border-muted-foreground/25 bg-muted/10 p-4 text-center text-sm text-muted-foreground">
Array fields not yet supported
</div>
)
case 'object':
return (
<div className="rounded-md border border-dashed border-muted-foreground/25 bg-muted/10 p-4 text-center text-sm text-muted-foreground">
Object fields not yet supported
</div>
)
case 'textarea':
return renderTextarea()
default:
return (
<div className="rounded-md border border-dashed border-muted-foreground/25 bg-muted/10 p-4 text-center text-sm text-muted-foreground">
Unknown field type: {type}
</div>
)
}
}
/**
* Switch boolean
*/
const renderSwitch = () => {
const checked = Boolean(value)
return (
<Switch
checked={checked}
onCheckedChange={(checked) => onChange(checked)}
/>
)
}
/**
* Slider number + x-widget: slider
*/
const renderSlider = () => {
const numValue = typeof value === 'number' ? value : (schema.default as number ?? 0)
const min = schema.minValue ?? 0
const max = schema.maxValue ?? 100
const step = schema.step ?? 1
return (
<div className="space-y-2">
<Slider
value={[numValue]}
onValueChange={(values) => onChange(values[0])}
min={min}
max={max}
step={step}
/>
<div className="flex justify-between text-xs text-muted-foreground">
<span>{min}</span>
<span className="font-medium text-foreground">{numValue}</span>
<span>{max}</span>
</div>
</div>
)
}
/**
* Input[type="number"] number/integer
*/
const renderNumberInput = () => {
const numValue = typeof value === 'number' ? value : (schema.default as number ?? 0)
const min = schema.minValue
const max = schema.maxValue
const step = schema.step ?? (schema.type === 'integer' ? 1 : 0.1)
return (
<Input
type="number"
value={numValue}
onChange={(e) => onChange(parseFloat(e.target.value) || 0)}
min={min}
max={max}
step={step}
/>
)
}
/**
* Input[type="text"] string
*/
const renderTextInput = () => {
const strValue = typeof value === 'string' ? value : (schema.default as string ?? '')
return (
<Input
type="text"
value={strValue}
onChange={(e) => onChange(e.target.value)}
/>
)
}
/**
* Textarea textarea x-widget: textarea
*/
const renderTextarea = () => {
const strValue = typeof value === 'string' ? value : (schema.default as string ?? '')
return (
<Textarea
value={strValue}
onChange={(e) => onChange(e.target.value)}
rows={4}
/>
)
}
/**
* Select select x-widget: select
*/
const renderSelect = () => {
const strValue = typeof value === 'string' ? value : (schema.default as string ?? '')
const options = schema.options ?? []
if (options.length === 0) {
return (
<div className="rounded-md border border-dashed border-muted-foreground/25 bg-muted/10 p-4 text-center text-sm text-muted-foreground">
No options available for select
</div>
)
}
return (
<Select value={strValue} onValueChange={(val) => onChange(val)}>
<SelectTrigger>
<SelectValue placeholder={`Select ${schema.label}`} />
</SelectTrigger>
<SelectContent>
{options.map((option) => (
<SelectItem key={option} value={option}>
{option}
</SelectItem>
))}
</SelectContent>
</Select>
)
}
return (
<div className="space-y-2">
{/* Label with icon */}
<Label className="text-sm font-medium flex items-center gap-2">
{renderIcon()}
{schema.label}
{schema.required && <span className="text-destructive">*</span>}
</Label>
{/* Input component */}
{renderInputComponent()}
{/* Description */}
{schema.description && (
<p className="text-sm text-muted-foreground">{schema.description}</p>
)}
</div>
)
}

View File

@ -0,0 +1,126 @@
# Dynamic Config Form System
## Overview
The Dynamic Config Form system is a schema-driven UI component designed to automatically generate configuration forms based on backend Pydantic models. It supports rich metadata for UI customization and a flexible Hook system for complex fields.
### Core Components
- **DynamicConfigForm**: The main component that takes a `ConfigSchema` and renders the entire form.
- **DynamicField**: A lower-level component that renders individual fields based on their type and UI metadata.
- **FieldHookRegistry**: A registry for custom React components that can replace or wrap default field rendering.
## Quick Start
To use the dynamic form in your page:
```typescript
import { DynamicConfigForm } from '@/components/dynamic-form'
import { fieldHooks } from '@/lib/field-hooks'
// Example usage in a component
export function ConfigPage() {
const [config, setConfig] = useState({})
const schema = useConfigSchema() // Fetch from API
const handleChange = (fieldPath: string, value: unknown) => {
// fieldPath can be nested, e.g., 'section.subfield'
updateConfigAt(fieldPath, value)
}
return (
<DynamicConfigForm
schema={schema}
values={config}
onChange={handleChange}
hooks={fieldHooks}
/>
)
}
```
## Adding UI Metadata (Backend)
You can customize how fields are rendered by adding `json_schema_extra` to your Pydantic `Field` definitions.
### Supported Metadata
- `x-widget`: Specifies the UI component to use.
- `slider`: A range slider (requires `ge`, `le`, and `step`).
- `switch`: A toggle switch (for booleans).
- `textarea`: A multi-line text input.
- `select`: A dropdown menu (for `Literal` or enum types).
- `custom`: Indicates that this field requires a Hook for rendering.
- `x-icon`: A Lucide icon name (e.g., `MessageSquare`, `Settings`).
- `step`: Incremental step for sliders or number inputs.
### Example
```python
class ChatConfig(ConfigBase):
talk_value: float = Field(
default=0.5,
ge=0.0,
le=1.0,
json_schema_extra={
"x-widget": "slider",
"x-icon": "MessageSquare",
"step": 0.1
}
)
```
## Creating Hook Components
Hooks allow you to provide custom UI for complex configuration sections or fields.
### FieldHookComponent Interface
A Hook component receives the following props:
- `fieldPath`: The full path to the field.
- `value`: The current value of the field/section.
- `onChange`: Callback to update the value.
- `children`: (Only for `wrapper` hooks) The default field renderer.
### Implementation Example
```typescript
import type { FieldHookComponent } from '@/lib/field-hooks'
export const CustomSectionHook: FieldHookComponent = ({
fieldPath,
value,
onChange
}) => {
return (
<div className="custom-section">
<h3>Custom UI</h3>
<input
value={value.some_prop}
onChange={(e) => onChange({ ...value, some_prop: e.target.value })}
/>
</div>
)
}
```
### Registering Hooks
Register hooks in your component's lifecycle:
```typescript
useEffect(() => {
fieldHooks.register('chat', ChatSectionHook, 'replace')
return () => fieldHooks.unregister('chat')
}, [])
```
## API Reference
### DynamicConfigForm
| Prop | Type | Description |
|------|------|-------------|
| `schema` | `ConfigSchema` | The schema generated by the backend. |
| `values` | `Record<string, any>` | Current configuration values. |
| `onChange` | `(field: string, value: any) => void` | Change handler. |
| `hooks` | `FieldHookRegistry` | Optional custom hook registry. |
### FieldHookRegistry
- `register(path, component, type)`: Register a hook.
- `get(path)`: Retrieve a registered hook.
- `has(path)`: Check if a hook exists.
- `unregister(path)`: Remove a hook.
## Troubleshooting
- **Hook not rendering**: Ensure the registration path matches the schema field name exactly (e.g., `chat` vs `Chat`).
- **Field missing**: Check if the field is present in the `ConfigSchema` returned by the backend.
- **TypeScript errors**: Ensure your Hook implements the `FieldHookComponent` type.

View File

@ -0,0 +1,362 @@
import { describe, it, expect, vi } from 'vitest'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { DynamicConfigForm } from '../DynamicConfigForm'
import { FieldHookRegistry } from '@/lib/field-hooks'
import type { ConfigSchema } from '@/types/config-schema'
import type { FieldHookComponentProps } from '@/lib/field-hooks'
describe('DynamicConfigForm', () => {
describe('basic rendering', () => {
it('renders simple fields', () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'field1',
type: 'string',
label: 'Field 1',
description: 'First field',
required: false,
default: 'value1',
},
{
name: 'field2',
type: 'boolean',
label: 'Field 2',
description: 'Second field',
required: false,
default: false,
},
],
}
const values = { field1: 'value1', field2: false }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Field 1')).toBeInTheDocument()
expect(screen.getByText('Field 2')).toBeInTheDocument()
expect(screen.getByText('First field')).toBeInTheDocument()
expect(screen.getByText('Second field')).toBeInTheDocument()
})
it('renders nested schema', () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [
{
name: 'top_field',
type: 'string',
label: 'Top Field',
description: 'Top level field',
required: false,
},
],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'number',
label: 'Nested Field',
description: 'Nested field',
required: false,
default: 42,
},
],
},
},
}
const values = {
top_field: 'top',
sub_config: {
nested_field: 42,
},
}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Top Field')).toBeInTheDocument()
expect(screen.getByText('SubConfig')).toBeInTheDocument()
expect(screen.getByText('Sub configuration')).toBeInTheDocument()
expect(screen.getByText('Nested Field')).toBeInTheDocument()
})
})
describe('Hook system', () => {
it('renders Hook component in replace mode', () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value }) => {
return <div data-testid="hook-component">Hook: {fieldPath} = {String(value)}</div>
}
const hooks = new FieldHookRegistry()
hooks.register('hooked_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'hooked_field',
type: 'string',
label: 'Hooked Field',
description: 'A field with hook',
required: false,
},
{
name: 'normal_field',
type: 'string',
label: 'Normal Field',
description: 'A normal field',
required: false,
},
],
}
const values = { hooked_field: 'test', normal_field: 'normal' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('hook-component')).toBeInTheDocument()
expect(screen.getByText('Hook: hooked_field = test')).toBeInTheDocument()
expect(screen.queryByText('Hooked Field')).not.toBeInTheDocument()
expect(screen.getByText('Normal Field')).toBeInTheDocument()
})
it('renders Hook component in wrapper mode', () => {
const WrapperHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, children }) => {
return (
<div data-testid="wrapper-hook">
<div>Wrapper for: {fieldPath}</div>
{children}
</div>
)
}
const hooks = new FieldHookRegistry()
hooks.register('wrapped_field', WrapperHookComponent, 'wrapper')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'wrapped_field',
type: 'string',
label: 'Wrapped Field',
description: 'A wrapped field',
required: false,
},
],
}
const values = { wrapped_field: 'test' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('wrapper-hook')).toBeInTheDocument()
expect(screen.getByText('Wrapper for: wrapped_field')).toBeInTheDocument()
expect(screen.getByText('Wrapped Field')).toBeInTheDocument()
})
it('passes correct props to Hook component', () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ fieldPath, value, onChange }) => {
return (
<div>
<div data-testid="field-path">{fieldPath}</div>
<div data-testid="field-value">{String(value)}</div>
<button onClick={() => onChange?.('new_value')}>Change</button>
</div>
)
}
const hooks = new FieldHookRegistry()
hooks.register('test_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: 'original' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
expect(screen.getByTestId('field-path')).toHaveTextContent('test_field')
expect(screen.getByTestId('field-value')).toHaveTextContent('original')
})
})
describe('onChange propagation', () => {
it('propagates onChange from simple field', async () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: '' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Hello')
expect(onChange).toHaveBeenCalledTimes(5)
expect(onChange.mock.calls.every(call => call[0] === 'test_field')).toBe(true)
expect(onChange).toHaveBeenNthCalledWith(1, 'test_field', 'H')
expect(onChange).toHaveBeenNthCalledWith(5, 'test_field', 'o')
})
it('propagates onChange from nested field with correct path', async () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'string',
label: 'Nested Field',
description: 'Nested field',
required: false,
},
],
},
},
}
const values = {
sub_config: {
nested_field: '',
},
}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Test')
expect(onChange).toHaveBeenCalledTimes(4)
expect(onChange.mock.calls.every(call => call[0] === 'sub_config.nested_field')).toBe(true)
expect(onChange).toHaveBeenNthCalledWith(1, 'sub_config.nested_field', 'T')
expect(onChange).toHaveBeenNthCalledWith(4, 'sub_config.nested_field', 't')
})
it('propagates onChange from Hook component', async () => {
const TestHookComponent: React.FC<FieldHookComponentProps> = ({ onChange }) => {
return <button onClick={() => onChange?.('hook_value')}>Set Value</button>
}
const hooks = new FieldHookRegistry()
hooks.register('hooked_field', TestHookComponent, 'replace')
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'hooked_field',
type: 'string',
label: 'Hooked Field',
description: 'A hooked field',
required: false,
},
],
}
const values = { hooked_field: '' }
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} hooks={hooks} />)
await user.click(screen.getByRole('button'))
expect(onChange).toHaveBeenCalledWith('hooked_field', 'hook_value')
})
})
describe('edge cases', () => {
it('renders with empty nested values', () => {
const schema: ConfigSchema = {
className: 'MainConfig',
classDoc: 'Main configuration',
fields: [],
nested: {
sub_config: {
className: 'SubConfig',
classDoc: 'Sub configuration',
fields: [
{
name: 'nested_field',
type: 'string',
label: 'Nested Field',
description: 'Nested field',
required: false,
},
],
},
},
}
const values = {}
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('SubConfig')).toBeInTheDocument()
expect(screen.getByText('Nested Field')).toBeInTheDocument()
})
it('uses default hook registry when not provided', () => {
const schema: ConfigSchema = {
className: 'TestConfig',
classDoc: 'Test configuration',
fields: [
{
name: 'test_field',
type: 'string',
label: 'Test Field',
description: 'A test field',
required: false,
},
],
}
const values = { test_field: 'test' }
const onChange = vi.fn()
render(<DynamicConfigForm schema={schema} values={values} onChange={onChange} />)
expect(screen.getByText('Test Field')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,394 @@
import { describe, it, expect, vi } from 'vitest'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { DynamicField } from '../DynamicField'
import type { FieldSchema } from '@/types/config-schema'
describe('DynamicField', () => {
describe('x-widget priority', () => {
it('renders Slider when x-widget is slider', () => {
const schema: FieldSchema = {
name: 'test_slider',
type: 'number',
label: 'Test Slider',
description: 'A test slider',
required: false,
'x-widget': 'slider',
minValue: 0,
maxValue: 100,
default: 50,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={50} onChange={onChange} />)
expect(screen.getByText('Test Slider')).toBeInTheDocument()
expect(screen.getByRole('slider')).toBeInTheDocument()
expect(screen.getByText('50')).toBeInTheDocument()
})
it('renders Switch when x-widget is switch', () => {
const schema: FieldSchema = {
name: 'test_switch',
type: 'boolean',
label: 'Test Switch',
description: 'A test switch',
required: false,
'x-widget': 'switch',
default: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
expect(screen.getByText('Test Switch')).toBeInTheDocument()
expect(screen.getByRole('switch')).toBeInTheDocument()
})
it('renders Textarea when x-widget is textarea', () => {
const schema: FieldSchema = {
name: 'test_textarea',
type: 'string',
label: 'Test Textarea',
description: 'A test textarea',
required: false,
'x-widget': 'textarea',
default: 'Hello',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
expect(screen.getByText('Test Textarea')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Hello')
})
it('renders Select when x-widget is select', () => {
const schema: FieldSchema = {
name: 'test_select',
type: 'string',
label: 'Test Select',
description: 'A test select',
required: false,
'x-widget': 'select',
options: ['Option 1', 'Option 2', 'Option 3'],
default: 'Option 1',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Option 1" onChange={onChange} />)
expect(screen.getByText('Test Select')).toBeInTheDocument()
expect(screen.getByRole('combobox')).toBeInTheDocument()
})
it('renders placeholder for custom widget', () => {
const schema: FieldSchema = {
name: 'test_custom',
type: 'string',
label: 'Test Custom',
description: 'A test custom field',
required: false,
'x-widget': 'custom',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('Custom field requires Hook')).toBeInTheDocument()
})
})
describe('type fallback', () => {
it('renders Input for string type', () => {
const schema: FieldSchema = {
name: 'test_string',
type: 'string',
label: 'Test String',
description: 'A test string',
required: false,
default: 'Hello',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Hello" onChange={onChange} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Hello')
})
it('renders Switch for boolean type', () => {
const schema: FieldSchema = {
name: 'test_bool',
type: 'boolean',
label: 'Test Boolean',
description: 'A test boolean',
required: false,
default: true,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={true} onChange={onChange} />)
expect(screen.getByRole('switch')).toBeInTheDocument()
expect(screen.getByRole('switch')).toBeChecked()
})
it('renders number Input for number type', () => {
const schema: FieldSchema = {
name: 'test_number',
type: 'number',
label: 'Test Number',
description: 'A test number',
required: false,
default: 42,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={42} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
expect(input).toBeInTheDocument()
expect(input).toHaveValue(42)
})
it('renders number Input for integer type', () => {
const schema: FieldSchema = {
name: 'test_integer',
type: 'integer',
label: 'Test Integer',
description: 'A test integer',
required: false,
default: 10,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={10} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
expect(input).toBeInTheDocument()
expect(input).toHaveValue(10)
})
it('renders Textarea for textarea type', () => {
const schema: FieldSchema = {
name: 'test_textarea_type',
type: 'textarea',
label: 'Test Textarea Type',
description: 'A test textarea type',
required: false,
default: 'Long text',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="Long text" onChange={onChange} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('textbox')).toHaveValue('Long text')
})
it('renders Select for select type', () => {
const schema: FieldSchema = {
name: 'test_select_type',
type: 'select',
label: 'Test Select Type',
description: 'A test select type',
required: false,
options: ['A', 'B', 'C'],
default: 'A',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="A" onChange={onChange} />)
expect(screen.getByRole('combobox')).toBeInTheDocument()
})
it('renders placeholder for array type', () => {
const schema: FieldSchema = {
name: 'test_array',
type: 'array',
label: 'Test Array',
description: 'A test array',
required: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={[]} onChange={onChange} />)
expect(screen.getByText('Array fields not yet supported')).toBeInTheDocument()
})
it('renders placeholder for object type', () => {
const schema: FieldSchema = {
name: 'test_object',
type: 'object',
label: 'Test Object',
description: 'A test object',
required: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={{}} onChange={onChange} />)
expect(screen.getByText('Object fields not yet supported')).toBeInTheDocument()
})
})
describe('onChange events', () => {
it('triggers onChange for Switch', async () => {
const schema: FieldSchema = {
name: 'test_switch',
type: 'boolean',
label: 'Test Switch',
description: 'A test switch',
required: false,
default: false,
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicField schema={schema} value={false} onChange={onChange} />)
await user.click(screen.getByRole('switch'))
expect(onChange).toHaveBeenCalledWith(true)
})
it('triggers onChange for Input', async () => {
const schema: FieldSchema = {
name: 'test_input',
type: 'string',
label: 'Test Input',
description: 'A test input',
required: false,
default: '',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
const input = screen.getByRole('textbox')
input.focus()
await userEvent.keyboard('Hello')
expect(onChange).toHaveBeenCalledTimes(5)
expect(onChange).toHaveBeenNthCalledWith(1, 'H')
expect(onChange).toHaveBeenNthCalledWith(2, 'e')
expect(onChange).toHaveBeenNthCalledWith(3, 'l')
expect(onChange).toHaveBeenNthCalledWith(4, 'l')
expect(onChange).toHaveBeenNthCalledWith(5, 'o')
})
it('triggers onChange for number Input', async () => {
const schema: FieldSchema = {
name: 'test_number',
type: 'number',
label: 'Test Number',
description: 'A test number',
required: false,
default: 0,
}
const onChange = vi.fn()
const user = userEvent.setup()
render(<DynamicField schema={schema} value={0} onChange={onChange} />)
const input = screen.getByRole('spinbutton')
await user.clear(input)
await user.type(input, '123')
expect(onChange).toHaveBeenCalled()
})
})
describe('visual features', () => {
it('renders label with icon', () => {
const schema: FieldSchema = {
name: 'test_icon',
type: 'string',
label: 'Test Icon',
description: 'A test with icon',
required: false,
'x-icon': 'Settings',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('Test Icon')).toBeInTheDocument()
})
it('renders required indicator', () => {
const schema: FieldSchema = {
name: 'test_required',
type: 'string',
label: 'Test Required',
description: 'A required field',
required: true,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('*')).toBeInTheDocument()
})
it('renders description', () => {
const schema: FieldSchema = {
name: 'test_desc',
type: 'string',
label: 'Test Description',
description: 'This is a description',
required: false,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('This is a description')).toBeInTheDocument()
})
})
describe('slider features', () => {
it('renders slider with min/max/step', () => {
const schema: FieldSchema = {
name: 'test_slider_props',
type: 'number',
label: 'Test Slider Props',
description: 'A slider with props',
required: false,
'x-widget': 'slider',
minValue: 10,
maxValue: 50,
step: 5,
default: 25,
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value={25} onChange={onChange} />)
expect(screen.getByText('10')).toBeInTheDocument()
expect(screen.getByText('50')).toBeInTheDocument()
expect(screen.getByText('25')).toBeInTheDocument()
})
})
describe('select features', () => {
it('renders placeholder when no options', () => {
const schema: FieldSchema = {
name: 'test_select_no_options',
type: 'string',
label: 'Test Select No Options',
description: 'A select with no options',
required: false,
'x-widget': 'select',
}
const onChange = vi.fn()
render(<DynamicField schema={schema} value="" onChange={onChange} />)
expect(screen.getByText('No options available for select')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,2 @@
export { DynamicConfigForm } from './DynamicConfigForm'
export { DynamicField } from './DynamicField'

View File

@ -0,0 +1,253 @@
import { describe, it, expect, beforeEach } from 'vitest'
import { FieldHookRegistry } from '../field-hooks'
import type { FieldHookComponent } from '../field-hooks'
describe('FieldHookRegistry', () => {
let registry: FieldHookRegistry
beforeEach(() => {
registry = new FieldHookRegistry()
})
describe('register', () => {
it('registers a hook with replace type', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'replace')
expect(registry.has('test.field')).toBe(true)
})
it('registers a hook with wrapper type', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'wrapper')
expect(registry.has('test.field')).toBe(true)
const entry = registry.get('test.field')
expect(entry?.type).toBe('wrapper')
})
it('defaults to replace type when not specified', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
const entry = registry.get('test.field')
expect(entry?.type).toBe('replace')
})
it('overwrites existing hook for same field path', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
registry.register('test.field', component1, 'replace')
registry.register('test.field', component2, 'wrapper')
const entry = registry.get('test.field')
expect(entry?.component).toBe(component2)
expect(entry?.type).toBe('wrapper')
})
})
describe('get', () => {
it('returns hook entry for registered field path', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component, 'replace')
const entry = registry.get('test.field')
expect(entry).toBeDefined()
expect(entry?.component).toBe(component)
expect(entry?.type).toBe('replace')
})
it('returns undefined for unregistered field path', () => {
const entry = registry.get('nonexistent.field')
expect(entry).toBeUndefined()
})
it('returns correct entry for nested field paths', () => {
const component: FieldHookComponent = () => null
registry.register('config.section.field', component, 'wrapper')
const entry = registry.get('config.section.field')
expect(entry).toBeDefined()
expect(entry?.type).toBe('wrapper')
})
})
describe('has', () => {
it('returns true for registered field path', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
expect(registry.has('test.field')).toBe(true)
})
it('returns false for unregistered field path', () => {
expect(registry.has('nonexistent.field')).toBe(false)
})
it('returns false after unregistering', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
registry.unregister('test.field')
expect(registry.has('test.field')).toBe(false)
})
})
describe('unregister', () => {
it('removes a registered hook', () => {
const component: FieldHookComponent = () => null
registry.register('test.field', component)
expect(registry.has('test.field')).toBe(true)
registry.unregister('test.field')
expect(registry.has('test.field')).toBe(false)
})
it('does not throw when unregistering non-existent hook', () => {
expect(() => registry.unregister('nonexistent.field')).not.toThrow()
})
it('only removes specified hook, not others', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
registry.register('field1', component1)
registry.register('field2', component2)
registry.unregister('field1')
expect(registry.has('field1')).toBe(false)
expect(registry.has('field2')).toBe(true)
})
})
describe('clear', () => {
it('removes all registered hooks', () => {
const component1: FieldHookComponent = () => null
const component2: FieldHookComponent = () => null
const component3: FieldHookComponent = () => null
registry.register('field1', component1)
registry.register('field2', component2)
registry.register('field3', component3)
expect(registry.getAllPaths()).toHaveLength(3)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
expect(registry.has('field1')).toBe(false)
expect(registry.has('field2')).toBe(false)
expect(registry.has('field3')).toBe(false)
})
it('works correctly on empty registry', () => {
expect(() => registry.clear()).not.toThrow()
expect(registry.getAllPaths()).toHaveLength(0)
})
})
describe('getAllPaths', () => {
it('returns empty array when no hooks registered', () => {
expect(registry.getAllPaths()).toEqual([])
})
it('returns all registered field paths', () => {
const component: FieldHookComponent = () => null
registry.register('field1', component)
registry.register('field2', component)
registry.register('field3', component)
const paths = registry.getAllPaths()
expect(paths).toHaveLength(3)
expect(paths).toContain('field1')
expect(paths).toContain('field2')
expect(paths).toContain('field3')
})
it('returns updated paths after unregister', () => {
const component: FieldHookComponent = () => null
registry.register('field1', component)
registry.register('field2', component)
registry.register('field3', component)
registry.unregister('field2')
const paths = registry.getAllPaths()
expect(paths).toHaveLength(2)
expect(paths).toContain('field1')
expect(paths).toContain('field3')
expect(paths).not.toContain('field2')
})
it('handles nested field paths correctly', () => {
const component: FieldHookComponent = () => null
registry.register('config.chat.enabled', component)
registry.register('config.chat.model', component)
registry.register('config.api.key', component)
const paths = registry.getAllPaths()
expect(paths).toHaveLength(3)
expect(paths).toContain('config.chat.enabled')
expect(paths).toContain('config.chat.model')
expect(paths).toContain('config.api.key')
})
})
describe('integration scenarios', () => {
it('supports full lifecycle of multiple hooks', () => {
const replaceComponent: FieldHookComponent = () => null
const wrapperComponent: FieldHookComponent = () => null
registry.register('field1', replaceComponent, 'replace')
registry.register('field2', wrapperComponent, 'wrapper')
expect(registry.getAllPaths()).toHaveLength(2)
const entry1 = registry.get('field1')
expect(entry1?.type).toBe('replace')
expect(entry1?.component).toBe(replaceComponent)
const entry2 = registry.get('field2')
expect(entry2?.type).toBe('wrapper')
expect(entry2?.component).toBe(wrapperComponent)
registry.unregister('field1')
expect(registry.getAllPaths()).toHaveLength(1)
expect(registry.has('field2')).toBe(true)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
})
it('handles rapid register/unregister cycles', () => {
const component: FieldHookComponent = () => null
for (let i = 0; i < 100; i++) {
registry.register(`field${i}`, component)
}
expect(registry.getAllPaths()).toHaveLength(100)
for (let i = 0; i < 50; i++) {
registry.unregister(`field${i}`)
}
expect(registry.getAllPaths()).toHaveLength(50)
registry.clear()
expect(registry.getAllPaths()).toHaveLength(0)
})
})
})

View File

@ -0,0 +1,99 @@
import type { ReactNode } from 'react'
/**
* Hook type for field-level customization
*/
export type FieldHookType = 'replace' | 'wrapper'
/**
* Props passed to a FieldHookComponent
*/
export interface FieldHookComponentProps {
fieldPath: string
value: unknown
onChange?: (value: unknown) => void
children?: ReactNode
}
/**
* A React component that can be registered as a field hook
*/
export type FieldHookComponent = React.FC<FieldHookComponentProps>
/**
* Registry entry for a field hook
*/
interface FieldHookEntry {
component: FieldHookComponent
type: FieldHookType
}
/**
* Registry for managing field-level hooks
* Supports two types of hooks:
* - replace: Completely replaces the default field renderer
* - wrapper: Wraps the default field renderer with additional functionality
*/
export class FieldHookRegistry {
private hooks: Map<string, FieldHookEntry> = new Map()
/**
* Register a hook for a specific field path
* @param fieldPath The field path (e.g., 'chat.talk_value')
* @param component The React component to register
* @param type The hook type ('replace' or 'wrapper')
*/
register(
fieldPath: string,
component: FieldHookComponent,
type: FieldHookType = 'replace'
): void {
this.hooks.set(fieldPath, { component, type })
}
/**
* Get a registered hook for a specific field path
* @param fieldPath The field path to look up
* @returns The hook entry if found, undefined otherwise
*/
get(fieldPath: string): FieldHookEntry | undefined {
return this.hooks.get(fieldPath)
}
/**
* Check if a hook is registered for a specific field path
* @param fieldPath The field path to check
* @returns True if a hook is registered, false otherwise
*/
has(fieldPath: string): boolean {
return this.hooks.has(fieldPath)
}
/**
* Unregister a hook for a specific field path
* @param fieldPath The field path to unregister
*/
unregister(fieldPath: string): void {
this.hooks.delete(fieldPath)
}
/**
* Clear all registered hooks
*/
clear(): void {
this.hooks.clear()
}
/**
* Get all registered field paths
* @returns Array of registered field paths
*/
getAllPaths(): string[] {
return Array.from(this.hooks.keys())
}
}
/**
* Singleton instance of the field hook registry
*/
export const fieldHooks = new FieldHookRegistry()

View File

@ -5,7 +5,7 @@
* *
*/ */
export const APP_VERSION = '0.12.2' export const APP_VERSION = '1.0.0'
export const APP_NAME = 'MaiBot Dashboard' export const APP_NAME = 'MaiBot Dashboard'
export const APP_FULL_NAME = `${APP_NAME} v${APP_VERSION}` export const APP_FULL_NAME = `${APP_NAME} v${APP_VERSION}`

View File

@ -69,6 +69,11 @@ import { useAutoSave, useConfigAutoSave } from './bot/hooks'
import { useCallback, useEffect, useRef, useState } from 'react' import { useCallback, useEffect, useRef, useState } from 'react'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
// 导入动态表单和 Hook 系统
import { DynamicConfigForm } from '@/components/dynamic-form'
import { fieldHooks } from '@/lib/field-hooks'
import { ChatSectionHook } from '@/routes/config/bot/hooks'
// ==================== 常量定义 ==================== // ==================== 常量定义 ====================
/** Toast 显示前的延迟时间 (毫秒) */ /** Toast 显示前的延迟时间 (毫秒) */
const TOAST_DISPLAY_DELAY = 500 const TOAST_DISPLAY_DELAY = 500
@ -308,6 +313,13 @@ function BotConfigPageContent() {
loadConfig() loadConfig()
}, [loadConfig]) }, [loadConfig])
useEffect(() => {
fieldHooks.register('chat', ChatSectionHook, 'replace')
return () => {
fieldHooks.unregister('chat')
}
}, [])
// 使用模块化的 useAutoSave hook // 使用模块化的 useAutoSave hook
const { triggerAutoSave, cancelPendingAutoSave } = useAutoSave( const { triggerAutoSave, cancelPendingAutoSave } = useAutoSave(
initialLoadRef.current, initialLoadRef.current,
@ -652,7 +664,24 @@ function BotConfigPageContent() {
{/* 聊天配置 */} {/* 聊天配置 */}
<TabsContent value="chat" className="space-y-4"> <TabsContent value="chat" className="space-y-4">
{chatConfig && <ChatSection config={chatConfig} onChange={setChatConfig} />} {chatConfig && (
<DynamicConfigForm
schema={{
className: 'ChatConfig',
classDoc: '聊天配置',
fields: [],
nested: {},
}}
values={{ chat: chatConfig }}
onChange={(field, value) => {
if (field === 'chat') {
setChatConfig(value as ChatConfig)
setHasUnsavedChanges(true)
}
}}
hooks={fieldHooks}
/>
)}
</TabsContent> </TabsContent>
{/* 表达配置 */} {/* 表达配置 */}

View File

@ -0,0 +1,15 @@
import type { FieldHookComponent } from '@/lib/field-hooks'
import { BotInfoSection } from '../sections/BotInfoSection'
/**
* BotInfoSection as a Field Hook Component
* This component replaces the entire 'bot' nested config section rendering
*/
export const BotInfoSectionHook: FieldHookComponent = ({ value, onChange }) => {
return (
<BotInfoSection
config={value as any}
onChange={(newConfig) => onChange?.(newConfig)}
/>
)
}

View File

@ -0,0 +1,617 @@
import React, { useState, useEffect, useMemo } from 'react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Switch } from '@/components/ui/switch'
import { Slider } from '@/components/ui/slider'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import {
AlertDialog,
AlertDialogAction,
AlertDialogCancel,
AlertDialogContent,
AlertDialogDescription,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogTitle,
AlertDialogTrigger,
} from '@/components/ui/alert-dialog'
import {
Popover,
PopoverContent,
PopoverTrigger,
} from '@/components/ui/popover'
import { Plus, Trash2, Eye, Clock } from 'lucide-react'
import type { FieldHookComponent } from '@/lib/field-hooks'
import type { ChatConfig } from '../types'
// 时间选择组件
const TimeRangePicker = React.memo(function TimeRangePicker({
value,
onChange,
}: {
value: string
onChange: (value: string) => void
}) {
// 解析初始值
const parsedValue = useMemo(() => {
const parts = value.split('-')
if (parts.length === 2) {
const [start, end] = parts
const [sh, sm] = start.split(':')
const [eh, em] = end.split(':')
return {
startHour: sh ? sh.padStart(2, '0') : '00',
startMinute: sm ? sm.padStart(2, '0') : '00',
endHour: eh ? eh.padStart(2, '0') : '23',
endMinute: em ? em.padStart(2, '0') : '59',
}
}
return {
startHour: '00',
startMinute: '00',
endHour: '23',
endMinute: '59',
}
}, [value])
const [startHour, setStartHour] = useState(parsedValue.startHour)
const [startMinute, setStartMinute] = useState(parsedValue.startMinute)
const [endHour, setEndHour] = useState(parsedValue.endHour)
const [endMinute, setEndMinute] = useState(parsedValue.endMinute)
// 当value变化时同步状态
useEffect(() => {
setStartHour(parsedValue.startHour)
setStartMinute(parsedValue.startMinute)
setEndHour(parsedValue.endHour)
setEndMinute(parsedValue.endMinute)
}, [parsedValue])
const updateTime = (
newStartHour: string,
newStartMinute: string,
newEndHour: string,
newEndMinute: string
) => {
const newValue = `${newStartHour}:${newStartMinute}-${newEndHour}:${newEndMinute}`
onChange(newValue)
}
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="outline" className="w-full justify-start font-mono text-sm">
<Clock className="h-4 w-4 mr-2" />
{value || '选择时间段'}
</Button>
</PopoverTrigger>
<PopoverContent className="w-72 sm:w-80">
<div className="space-y-4">
<div>
<h4 className="font-medium text-sm mb-3"></h4>
<div className="grid grid-cols-2 gap-2 sm:gap-3">
<div>
<Label className="text-xs"></Label>
<Select
value={startHour}
onValueChange={(v) => {
setStartHour(v)
updateTime(v, startMinute, endHour, endMinute)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{Array.from({ length: 24 }, (_, i) => i).map((h) => (
<SelectItem key={h} value={h.toString().padStart(2, '0')}>
{h.toString().padStart(2, '0')}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div>
<Label className="text-xs"></Label>
<Select
value={startMinute}
onValueChange={(v) => {
setStartMinute(v)
updateTime(startHour, v, endHour, endMinute)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{Array.from({ length: 60 }, (_, i) => i).map((m) => (
<SelectItem key={m} value={m.toString().padStart(2, '0')}>
{m.toString().padStart(2, '0')}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
</div>
<div>
<h4 className="font-medium text-sm mb-3"></h4>
<div className="grid grid-cols-2 gap-2 sm:gap-3">
<div>
<Label className="text-xs"></Label>
<Select
value={endHour}
onValueChange={(v) => {
setEndHour(v)
updateTime(startHour, startMinute, v, endMinute)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{Array.from({ length: 24 }, (_, i) => i).map((h) => (
<SelectItem key={h} value={h.toString().padStart(2, '0')}>
{h.toString().padStart(2, '0')}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div>
<Label className="text-xs"></Label>
<Select
value={endMinute}
onValueChange={(v) => {
setEndMinute(v)
updateTime(startHour, startMinute, endHour, v)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
{Array.from({ length: 60 }, (_, i) => i).map((m) => (
<SelectItem key={m} value={m.toString().padStart(2, '0')}>
{m.toString().padStart(2, '0')}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
</div>
</div>
</div>
</PopoverContent>
</Popover>
)
})
// 预览窗口组件
const RulePreview = React.memo(function RulePreview({ rule }: { rule: { target: string; time: string; value: number } }) {
const previewText = `{ target = "${rule.target}", time = "${rule.time}", value = ${rule.value.toFixed(1)} }`
return (
<Popover>
<PopoverTrigger asChild>
<Button variant="outline" size="sm">
<Eye className="h-4 w-4 mr-1" />
</Button>
</PopoverTrigger>
<PopoverContent className="w-80 sm:w-96">
<div className="space-y-2">
<h4 className="font-medium text-sm"></h4>
<div className="rounded-md bg-muted p-3 font-mono text-xs break-all">
{previewText}
</div>
<p className="text-xs text-muted-foreground">
bot_config.toml
</p>
</div>
</PopoverContent>
</Popover>
)
})
/**
* ChatSection as a Field Hook Component
* This component replaces the entire 'chat' nested config section rendering
*/
export const ChatSectionHook: FieldHookComponent = ({ value, onChange }) => {
// Cast value to ChatConfig (assuming it's the entire chat config object)
const config = value as ChatConfig
// Helper to update config
const updateConfig = (updates: Partial<ChatConfig>) => {
if (onChange) {
onChange({ ...config, ...updates })
}
}
// 添加发言频率规则
const addTalkValueRule = () => {
updateConfig({
talk_value_rules: [
...config.talk_value_rules,
{ target: '', time: '00:00-23:59', value: 1.0 },
],
})
}
// 删除发言频率规则
const removeTalkValueRule = (index: number) => {
updateConfig({
talk_value_rules: config.talk_value_rules.filter((_, i) => i !== index),
})
}
// 更新发言频率规则
const updateTalkValueRule = (
index: number,
field: 'target' | 'time' | 'value',
value: string | number
) => {
const newRules = [...config.talk_value_rules]
newRules[index] = {
...newRules[index],
[field]: value,
}
updateConfig({
talk_value_rules: newRules,
})
}
return (
<div className="rounded-lg border bg-card p-4 sm:p-6 space-y-6">
<div>
<h3 className="text-lg font-semibold mb-4"></h3>
<div className="grid gap-4">
<div className="grid gap-2">
<Label htmlFor="talk_value"></Label>
<Input
id="talk_value"
type="number"
step="0.1"
min="0"
max="1"
value={config.talk_value}
onChange={(e) => updateConfig({ talk_value: parseFloat(e.target.value) })}
/>
<p className="text-xs text-muted-foreground"> 0-1</p>
</div>
<div className="grid gap-2">
<Label htmlFor="think_mode"></Label>
<Select
value={config.think_mode || 'classic'}
onValueChange={(value) => updateConfig({ think_mode: value as 'classic' | 'deep' | 'dynamic' })}
>
<SelectTrigger id="think_mode">
<SelectValue placeholder="选择思考模式" />
</SelectTrigger>
<SelectContent>
<SelectItem value="classic"> - </SelectItem>
<SelectItem value="deep"> - </SelectItem>
<SelectItem value="dynamic"> - </SelectItem>
</SelectContent>
</Select>
<p className="text-xs text-muted-foreground">
</p>
</div>
<div className="flex items-center space-x-2">
<Switch
id="mentioned_bot_reply"
checked={config.mentioned_bot_reply}
onCheckedChange={(checked) =>
updateConfig({ mentioned_bot_reply: checked })
}
/>
<Label htmlFor="mentioned_bot_reply" className="cursor-pointer">
</Label>
</div>
<div className="grid gap-2">
<Label htmlFor="max_context_size"></Label>
<Input
id="max_context_size"
type="number"
min="1"
value={config.max_context_size}
onChange={(e) =>
updateConfig({ max_context_size: parseInt(e.target.value) })
}
/>
</div>
<div className="grid gap-2">
<Label htmlFor="planner_smooth"></Label>
<Input
id="planner_smooth"
type="number"
step="1"
min="0"
value={config.planner_smooth}
onChange={(e) =>
updateConfig({ planner_smooth: parseFloat(e.target.value) })
}
/>
<p className="text-xs text-muted-foreground">
planner 1-50
</p>
</div>
<div className="grid gap-2">
<Label htmlFor="plan_reply_log_max_per_chat"></Label>
<Input
id="plan_reply_log_max_per_chat"
type="number"
step="1"
min="100"
value={config.plan_reply_log_max_per_chat ?? 1024}
onChange={(e) =>
updateConfig({ plan_reply_log_max_per_chat: parseInt(e.target.value) })
}
/>
<p className="text-xs text-muted-foreground">
Plan/Reply
</p>
</div>
<div className="flex items-center space-x-2">
<Switch
id="llm_quote"
checked={config.llm_quote ?? false}
onCheckedChange={(checked) =>
updateConfig({ llm_quote: checked })
}
/>
<Label htmlFor="llm_quote" className="cursor-pointer">
LLM
</Label>
</div>
<p className="text-xs text-muted-foreground -mt-2 ml-10">
LLM
</p>
<div className="flex items-center space-x-2">
<Switch
id="enable_talk_value_rules"
checked={config.enable_talk_value_rules}
onCheckedChange={(checked) =>
updateConfig({ enable_talk_value_rules: checked })
}
/>
<Label htmlFor="enable_talk_value_rules" className="cursor-pointer">
</Label>
</div>
</div>
</div>
{/* 动态发言频率规则配置 */}
{config.enable_talk_value_rules && (
<div className="border-t pt-6">
<div className="flex items-center justify-between mb-4">
<div>
<h4 className="text-base font-semibold"></h4>
<p className="text-xs text-muted-foreground mt-1">
ID
</p>
</div>
<Button onClick={addTalkValueRule} size="sm">
<Plus className="h-4 w-4 mr-1" />
</Button>
</div>
{config.talk_value_rules && config.talk_value_rules.length > 0 ? (
<div className="space-y-4">
{config.talk_value_rules.map((rule, index) => (
<div key={index} className="rounded-lg border p-4 bg-muted/50 space-y-4">
<div className="flex items-center justify-between">
<span className="text-sm font-medium text-muted-foreground">
#{index + 1}
</span>
<div className="flex items-center gap-2">
<RulePreview rule={rule} />
<AlertDialog>
<AlertDialogTrigger asChild>
<Button variant="ghost" size="sm">
<Trash2 className="h-4 w-4 text-destructive" />
</Button>
</AlertDialogTrigger>
<AlertDialogContent>
<AlertDialogHeader>
<AlertDialogTitle></AlertDialogTitle>
<AlertDialogDescription>
#{index + 1}
</AlertDialogDescription>
</AlertDialogHeader>
<AlertDialogFooter>
<AlertDialogCancel></AlertDialogCancel>
<AlertDialogAction onClick={() => removeTalkValueRule(index)}>
</AlertDialogAction>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialog>
</div>
</div>
<div className="space-y-4">
{/* 配置类型选择 */}
<div className="grid gap-2">
<Label className="text-xs font-medium"></Label>
<Select
value={rule.target === '' ? 'global' : 'specific'}
onValueChange={(value) => {
if (value === 'global') {
updateTalkValueRule(index, 'target', '')
} else {
updateTalkValueRule(index, 'target', 'qq::group')
}
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="global"></SelectItem>
<SelectItem value="specific"></SelectItem>
</SelectContent>
</Select>
</div>
{/* 详细配置选项 - 只在非全局时显示 */}
{rule.target !== '' && (() => {
const parts = rule.target.split(':')
const platform = parts[0] || 'qq'
const chatId = parts[1] || ''
const chatType = parts[2] || 'group'
return (
<div className="grid gap-4 p-3 sm:p-4 rounded-lg bg-muted/50">
<div className="grid grid-cols-1 sm:grid-cols-3 gap-3">
<div className="grid gap-2">
<Label className="text-xs font-medium"></Label>
<Select
value={platform}
onValueChange={(value) => {
updateTalkValueRule(index, 'target', `${value}:${chatId}:${chatType}`)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="qq">QQ</SelectItem>
<SelectItem value="wx"></SelectItem>
</SelectContent>
</Select>
</div>
<div className="grid gap-2">
<Label className="text-xs font-medium"> ID</Label>
<Input
value={chatId}
onChange={(e) => {
updateTalkValueRule(index, 'target', `${platform}:${e.target.value}:${chatType}`)
}}
placeholder="输入群 ID"
className="font-mono text-sm"
/>
</div>
<div className="grid gap-2">
<Label className="text-xs font-medium"></Label>
<Select
value={chatType}
onValueChange={(value) => {
updateTalkValueRule(index, 'target', `${platform}:${chatId}:${value}`)
}}
>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="group">group</SelectItem>
<SelectItem value="private">private</SelectItem>
</SelectContent>
</Select>
</div>
</div>
<p className="text-xs text-muted-foreground">
ID{rule.target || '(未设置)'}
</p>
</div>
)
})()}
{/* 时间段选择器 */}
<div className="grid gap-2">
<Label className="text-xs font-medium"> (Time)</Label>
<TimeRangePicker
value={rule.time}
onChange={(v) => updateTalkValueRule(index, 'time', v)}
/>
<p className="text-xs text-muted-foreground">
23:00-02:00
</p>
</div>
{/* 发言频率滑块 */}
<div className="grid gap-3">
<div className="flex items-center justify-between">
<Label htmlFor={`rule-value-${index}`} className="text-xs font-medium">
(Value)
</Label>
<Input
id={`rule-value-${index}`}
type="number"
step="0.01"
min="0.01"
max="1"
value={rule.value}
onChange={(e) => {
const val = parseFloat(e.target.value)
if (!isNaN(val)) {
updateTalkValueRule(index, 'value', Math.max(0.01, Math.min(1, val)))
}
}}
className="w-20 h-8 text-xs"
/>
</div>
<Slider
value={[rule.value]}
onValueChange={(values) =>
updateTalkValueRule(index, 'value', values[0])
}
min={0.01}
max={1}
step={0.01}
className="w-full"
/>
<div className="flex justify-between text-xs text-muted-foreground">
<span>0.01 ()</span>
<span>0.5</span>
<span>1.0 ()</span>
</div>
</div>
</div>
</div>
))}
</div>
) : (
<div className="text-center py-8 text-muted-foreground">
<p className="text-sm">"添加规则"</p>
</div>
)}
<div className="mt-4 p-4 bg-blue-50 dark:bg-blue-950/20 border border-blue-200 dark:border-blue-800 rounded-lg">
<h5 className="text-sm font-semibold text-blue-900 dark:text-blue-100 mb-2">
📝
</h5>
<ul className="text-xs text-blue-800 dark:text-blue-200 space-y-1">
<li> <strong>Target </strong></li>
<li> <strong>Target </strong>platform:id:type</li>
<li> <strong></strong></li>
<li> <strong></strong> 23:00-02:00 112</li>
<li> <strong></strong> 0-10 1 </li>
</ul>
</div>
</div>
)}
</div>
)
}

View File

@ -0,0 +1,15 @@
import type { FieldHookComponent } from '@/lib/field-hooks'
import { DebugSection } from '../sections/DebugSection'
/**
* DebugSection as a Field Hook Component
* This component replaces the entire 'debug' nested config section rendering
*/
export const DebugSectionHook: FieldHookComponent = ({ value, onChange }) => {
return (
<DebugSection
config={value as any}
onChange={(newConfig) => onChange?.(newConfig)}
/>
)
}

View File

@ -0,0 +1,15 @@
import type { FieldHookComponent } from '@/lib/field-hooks'
import { ExpressionSection } from '../sections/ExpressionSection'
/**
* ExpressionSection as a Field Hook Component
* This component replaces the entire 'expression' nested config section rendering
*/
export const ExpressionSectionHook: FieldHookComponent = ({ value, onChange }) => {
return (
<ExpressionSection
config={value as any}
onChange={(newConfig) => onChange?.(newConfig)}
/>
)
}

View File

@ -0,0 +1,15 @@
import type { FieldHookComponent } from '@/lib/field-hooks'
import { PersonalitySection } from '../sections/PersonalitySection'
/**
* PersonalitySection as a Field Hook Component
* This component replaces the entire 'personality' nested config section rendering
*/
export const PersonalitySectionHook: FieldHookComponent = ({ value, onChange }) => {
return (
<PersonalitySection
config={value as any}
onChange={(newConfig) => onChange?.(newConfig)}
/>
)
}

View File

@ -4,3 +4,8 @@
export { useAutoSave, useConfigAutoSave } from './useAutoSave' export { useAutoSave, useConfigAutoSave } from './useAutoSave'
export type { UseAutoSaveOptions, UseAutoSaveReturn, AutoSaveState } from './useAutoSave' export type { UseAutoSaveOptions, UseAutoSaveReturn, AutoSaveState } from './useAutoSave'
export { ChatSectionHook } from './ChatSectionHook'
export { PersonalitySectionHook } from './PersonalitySectionHook'
export { DebugSectionHook } from './DebugSectionHook'
export { ExpressionSectionHook } from './ExpressionSectionHook'
export { BotInfoSectionHook } from './BotInfoSectionHook'

View File

@ -58,9 +58,13 @@ import { SharePackDialog } from '@/components/share-pack-dialog'
// 导入模块化的类型定义和组件 // 导入模块化的类型定义和组件
import type { ModelInfo, ProviderConfig, ModelTaskConfig, TaskConfig } from './model/types' import type { ModelInfo, ProviderConfig, ModelTaskConfig, TaskConfig } from './model/types'
import { TaskConfigCard, Pagination, ModelTable, ModelCardList } from './model/components' import { Pagination, ModelTable, ModelCardList } from './model/components'
import { useModelTour, useModelFetcher, useModelAutoSave } from './model/hooks' import { useModelTour, useModelFetcher, useModelAutoSave } from './model/hooks'
// 导入动态表单和 Hook 系统
import { DynamicConfigForm } from '@/components/dynamic-form'
import { fieldHooks } from '@/lib/field-hooks'
// 主导出组件:包装 RestartProvider // 主导出组件:包装 RestartProvider
export function ModelConfigPage() { export function ModelConfigPage() {
return ( return (
@ -918,101 +922,22 @@ function ModelConfigPageContent() {
</p> </p>
{taskConfig && ( {taskConfig && (
<div className="grid gap-4 sm:gap-6"> <DynamicConfigForm
{/* Utils 任务 */} schema={{
<TaskConfigCard className: 'TaskConfig',
title="组件模型 (utils)" classDoc: '任务配置',
description="用于表情包、取名、关系、情绪变化等组件" fields: [],
taskConfig={taskConfig.utils} nested: {},
modelNames={modelNames} }}
onChange={(field, value) => updateTaskConfig('utils', field, value)} values={{ taskConfig }}
dataTour="task-model-select" onChange={(field, value) => {
/> if (field === 'taskConfig') {
setTaskConfig(value as ModelTaskConfig)
{/* Tool Use 任务 */} setHasUnsavedChanges(true)
<TaskConfigCard }
title="工具调用模型 (tool_use)" }}
description="需要使用支持工具调用的模型" hooks={fieldHooks}
taskConfig={taskConfig.tool_use} />
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('tool_use', field, value)}
/>
{/* Replyer 任务 */}
<TaskConfigCard
title="首要回复模型 (replyer)"
description="用于表达器和表达方式学习"
taskConfig={taskConfig.replyer}
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('replyer', field, value)}
/>
{/* Planner 任务 */}
<TaskConfigCard
title="决策模型 (planner)"
description="负责决定麦麦该什么时候回复"
taskConfig={taskConfig.planner}
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('planner', field, value)}
/>
{/* VLM 任务 */}
<TaskConfigCard
title="图像识别模型 (vlm)"
description="视觉语言模型"
taskConfig={taskConfig.vlm}
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('vlm', field, value)}
hideTemperature
/>
{/* Voice 任务 */}
<TaskConfigCard
title="语音识别模型 (voice)"
description="语音转文字"
taskConfig={taskConfig.voice}
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('voice', field, value)}
hideTemperature
hideMaxTokens
/>
{/* Embedding 任务 */}
<TaskConfigCard
title="嵌入模型 (embedding)"
description="用于向量化"
taskConfig={taskConfig.embedding}
modelNames={modelNames}
onChange={(field, value) => updateTaskConfig('embedding', field, value)}
hideTemperature
hideMaxTokens
/>
{/* LPMM 相关任务 */}
<div className="space-y-4">
<h3 className="text-lg font-semibold">LPMM </h3>
<TaskConfigCard
title="实体提取模型 (lpmm_entity_extract)"
description="从文本中提取实体"
taskConfig={taskConfig.lpmm_entity_extract}
modelNames={modelNames}
onChange={(field, value) =>
updateTaskConfig('lpmm_entity_extract', field, value)
}
/>
<TaskConfigCard
title="RDF 构建模型 (lpmm_rdf_build)"
description="构建知识图谱"
taskConfig={taskConfig.lpmm_rdf_build}
modelNames={modelNames}
onChange={(field, value) =>
updateTaskConfig('lpmm_rdf_build', field, value)
}
/>
</div>
</div>
)} )}
</TabsContent> </TabsContent>
</Tabs> </Tabs>

View File

@ -0,0 +1,22 @@
import '@testing-library/jest-dom/vitest'
global.ResizeObserver = class ResizeObserver {
observe() {}
unobserve() {}
disconnect() {}
}
Object.defineProperty(window, 'matchMedia', {
writable: true,
value: (query: string) => ({
matches: false,
media: query,
onchange: null,
addListener: () => {},
removeListener: () => {},
addEventListener: () => {},
removeEventListener: () => {},
dispatchEvent: () => {},
}),
})

View File

@ -12,6 +12,8 @@ export type FieldType =
| 'object' | 'object'
| 'textarea' | 'textarea'
export type XWidgetType = 'slider' | 'select' | 'textarea' | 'switch' | 'custom'
export interface FieldSchema { export interface FieldSchema {
name: string name: string
type: FieldType type: FieldType
@ -26,6 +28,9 @@ export interface FieldSchema {
type: string type: string
} }
properties?: ConfigSchema properties?: ConfigSchema
'x-widget'?: XWidgetType
'x-icon'?: string
step?: number
} }
export interface ConfigSchema { export interface ConfigSchema {

View File

@ -2,6 +2,7 @@
"files": [], "files": [],
"references": [ "references": [
{ "path": "./tsconfig.app.json" }, { "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" } { "path": "./tsconfig.node.json" },
{ "path": "./tsconfig.vitest.json" }
] ]
} }

View File

@ -0,0 +1,7 @@
{
"extends": "./tsconfig.app.json",
"compilerOptions": {
"types": ["vitest/globals", "@testing-library/jest-dom"]
},
"include": ["src"]
}

View File

@ -1,3 +1,4 @@
/// <reference types="vitest" />
import { defineConfig } from 'vite' import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react' import react from '@vitejs/plugin-react'
import path from 'path' import path from 'path'
@ -5,6 +6,11 @@ import path from 'path'
// https://vite.dev/config/ // https://vite.dev/config/
export default defineConfig({ export default defineConfig({
plugins: [react()], plugins: [react()],
test: {
globals: true,
environment: 'jsdom',
setupFiles: './src/test/setup.ts',
},
server: { server: {
port: 7999, port: 7999,
proxy: { proxy: {
@ -23,6 +29,9 @@ export default defineConfig({
'@': path.resolve(__dirname, './src'), '@': path.resolve(__dirname, './src'),
}, },
}, },
optimizeDeps: {
include: ['react', 'react-dom'],
},
build: { build: {
rollupOptions: { rollupOptions: {
output: { output: {

View File

@ -0,0 +1,18 @@
/// <reference types="vitest" />
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
import path from 'path'
export default defineConfig({
plugins: [react()],
test: {
globals: true,
environment: 'jsdom',
setupFiles: './src/test/setup.ts',
},
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
},
},
})

View File

@ -0,0 +1,6 @@
import sys
from pathlib import Path
# Add project root to Python path so src imports work
project_root = Path(__file__).parent.parent.absolute()
sys.path.insert(0, str(project_root))

View File

View File

@ -0,0 +1,78 @@
import pytest
from src.config.official_configs import ChatConfig
from src.config.config import Config
from src.webui.config_schema import ConfigSchemaGenerator
def test_field_docs_in_schema():
"""Test that field descriptions are correctly extracted from field_docs (docstrings)."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify description field exists
assert "description" in talk_value
# Verify description contains expected Chinese text from the docstring
assert "聊天频率" in talk_value["description"]
def test_json_schema_extra_merged():
"""Test that json_schema_extra fields are correctly merged into output."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify UI metadata fields from json_schema_extra exist
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("x-icon") == "message-circle"
assert talk_value.get("step") == 0.1
def test_pydantic_constraints_mapped():
"""Test that Pydantic constraints (ge/le) are correctly mapped to minValue/maxValue."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
talk_value = next(f for f in schema["fields"] if f["name"] == "talk_value")
# Verify constraints are mapped to frontend naming convention
assert "minValue" in talk_value
assert "maxValue" in talk_value
assert talk_value["minValue"] == 0 # From ge=0
assert talk_value["maxValue"] == 1 # From le=1
def test_nested_model_schema():
"""Test that nested models (ConfigBase fields) are correctly handled."""
schema = ConfigSchemaGenerator.generate_schema(Config)
# Verify nested structure exists
assert "nested" in schema
assert "chat" in schema["nested"]
# Verify nested chat schema is complete
chat_schema = schema["nested"]["chat"]
assert chat_schema["className"] == "ChatConfig"
assert "fields" in chat_schema
# Verify nested schema fields include description and metadata
talk_value = next(f for f in chat_schema["fields"] if f["name"] == "talk_value")
assert "description" in talk_value
assert talk_value.get("x-widget") == "slider"
assert talk_value.get("minValue") == 0
def test_field_without_extra_metadata():
"""Test that fields without json_schema_extra still generate valid schema."""
schema = ConfigSchemaGenerator.generate_schema(ChatConfig)
max_context_size = next(f for f in schema["fields"] if f["name"] == "max_context_size")
# Verify basic fields are generated
assert "name" in max_context_size
assert max_context_size["name"] == "max_context_size"
assert "type" in max_context_size
assert max_context_size["type"] == "integer"
assert "label" in max_context_size
assert "required" in max_context_size
# Verify no x-widget or x-icon from json_schema_extra (since field has none)
# These fields should only be present if explicitly defined in json_schema_extra
assert not max_context_size.get("x-widget")
assert not max_context_size.get("x-icon")

View File

@ -0,0 +1,461 @@
"""表情包路由 API 测试
测试 src/webui/routers/emoji.py 中的核心 emoji 路由端点
使用内存 SQLite 数据库和 FastAPI TestClient
"""
from contextlib import contextmanager
from datetime import datetime
from typing import Generator
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import Images, ImageType
from src.webui.core import TokenManager
from src.webui.routers.emoji import router
@pytest.fixture(scope="function")
def test_engine():
"""创建内存 SQLite 引擎用于测试"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(scope="function")
def test_session(test_engine) -> Generator[Session, None, None]:
"""创建测试数据库会话"""
with Session(test_engine) as session:
yield session
@pytest.fixture(scope="function")
def test_app(test_session):
"""创建测试 FastAPI 应用并覆盖 get_db_session 依赖"""
app = FastAPI()
app.include_router(router)
# Create a context manager that yields the test session
@contextmanager
def override_get_db_session(auto_commit=True):
"""Override get_db_session to use test session"""
try:
yield test_session
if auto_commit:
test_session.commit()
except Exception:
test_session.rollback()
raise
with patch("src.webui.routers.emoji.get_db_session", override_get_db_session):
yield app
@pytest.fixture(scope="function")
def client(test_app):
"""创建 TestClient"""
return TestClient(test_app)
@pytest.fixture(scope="function")
def auth_token():
"""创建有效的认证 token"""
token_manager = TokenManager(secret_key="test-secret-key", token_expire_hours=24)
return token_manager.create_token()
@pytest.fixture(scope="function")
def sample_emojis(test_session) -> list[Images]:
"""插入测试用表情包数据"""
import hashlib
emojis = [
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test1.png",
image_hash=hashlib.sha256(b"test1").hexdigest(),
description="测试表情包 1",
emotion="开心,快乐",
query_count=10,
is_registered=True,
is_banned=False,
record_time=datetime(2026, 1, 1, 10, 0, 0),
register_time=datetime(2026, 1, 1, 10, 0, 0),
last_used_time=datetime(2026, 1, 2, 10, 0, 0),
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test2.gif",
image_hash=hashlib.sha256(b"test2").hexdigest(),
description="测试表情包 2",
emotion="难过",
query_count=5,
is_registered=False,
is_banned=False,
record_time=datetime(2026, 1, 3, 10, 0, 0),
register_time=None,
last_used_time=None,
),
Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/test3.webp",
image_hash=hashlib.sha256(b"test3").hexdigest(),
description="测试表情包 3",
emotion="生气",
query_count=20,
is_registered=True,
is_banned=True,
record_time=datetime(2026, 1, 4, 10, 0, 0),
register_time=datetime(2026, 1, 4, 10, 0, 0),
last_used_time=datetime(2026, 1, 5, 10, 0, 0),
),
]
for emoji in emojis:
test_session.add(emoji)
test_session.commit()
for emoji in emojis:
test_session.refresh(emoji)
return emojis
@pytest.fixture(scope="function")
def mock_token_verify():
"""Mock token verification to always succeed"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=True):
yield
# ==================== 测试用例 ====================
def test_list_emojis_basic(client, sample_emojis, mock_token_verify):
"""测试获取表情包列表(基本分页)"""
response = client.get("/emoji/list?page=1&page_size=10")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 10
assert len(data["data"]) == 3
# 验证第一个表情包字段
emoji = data["data"][0]
assert "id" in emoji
assert "full_path" in emoji
assert "emoji_hash" in emoji
assert "description" in emoji
assert "query_count" in emoji
assert "is_registered" in emoji
assert "is_banned" in emoji
assert "emotion" in emoji
assert "record_time" in emoji
assert "register_time" in emoji
assert "last_used_time" in emoji
def test_list_emojis_pagination(client, sample_emojis, mock_token_verify):
"""测试分页功能"""
response = client.get("/emoji/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert len(data["data"]) == 2
# 第二页
response = client.get("/emoji/list?page=2&page_size=2")
data = response.json()
assert len(data["data"]) == 1
def test_list_emojis_search(client, sample_emojis, mock_token_verify):
"""测试搜索过滤"""
response = client.get("/emoji/list?search=表情包 2")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["description"] == "测试表情包 2"
def test_list_emojis_filter_registered(client, sample_emojis, mock_token_verify):
"""测试 is_registered 过滤"""
response = client.get("/emoji/list?is_registered=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 2
assert all(emoji["is_registered"] is True for emoji in data["data"])
def test_list_emojis_filter_banned(client, sample_emojis, mock_token_verify):
"""测试 is_banned 过滤"""
response = client.get("/emoji/list?is_banned=true")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert data["data"][0]["is_banned"] is True
def test_list_emojis_sort_by_query_count(client, sample_emojis, mock_token_verify):
"""测试按 query_count 排序"""
response = client.get("/emoji/list?sort_by=query_count&sort_order=desc")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# 验证降序排列 (20 > 10 > 5)
assert data["data"][0]["query_count"] == 20
assert data["data"][1]["query_count"] == 10
assert data["data"][2]["query_count"] == 5
def test_get_emoji_detail_success(client, sample_emojis, mock_token_verify):
"""测试获取表情包详情(成功)"""
emoji_id = sample_emojis[0].id
response = client.get(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == emoji_id
assert data["data"]["emoji_hash"] == sample_emojis[0].image_hash
def test_get_emoji_detail_not_found(client, mock_token_verify):
"""测试获取不存在的表情包404"""
response = client.get("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_emoji_description(client, sample_emojis, mock_token_verify):
"""测试更新表情包描述"""
emoji_id = sample_emojis[0].id
response = client.patch(
f"/emoji/{emoji_id}",
json={"description": "更新后的描述"},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["description"] == "更新后的描述"
assert "成功更新" in data["message"]
def test_update_emoji_register_status(client, sample_emojis, mock_token_verify, test_session):
"""测试更新注册状态False -> True 应设置 register_time"""
emoji_id = sample_emojis[1].id # 未注册的表情包
response = client.patch(
f"/emoji/{emoji_id}",
json={"is_registered": True},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["is_registered"] is True
assert data["data"]["register_time"] is not None # 应该设置了注册时间
def test_update_emoji_no_fields(client, sample_emojis, mock_token_verify):
"""测试更新请求未提供任何字段400"""
emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={})
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_update_emoji_not_found(client, mock_token_verify):
"""测试更新不存在的表情包404"""
response = client.patch("/emoji/99999", json={"description": "test"})
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_delete_emoji_success(client, sample_emojis, mock_token_verify, test_session):
"""测试删除表情包(成功)"""
emoji_id = sample_emojis[0].id
response = client.delete(f"/emoji/{emoji_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_delete_emoji_not_found(client, mock_token_verify):
"""测试删除不存在的表情包404"""
response = client.delete("/emoji/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_batch_delete_success(client, sample_emojis, mock_token_verify, test_session):
"""测试批量删除表情包(全部成功)"""
emoji_ids = [sample_emojis[0].id, sample_emojis[1].id]
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert data["failed_count"] == 0
assert "成功删除 2 个表情包" in data["message"]
# 验证数据库中已删除
from sqlmodel import select
for emoji_id in emoji_ids:
statement = select(Images).where(Images.id == emoji_id)
result = test_session.exec(statement).first()
assert result is None
def test_batch_delete_partial_failure(client, sample_emojis, mock_token_verify):
"""测试批量删除(部分失败)"""
emoji_ids = [sample_emojis[0].id, 99999] # 第二个 ID 不存在
response = client.post("/emoji/batch/delete", json={"emoji_ids": emoji_ids})
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 1
assert data["failed_count"] == 1
assert 99999 in data["failed_ids"]
def test_batch_delete_empty_list(client, mock_token_verify):
"""测试批量删除空列表400"""
response = client.post("/emoji/batch/delete", json={"emoji_ids": []})
assert response.status_code == 400
data = response.json()
assert "未提供要删除的表情包ID" in data["detail"]
def test_auth_required_list(client):
"""测试未认证访问列表端点401"""
# Without mock_token_verify fixture
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
response = client.get("/emoji/list")
# verify_auth_token 返回 False 会触发 HTTPException
# 但具体状态码取决于 verify_auth_token_from_cookie_or_header 的实现
# 这里假设它抛出 401
def test_auth_required_update(client, sample_emojis):
"""测试未认证访问更新端点401"""
with patch("src.webui.routers.emoji.verify_auth_token", return_value=False):
emoji_id = sample_emojis[0].id
response = client.patch(f"/emoji/{emoji_id}", json={"description": "test"})
# Should be unauthorized
def test_emoji_to_response_field_mapping(sample_emojis):
"""测试 emoji_to_response 字段映射image_hash -> emoji_hash"""
from src.webui.routers.emoji import emoji_to_response
emoji = sample_emojis[0]
response = emoji_to_response(emoji)
# 验证 API 字段名称
assert hasattr(response, "emoji_hash")
assert response.emoji_hash == emoji.image_hash
# 验证时间戳转换
assert isinstance(response.record_time, float)
assert response.record_time == emoji.record_time.timestamp()
if emoji.register_time:
assert isinstance(response.register_time, float)
assert response.register_time == emoji.register_time.timestamp()
def test_list_emojis_only_emoji_type(client, test_session, mock_token_verify):
"""测试列表只返回 type=EMOJI 的记录(不包括其他类型)"""
# 插入一个非 EMOJI 类型的图片
non_emoji = Images(
image_type=ImageType.IMAGE, # 不是 EMOJI
full_path="/data/images/test.png",
image_hash="hash_image",
description="非表情包图片",
query_count=0,
is_registered=False,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(non_emoji)
test_session.commit()
# 插入一个 EMOJI 类型
emoji = Images(
image_type=ImageType.EMOJI,
full_path="/data/emoji_registed/emoji.png",
image_hash="hash_emoji",
description="表情包",
query_count=0,
is_registered=True,
is_banned=False,
record_time=datetime.now(),
)
test_session.add(emoji)
test_session.commit()
response = client.get("/emoji/list")
assert response.status_code == 200
data = response.json()
# 只应该返回 1 个 EMOJI 类型的记录
assert data["total"] == 1
assert data["data"][0]["description"] == "表情包"

View File

@ -0,0 +1,504 @@
"""Expression routes pytest tests"""
from datetime import datetime
from typing import Generator
from unittest.mock import MagicMock
import pytest
from fastapi import FastAPI, APIRouter
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, select
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
def create_test_app() -> FastAPI:
"""Create minimal test app with only expression router"""
app = FastAPI(title="Test App")
from src.webui.routers.expression import router as expression_router
main_router = APIRouter(prefix="/api/webui")
main_router.include_router(expression_router)
app.include_router(main_router)
return app
app = create_test_app()
# Test database setup
@pytest.fixture(name="test_engine")
def test_engine_fixture():
"""Create in-memory SQLite database for testing"""
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture(name="test_session")
def test_session_fixture(test_engine) -> Generator[Session, None, None]:
"""Create a test database session with transaction rollback"""
connection = test_engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client")
def client_fixture(test_session: Session, monkeypatch) -> Generator[TestClient, None, None]:
"""Create TestClient with overridden database session"""
from contextlib import contextmanager
@contextmanager
def get_test_db_session():
yield test_session
monkeypatch.setattr("src.webui.routers.expression.get_db_session", get_test_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="mock_auth")
def mock_auth_fixture(monkeypatch):
"""Mock authentication to always return True"""
mock_verify = MagicMock(return_value=True)
monkeypatch.setattr("src.webui.routers.expression.verify_auth_token_from_cookie_or_header", mock_verify)
@pytest.fixture(name="sample_expression")
def sample_expression_fixture(test_session: Session) -> Expression:
"""Insert a sample expression into test database"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (1, '测试情景', '测试风格', '测试上下文', '测试上文', '[\"测试内容1\", \"测试内容2\"]', 10, '2026-02-17 12:00:00', '2026-02-15 10:00:00', 'test_chat_001')"
)
)
test_session.commit()
expression = test_session.exec(select(Expression).where(Expression.id == 1)).first()
assert expression is not None
return expression
# ============ Tests ============
def test_list_expressions_empty(client: TestClient, mock_auth):
"""Test GET /expression/list with empty database"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 0
assert data["page"] == 1
assert data["page_size"] == 20
assert data["data"] == []
def test_list_expressions_with_data(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/list returns expression data"""
response = client.get("/api/webui/expression/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
expr_data = data["data"][0]
assert expr_data["id"] == sample_expression.id
assert expr_data["situation"] == "测试情景"
assert expr_data["style"] == "测试风格"
assert expr_data["chat_id"] == "test_chat_001"
def test_list_expressions_pagination(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list pagination works correctly"""
for i in range(5):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '', '', '[]', 0, '2026-02-17 12:0{i}:00', '2026-02-15 10:00:00', 'chat_{i}')"
)
)
test_session.commit()
# Request page 1 with page_size=2
response = client.get("/api/webui/expression/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 5
assert data["page"] == 1
assert data["page_size"] == 2
assert len(data["data"]) == 2
# Request page 2
response = client.get("/api/webui/expression/list?page=2&page_size=2")
data = response.json()
assert data["page"] == 2
assert len(data["data"]) == 2
def test_list_expressions_search(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with search filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (1, '找人吃饭', '热情', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_001')"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (2, '拒绝邀请', '礼貌', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_002')"
)
)
test_session.commit()
# Search for "吃饭"
response = client.get("/api/webui/expression/list?search=吃饭")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "找人吃饭"
def test_list_expressions_chat_filter(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/list with chat_id filter"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (1, '情景A', '风格A', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_A')"
)
)
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (2, '情景B', '风格B', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_B')"
)
)
test_session.commit()
# Filter by chat_A
response = client.get("/api/webui/expression/list?chat_id=chat_A")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["situation"] == "情景A"
assert data["data"][0]["chat_id"] == "chat_A"
def test_get_expression_detail_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/{id} returns correct detail"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == sample_expression.id
assert data["data"]["situation"] == "测试情景"
assert data["data"]["style"] == "测试风格"
assert data["data"]["chat_id"] == "test_chat_001"
def test_get_expression_detail_not_found(client: TestClient, mock_auth):
"""Test GET /expression/{id} returns 404 for non-existent ID"""
response = client.get("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_expression_response_has_legacy_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test that ExpressionResponse includes legacy fields (checked/rejected/modified_by)"""
response = client.get(f"/api/webui/expression/{sample_expression.id}")
assert response.status_code == 200
data = response.json()["data"]
# Verify legacy fields exist and have default values
assert "checked" in data
assert "rejected" in data
assert "modified_by" in data
# Verify hardcoded default values (from expression_to_response)
assert data["checked"] is False
assert data["rejected"] is False
assert data["modified_by"] is None
def test_update_expression_without_removed_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} does not accept checked/rejected fields"""
# Valid update request (only allowed fields)
update_payload = {
"situation": "更新后的情景",
"style": "更新后的风格",
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "更新后的情景"
assert data["data"]["style"] == "更新后的风格"
# Verify legacy fields still returned (hardcoded values)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_ignores_invalid_fields(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} ignores fields not in ExpressionUpdateRequest"""
# Request with invalid field (checked not in schema)
update_payload = {
"situation": "新情景",
"checked": True, # This field should be ignored by Pydantic
"rejected": True, # This field should be ignored
}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["situation"] == "新情景"
# Response should have hardcoded False values (not True from request)
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
def test_update_expression_chat_id_mapping(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} correctly maps chat_id to session_id"""
update_payload = {"chat_id": "updated_chat_999"}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Verify chat_id is returned in response (mapped from session_id)
assert data["data"]["chat_id"] == "updated_chat_999"
def test_update_expression_not_found(client: TestClient, mock_auth):
"""Test PATCH /expression/{id} returns 404 for non-existent ID"""
update_payload = {"situation": "新情景"}
response = client.patch("/api/webui/expression/99999", json=update_payload)
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_update_expression_empty_request(client: TestClient, mock_auth, sample_expression: Expression):
"""Test PATCH /expression/{id} returns 400 for empty update request"""
update_payload = {}
response = client.patch(f"/api/webui/expression/{sample_expression.id}", json=update_payload)
assert response.status_code == 400
data = response.json()
assert "未提供任何需要更新的字段" in data["detail"]
def test_delete_expression_success(client: TestClient, mock_auth, sample_expression: Expression):
"""Test DELETE /expression/{id} successfully deletes expression"""
expression_id = sample_expression.id
response = client.delete(f"/api/webui/expression/{expression_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除" in data["message"]
# Verify expression is deleted
get_response = client.get(f"/api/webui/expression/{expression_id}")
assert get_response.status_code == 404
def test_delete_expression_not_found(client: TestClient, mock_auth):
"""Test DELETE /expression/{id} returns 404 for non-existent ID"""
response = client.delete("/api/webui/expression/99999")
assert response.status_code == 404
data = response.json()
assert "未找到" in data["detail"]
def test_create_expression_success(client: TestClient, mock_auth):
"""Test POST /expression/ successfully creates expression"""
create_payload = {
"situation": "新建情景",
"style": "新建风格",
"chat_id": "new_chat_123",
}
response = client.post("/api/webui/expression/", json=create_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "创建成功" in data["message"]
assert data["data"]["situation"] == "新建情景"
assert data["data"]["style"] == "新建风格"
assert data["data"]["chat_id"] == "new_chat_123"
# Verify legacy fields
assert data["data"]["checked"] is False
assert data["data"]["rejected"] is False
assert data["data"]["modified_by"] is None
def test_batch_delete_expressions_success(client: TestClient, mock_auth, test_session: Session):
"""Test POST /expression/batch/delete successfully deletes multiple expressions"""
expression_ids = []
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
f"VALUES ({i + 1}, '批量删除{i}', '风格{i}', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_{i}')"
)
)
expression_ids.append(i + 1)
test_session.commit()
delete_payload = {"ids": expression_ids}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功删除 3 个" in data["message"]
for expr_id in expression_ids:
get_response = client.get(f"/api/webui/expression/{expr_id}")
assert get_response.status_code == 404
def test_batch_delete_partial_not_found(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/batch/delete handles partial not found IDs"""
delete_payload = {"ids": [sample_expression.id, 88888, 99999]}
response = client.post("/api/webui/expression/batch/delete", json=delete_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# Should delete only the 1 valid ID
assert "成功删除 1 个" in data["message"]
def test_get_expression_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/stats/summary returns correct statistics"""
for i in range(3):
test_session.execute(
text(
f"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
f"VALUES ({i + 1}, '情景{i}', '风格{i}', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_{i % 2}')"
)
)
test_session.commit()
response = client.get("/api/webui/expression/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["total"] == 3
assert data["data"]["chat_count"] == 2
def test_get_review_stats(client: TestClient, mock_auth, test_session: Session):
"""Test GET /expression/review/stats returns hardcoded 0 counts"""
test_session.execute(
text(
"INSERT INTO expressions (id, situation, style, context, up_content, content_list, count, last_active_time, create_time, session_id) "
"VALUES (1, '待审核', '风格', '', '', '[]', 0, datetime('now'), datetime('now'), 'chat_001')"
)
)
test_session.commit()
response = client.get("/api/webui/expression/review/stats")
assert response.status_code == 200
data = response.json()
# Verify all review counts are 0 (hardcoded in refactored code)
assert data["total"] == 1 # Total expressions exists
assert data["unchecked"] == 0
assert data["passed"] == 0
assert data["rejected"] == 0
assert data["ai_checked"] == 0
assert data["user_checked"] == 0
def test_get_review_list_filter_unchecked(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=unchecked returns empty (legacy behavior)"""
# filter_type=unchecked should return no results (legacy removed)
response = client.get("/api/webui/expression/review/list?filter_type=unchecked")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 0 # No results (legacy fields removed)
def test_get_review_list_filter_all(client: TestClient, mock_auth, sample_expression: Expression):
"""Test GET /expression/review/list with filter_type=all returns all expressions"""
response = client.get("/api/webui/expression/review/list?filter_type=all")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 1
assert len(data["data"]) == 1
def test_batch_review_expressions_unsupported(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch returns failure for require_unchecked=True"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": True}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["failed"] == 1 # Should fail because require_unchecked=True
assert "不支持审核状态过滤" in data["results"][0]["message"]
def test_batch_review_expressions_no_unchecked_check(client: TestClient, mock_auth, sample_expression: Expression):
"""Test POST /expression/review/batch succeeds when require_unchecked=False"""
review_payload = {"items": [{"id": sample_expression.id, "rejected": False, "require_unchecked": False}]}
response = client.post("/api/webui/expression/review/batch", json=review_payload)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["succeeded"] == 1
assert data["results"][0]["success"] is True

View File

@ -0,0 +1,512 @@
"""测试 jargon 路由的完整性和正确性"""
import json
from datetime import datetime
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, SQLModel, create_engine
from src.common.database.database_model import ChatSession, Jargon
from src.webui.routers.jargon import router as jargon_router
@pytest.fixture(name="app", scope="function")
def app_fixture():
app = FastAPI()
app.include_router(jargon_router, prefix="/api/webui")
return app
@pytest.fixture(name="engine", scope="function")
def engine_fixture():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
yield engine
@pytest.fixture(name="session", scope="function")
def session_fixture(engine):
connection = engine.connect()
transaction = connection.begin()
session = Session(bind=connection)
yield session
session.close()
transaction.rollback()
connection.close()
@pytest.fixture(name="client", scope="function")
def client_fixture(app: FastAPI, session: Session, monkeypatch):
from contextlib import contextmanager
@contextmanager
def mock_get_db_session():
yield session
monkeypatch.setattr("src.webui.routers.jargon.get_db_session", mock_get_db_session)
with TestClient(app) as client:
yield client
@pytest.fixture(name="sample_chat_session")
def sample_chat_session_fixture(session: Session):
"""创建示例 ChatSession"""
chat_session = ChatSession(
session_id="test_stream_001",
platform="qq",
group_id="123456789",
user_id=None,
created_timestamp=datetime.now(),
last_active_timestamp=datetime.now(),
)
session.add(chat_session)
session.commit()
session.refresh(chat_session)
return chat_session
@pytest.fixture(name="sample_jargons")
def sample_jargons_fixture(session: Session, sample_chat_session: ChatSession):
"""创建示例 Jargon 数据"""
jargons = [
Jargon(
id=1,
content="yyds",
raw_content="永远的神",
meaning="永远的神",
session_id=sample_chat_session.session_id,
count=10,
is_jargon=True,
is_complete=False,
),
Jargon(
id=2,
content="awsl",
raw_content="啊我死了",
meaning="啊我死了",
session_id=sample_chat_session.session_id,
count=5,
is_jargon=True,
is_complete=False,
),
Jargon(
id=3,
content="hello",
raw_content=None,
meaning="你好",
session_id=sample_chat_session.session_id,
count=2,
is_jargon=False,
is_complete=False,
),
]
for jargon in jargons:
session.add(jargon)
session.commit()
for jargon in jargons:
session.refresh(jargon)
return jargons
# ==================== Test Cases ====================
def test_list_jargons(client: TestClient, sample_jargons):
"""测试 GET /jargon/list 基础列表功能"""
response = client.get("/api/webui/jargon/list")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["total"] == 3
assert data["page"] == 1
assert data["page_size"] == 20
assert len(data["data"]) == 3
assert data["data"][0]["content"] == "yyds"
assert data["data"][0]["count"] == 10
def test_list_jargons_with_pagination(client: TestClient, sample_jargons):
"""测试分页功能"""
response = client.get("/api/webui/jargon/list?page=1&page_size=2")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["data"]) == 2
response = client.get("/api/webui/jargon/list?page=2&page_size=2")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
def test_list_jargons_with_search(client: TestClient, sample_jargons):
"""测试 GET /jargon/list?search=xxx 搜索功能"""
response = client.get("/api/webui/jargon/list?search=yyds")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"
# 测试搜索 meaning
response = client.get("/api/webui/jargon/list?search=你好")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_list_jargons_with_chat_id_filter(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试按 chat_id 筛选"""
response = client.get(f"/api/webui/jargon/list?chat_id={sample_chat_session.session_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
# 测试不存在的 chat_id
response = client.get("/api/webui/jargon/list?chat_id=nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
def test_list_jargons_with_is_jargon_filter(client: TestClient, sample_jargons):
"""测试按 is_jargon 筛选"""
response = client.get("/api/webui/jargon/list?is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(item["is_jargon"] is True for item in data["data"])
response = client.get("/api/webui/jargon/list?is_jargon=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "hello"
def test_get_jargon_detail(client: TestClient, sample_jargons):
"""测试 GET /jargon/{id} 获取详情"""
jargon_id = sample_jargons[0].id
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["id"] == jargon_id
assert data["data"]["content"] == "yyds"
assert data["data"]["meaning"] == "永远的神"
assert data["data"]["count"] == 10
assert data["data"]["is_jargon"] is True
def test_get_jargon_detail_not_found(client: TestClient):
"""测试获取不存在的黑话详情"""
response = client.get("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon(client: TestClient, sample_chat_session: ChatSession):
"""测试 POST /jargon/ 创建黑话"""
request_data = {
"content": "新黑话",
"raw_content": "原始内容",
"meaning": "含义",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "创建成功"
assert data["data"]["content"] == "新黑话"
assert data["data"]["meaning"] == "含义"
assert data["data"]["count"] == 0
assert data["data"]["is_jargon"] is None
assert data["data"]["is_complete"] is False
def test_create_duplicate_jargon(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试创建重复黑话返回 400"""
request_data = {
"content": "yyds",
"meaning": "重复的",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 400
assert "已存在相同内容的黑话" in response.json()["detail"]
def test_update_jargon(client: TestClient, sample_jargons):
"""测试 PATCH /jargon/{id} 更新黑话"""
jargon_id = sample_jargons[0].id
update_data = {
"meaning": "更新后的含义",
"is_jargon": True,
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "更新成功"
assert data["data"]["meaning"] == "更新后的含义"
assert data["data"]["is_jargon"] is True
assert data["data"]["content"] == "yyds" # 未改变的字段保持不变
def test_update_jargon_with_chat_id_mapping(client: TestClient, sample_jargons):
"""测试更新时 chat_id → session_id 的映射"""
jargon_id = sample_jargons[0].id
update_data = {
"chat_id": "new_session_id",
}
response = client.patch(f"/api/webui/jargon/{jargon_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["data"]["chat_id"] == "new_session_id"
def test_update_jargon_not_found(client: TestClient):
"""测试更新不存在的黑话"""
response = client.patch("/api/webui/jargon/99999", json={"meaning": "test"})
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_delete_jargon(client: TestClient, sample_jargons, session: Session):
"""测试 DELETE /jargon/{id} 删除黑话"""
jargon_id = sample_jargons[0].id
response = client.delete(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["message"] == "删除成功"
assert data["deleted_count"] == 1
# 验证数据库中已删除
response = client.get(f"/api/webui/jargon/{jargon_id}")
assert response.status_code == 404
def test_delete_jargon_not_found(client: TestClient):
"""测试删除不存在的黑话"""
response = client.delete("/api/webui/jargon/99999")
assert response.status_code == 404
assert "黑话不存在" in response.json()["detail"]
def test_batch_delete(client: TestClient, sample_jargons):
"""测试 POST /jargon/batch/delete 批量删除"""
ids_to_delete = [sample_jargons[0].id, sample_jargons[1].id]
request_data = {"ids": ids_to_delete}
response = client.post("/api/webui/jargon/batch/delete", json=request_data)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["deleted_count"] == 2
assert "成功删除 2 条黑话" in data["message"]
# 验证已删除
response = client.get(f"/api/webui/jargon/{ids_to_delete[0]}")
assert response.status_code == 404
def test_batch_delete_empty_list(client: TestClient):
"""测试批量删除空列表返回 400"""
response = client.post("/api/webui/jargon/batch/delete", json={"ids": []})
assert response.status_code == 400
assert "ID列表不能为空" in response.json()["detail"]
def test_batch_set_jargon_status(client: TestClient, sample_jargons):
"""测试批量设置黑话状态"""
ids = [sample_jargons[0].id, sample_jargons[1].id]
response = client.post(
"/api/webui/jargon/batch/set-jargon",
params={"ids": ids, "is_jargon": False},
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "成功更新 2 条黑话状态" in data["message"]
# 验证状态已更新
detail_response = client.get(f"/api/webui/jargon/{ids[0]}")
assert detail_response.json()["data"]["is_jargon"] is False
def test_get_stats(client: TestClient, sample_jargons):
"""测试 GET /jargon/stats/summary 统计数据"""
response = client.get("/api/webui/jargon/stats/summary")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
stats = data["data"]
assert stats["total"] == 3
assert stats["confirmed_jargon"] == 2
assert stats["confirmed_not_jargon"] == 1
assert stats["pending"] == 0
assert stats["complete_count"] == 0
assert stats["chat_count"] == 1
assert isinstance(stats["top_chats"], dict)
def test_get_chat_list(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 GET /jargon/chats 获取聊天列表"""
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert len(data["data"]) == 1
chat_info = data["data"][0]
assert chat_info["chat_id"] == sample_chat_session.session_id
assert chat_info["platform"] == "qq"
assert chat_info["is_group"] is True
assert chat_info["chat_name"] == sample_chat_session.group_id
def test_get_chat_list_with_json_chat_id(client: TestClient, session: Session, sample_chat_session: ChatSession):
"""测试解析 JSON 格式的 chat_id"""
json_chat_id = json.dumps([[sample_chat_session.session_id, "user123"]])
jargon = Jargon(
id=100,
content="测试黑话",
meaning="测试",
session_id=json_chat_id,
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == sample_chat_session.session_id
def test_get_chat_list_without_chat_session(client: TestClient, session: Session):
"""测试聊天列表中没有对应 ChatSession 的情况"""
jargon = Jargon(
id=101,
content="孤立黑话",
meaning="无对应会话",
session_id="nonexistent_stream_id",
count=1,
)
session.add(jargon)
session.commit()
response = client.get("/api/webui/jargon/chats")
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 1
assert data["data"][0]["chat_id"] == "nonexistent_stream_id"
assert data["data"][0]["chat_name"] == "nonexistent_stream_id"[:20]
assert data["data"][0]["platform"] is None
assert data["data"][0]["is_group"] is False
def test_jargon_response_fields(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试 JargonResponse 字段完整性"""
response = client.get(f"/api/webui/jargon/{sample_jargons[0].id}")
assert response.status_code == 200
data = response.json()["data"]
# 验证所有必需字段存在
required_fields = [
"id",
"content",
"raw_content",
"meaning",
"chat_id",
"stream_id",
"chat_name",
"count",
"is_jargon",
"is_complete",
"inference_with_context",
"inference_content_only",
]
for field in required_fields:
assert field in data
# 验证 chat_name 显示逻辑
assert data["chat_name"] == sample_chat_session.group_id
@pytest.mark.skip(reason="Composite PK (id+content) prevents autoincrement - database model issue")
def test_create_jargon_without_optional_fields(client: TestClient, sample_chat_session: ChatSession):
"""测试创建黑话时可选字段为空"""
request_data = {
"content": "简单黑话",
"chat_id": sample_chat_session.session_id,
}
response = client.post("/api/webui/jargon/", json=request_data)
assert response.status_code == 200
data = response.json()["data"]
assert data["raw_content"] is None
assert data["meaning"] == ""
def test_update_jargon_partial_fields(client: TestClient, sample_jargons):
"""测试增量更新(只更新部分字段)"""
jargon_id = sample_jargons[0].id
original_content = sample_jargons[0].content
# 只更新 meaning
response = client.patch(f"/api/webui/jargon/{jargon_id}", json={"meaning": "新含义"})
assert response.status_code == 200
data = response.json()["data"]
assert data["meaning"] == "新含义"
assert data["content"] == original_content # 其他字段不变
def test_list_jargons_multiple_filters(client: TestClient, sample_jargons, sample_chat_session: ChatSession):
"""测试组合多个过滤条件"""
response = client.get(f"/api/webui/jargon/list?search=永远&chat_id={sample_chat_session.session_id}&is_jargon=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["data"][0]["content"] == "yyds"

View File

@ -0,0 +1,89 @@
"""
TOML文件工具函数 - 保留格式和注释
"""
import os
import tomlkit
from typing import Any
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
"""
保存TOML数据到文件保留现有格式如果文件存在
Args:
data: 要保存的数据字典
file_path: 文件路径
"""
# 如果文件不存在,直接创建
if not os.path.exists(file_path):
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 如果文件存在,尝试读取现有文件以保留格式
try:
with open(file_path, "r", encoding="utf-8") as f:
existing_doc = tomlkit.load(f)
except Exception:
# 如果读取失败,直接覆盖
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 递归更新,保留现有格式
_merge_toml_preserving_format(existing_doc, data)
# 保存
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(existing_doc, f)
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
递归合并source到target保留target中的格式和注释
Args:
target: 目标文档保留格式
source: 源数据新数据
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典且都是表格,递归合并
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是字典/表格
_merge_toml_preserving_format(target[key], value)
else:
target[key] = value
else:
# 其他情况直接替换
target[key] = value
else:
# 新键直接添加
target[key] = value
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
更新TOML文档中的字段保留现有的格式和注释
这是一个递归函数用于在部分更新配置时保留现有的格式和注释
Args:
target: 目标表格会被修改
source: 源数据新数据
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典,递归更新
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是表格
_update_toml_doc(target[key], value)
else:
target[key] = value
else:
# 直接更新值,保留注释
target[key] = value
else:
# 新键直接添加
target[key] = value

View File

@ -5,7 +5,7 @@ import types
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal from typing import Any, Dict, List, Literal, Set, Tuple, Union, cast, get_args, get_origin
__all__ = ["ConfigBase", "Field", "AttributeData"] __all__ = ["ConfigBase", "Field", "AttributeData"]
@ -44,6 +44,16 @@ class AttrDocBase:
# 从类定义节点中提取字段文档 # 从类定义节点中提取字段文档
return self._extract_field_docs(class_node, allow_extra_methods) return self._extract_field_docs(class_node, allow_extra_methods)
@classmethod
def get_class_field_docs(cls) -> dict[str, str]:
class_source = cls._get_class_source()
class_node = cls._find_class_node(class_source)
return AttrDocBase._extract_field_docs(
cast(AttrDocBase, cast(Any, cls)),
class_node,
allow_extra_methods=False,
)
@classmethod @classmethod
def _get_class_source(cls) -> str: def _get_class_source(cls) -> str:
"""获取类定义所在文件的完整源代码""" """获取类定义所在文件的完整源代码"""
@ -265,7 +275,7 @@ class ConfigBase(BaseModel, AttrDocBase):
if origin_type in (int, float, str, bool, complex, bytes, Any): if origin_type in (int, float, str, bool, complex, bytes, Any):
continue continue
# 允许嵌套的ConfigBase自定义类 # 允许嵌套的ConfigBase自定义类
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore if isinstance(origin_type, type) and issubclass(cast(type, origin_type), ConfigBase):
continue continue
# 只允许 list, set, dict 三类泛型 # 只允许 list, set, dict 三类泛型
if origin_type not in (list, set, dict, List, Set, Dict, Literal): if origin_type not in (list, set, dict, List, Set, Dict, Literal):

View File

@ -5,25 +5,73 @@ from .config_base import ConfigBase, Field
class APIProvider(ConfigBase): class APIProvider(ConfigBase):
"""API提供商配置类""" """API提供商配置类"""
name: str = "" name: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "tag",
},
)
"""API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)""" """API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
base_url: str = "" base_url: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "link",
},
)
"""API服务商的BaseURL""" """API服务商的BaseURL"""
api_key: str = Field(default_factory=str, repr=False) api_key: str = Field(
default_factory=str,
repr=False,
json_schema_extra={
"x-widget": "input",
"x-icon": "key",
},
)
"""API密钥""" """API密钥"""
client_type: str = Field(default="openai") client_type: str = Field(
default="openai",
json_schema_extra={
"x-widget": "select",
"x-icon": "settings",
},
)
"""客户端类型 (可选: openai/google, 默认为openai)""" """客户端类型 (可选: openai/google, 默认为openai)"""
max_retry: int = Field(default=2) max_retry: int = Field(
default=2,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "repeat",
},
)
"""最大重试次数 (单个模型API调用失败, 最多重试的次数)""" """最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
timeout: int = 10 timeout: int = Field(
default=10,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "clock",
"step": 1,
},
)
"""API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)""" """API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
retry_interval: int = 10 retry_interval: int = Field(
default=10,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "timer",
"step": 1,
},
)
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
def model_post_init(self, context: Any = None): def model_post_init(self, context: Any = None):
@ -39,34 +87,93 @@ class APIProvider(ConfigBase):
class ModelInfo(ConfigBase): class ModelInfo(ConfigBase):
"""单个模型信息配置类""" """单个模型信息配置类"""
_validate_any: bool = False _validate_any: bool = False
suppress_any_warning: bool = True suppress_any_warning: bool = True
model_identifier: str = "" model_identifier: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "package",
},
)
"""模型标识符 (API服务商提供的模型标识符)""" """模型标识符 (API服务商提供的模型标识符)"""
name: str = "" name: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "tag",
},
)
"""模型名称 (可随意命名, 在models中需使用这个命名)""" """模型名称 (可随意命名, 在models中需使用这个命名)"""
api_provider: str = "" api_provider: str = Field(
default="",
json_schema_extra={
"x-widget": "select",
"x-icon": "link",
},
)
"""API服务商名称 (对应在api_providers中配置的服务商名称)""" """API服务商名称 (对应在api_providers中配置的服务商名称)"""
price_in: float = Field(default=0.0) price_in: float = Field(
default=0.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "dollar-sign",
"step": 0.001,
},
)
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
price_out: float = Field(default=0.0) price_out: float = Field(
default=0.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "dollar-sign",
"step": 0.001,
},
)
"""输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" """输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
temperature: float | None = Field(default=None) temperature: float | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "thermometer",
},
)
"""模型级别温度(可选),会覆盖任务配置中的温度""" """模型级别温度(可选),会覆盖任务配置中的温度"""
max_tokens: int | None = Field(default=None) max_tokens: int | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
},
)
"""模型级别最大token数可选会覆盖任务配置中的max_tokens""" """模型级别最大token数可选会覆盖任务配置中的max_tokens"""
force_stream_mode: bool = Field(default=False) force_stream_mode: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "zap",
},
)
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)""" """强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
extra_params: dict[str, Any] = Field(default_factory=dict) extra_params: dict[str, Any] = Field(
default_factory=dict,
json_schema_extra={
"x-widget": "custom",
"x-icon": "sliders",
},
)
"""额外参数 (用于API调用时的额外配置)""" """额外参数 (用于API调用时的额外配置)"""
def model_post_init(self, context: Any = None): def model_post_init(self, context: Any = None):
@ -82,48 +189,139 @@ class ModelInfo(ConfigBase):
class TaskConfig(ConfigBase): class TaskConfig(ConfigBase):
"""任务配置类""" """任务配置类"""
model_list: list[str] = Field(default_factory=list) model_list: list[str] = Field(
default_factory=list,
json_schema_extra={
"x-widget": "custom",
"x-icon": "list",
},
)
"""使用的模型列表, 每个元素对应上面的模型名称(name)""" """使用的模型列表, 每个元素对应上面的模型名称(name)"""
max_tokens: int = 1024 max_tokens: int = Field(
default=1024,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
"step": 1,
},
)
"""任务最大输出token数""" """任务最大输出token数"""
temperature: float = 0.3 temperature: float = Field(
default=0.3,
ge=0,
le=2,
json_schema_extra={
"x-widget": "slider",
"x-icon": "thermometer",
"step": 0.1,
},
)
"""模型温度""" """模型温度"""
slow_threshold: float = 15.0 slow_threshold: float = Field(
default=15.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "alert-circle",
"step": 0.1,
},
)
"""慢请求阈值(秒),超过此值会输出警告日志""" """慢请求阈值(秒),超过此值会输出警告日志"""
selection_strategy: str = Field(default="balance") selection_strategy: str = Field(
default="balance",
json_schema_extra={
"x-widget": "select",
"x-icon": "shuffle",
},
)
"""模型选择策略balance负载均衡或 random随机选择""" """模型选择策略balance负载均衡或 random随机选择"""
class ModelTaskConfig(ConfigBase): class ModelTaskConfig(ConfigBase):
"""模型配置类""" """模型配置类"""
utils: TaskConfig = Field(default_factory=TaskConfig) utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型""" """组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
replyer: TaskConfig = Field(default_factory=TaskConfig) replyer: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "message-square",
},
)
"""首要回复模型配置, 还用于表达器和表达方式学习""" """首要回复模型配置, 还用于表达器和表达方式学习"""
vlm: TaskConfig = Field(default_factory=TaskConfig) vlm: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "image",
},
)
"""视觉模型配置""" """视觉模型配置"""
voice: TaskConfig = Field(default_factory=TaskConfig) voice: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "volume-2",
},
)
"""语音识别模型配置""" """语音识别模型配置"""
tool_use: TaskConfig = Field(default_factory=TaskConfig) tool_use: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "tools",
},
)
"""工具使用模型配置, 需要使用支持工具调用的模型""" """工具使用模型配置, 需要使用支持工具调用的模型"""
planner: TaskConfig = Field(default_factory=TaskConfig) planner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "map",
},
)
"""规划模型配置""" """规划模型配置"""
embedding: TaskConfig = Field(default_factory=TaskConfig) embedding: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "database",
},
)
"""嵌入模型配置""" """嵌入模型配置"""
lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig) lpmm_entity_extract: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "filter",
},
)
"""LPMM实体提取模型配置""" """LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig) lpmm_rdf_build: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "network",
},
)
"""LPMM RDF构建模型配置""" """LPMM RDF构建模型配置"""

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,301 @@
"""
规划器监控API
提供规划器日志数据的查询接口
性能优化
1. 聊天摘要只统计文件数量和最新时间戳不读取文件内容
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/planner", tags=["planner"])
# 规划器日志目录
PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
plan_count: int
latest_timestamp: float
latest_filename: str
class PlanLogSummary(BaseModel):
"""规划日志摘要"""
chat_id: str
timestamp: float
filename: str
action_count: int
action_types: List[str] # 动作类型列表
total_plan_ms: float
llm_duration_ms: float
reasoning_preview: str
class PlanLogDetail(BaseModel):
"""规划日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
reasoning: str
raw_output: str
actions: List[Dict]
timing: Dict
extra: Optional[Dict] = None
class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计"""
total_chats: int
total_plans: int
chats: List[ChatSummary]
class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表"""
data: List[PlanLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=PlannerOverview)
async def get_planner_overview():
"""
获取规划器总览 - 轻量级接口
只统计文件数量不读取文件内容
"""
if not PLAN_LOG_DIR.exists():
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
chats = []
total_plans = 0
for chat_dir in PLAN_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
plan_count = len(json_files)
total_plans += plan_count
if plan_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview(
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
async def get_chat_plan_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的规划日志列表分页
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedChatLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
actions = data.get('actions', [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
logs.append(PlanLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
reasoning_preview=reasoning[:100] if reasoning else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview='[读取失败]'
))
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
async def get_log_detail(chat_id: str, filename: str):
"""获取规划日志详情 - 按需加载完整内容"""
log_file = PLAN_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return PlanLogDetail(**data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容旧接口 ==========
@router.get("/stats")
async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口"""
overview = await get_planner_overview()
# 获取最近10条计划的摘要
recent_plans = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
recent_plans.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
recent_plans = recent_plans[:10]
return {
"total_chats": overview.total_chats,
"total_plans": overview.total_plans,
"avg_plan_time_ms": 0,
"avg_llm_time_ms": 0,
"recent_plans": recent_plans
}
@router.get("/chats")
async def get_chat_list():
"""获取所有聊天ID列表 - 兼容旧接口"""
overview = await get_planner_overview()
return [chat.chat_id for chat in overview.chats]
@router.get("/all-logs")
async def get_all_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
"""获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size}
# 收集所有文件
all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file))
# 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
total = len(all_files)
offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size]
logs = []
for chat_id, log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
logs.append({
"chat_id": data.get('chat_id', chat_id),
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get('actions', [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
"reasoning_preview": reasoning[:100] if reasoning else ''
})
except Exception:
continue
return {"data": logs, "total": total, "page": page, "page_size": page_size}

View File

@ -0,0 +1,269 @@
"""
回复器监控API
提供回复器日志数据的查询接口
性能优化
1. 聊天摘要只统计文件数量和最新时间戳不读取文件内容
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/replier", tags=["replier"])
# 回复器日志目录
REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
reply_count: int
latest_timestamp: float
latest_filename: str
class ReplyLogSummary(BaseModel):
"""回复日志摘要"""
chat_id: str
timestamp: float
filename: str
model: str
success: bool
llm_ms: float
overall_ms: float
output_preview: str
class ReplyLogDetail(BaseModel):
"""回复日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
output: str
processed_output: List[str]
model: str
reasoning: str
think_level: int
timing: Dict
error: Optional[str] = None
success: bool
class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计"""
total_chats: int
total_replies: int
chats: List[ReplierChatSummary]
class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表"""
data: List[ReplyLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=ReplierOverview)
async def get_replier_overview():
"""
获取回复器总览 - 轻量级接口
只统计文件数量不读取文件内容
"""
if not REPLY_LOG_DIR.exists():
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
chats = []
total_replies = 0
for chat_dir in REPLY_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
reply_count = len(json_files)
total_replies += reply_count
if reply_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview(
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
async def get_chat_reply_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的回复日志列表分页
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedReplyLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
output = data.get('output', '')
logs.append(ReplyLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get('model', ''),
success=data.get('success', True),
llm_ms=data.get('timing', {}).get('llm_ms', 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0),
output_preview=output[:100] if output else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model='',
success=False,
llm_ms=0,
overall_ms=0,
output_preview='[读取失败]'
))
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
async def get_reply_log_detail(chat_id: str, filename: str):
"""获取回复日志详情 - 按需加载完整内容"""
log_file = REPLY_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return ReplyLogDetail(
type=data.get('type', 'reply'),
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', 0),
prompt=data.get('prompt', ''),
output=data.get('output', ''),
processed_output=data.get('processed_output', []),
model=data.get('model', ''),
reasoning=data.get('reasoning', ''),
think_level=data.get('think_level', 0),
timing=data.get('timing', {}),
error=data.get('error'),
success=data.get('success', True)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容接口 ==========
@router.get("/stats")
async def get_replier_stats():
"""获取回复器统计信息"""
overview = await get_replier_overview()
# 获取最近10条回复的摘要
recent_replies = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
recent_replies.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
recent_replies = recent_replies[:10]
return {
"total_chats": overview.total_chats,
"total_replies": overview.total_replies,
"recent_replies": recent_replies
}
@router.get("/chats")
async def get_replier_chat_list():
"""获取所有聊天ID列表"""
overview = await get_replier_overview()
return [chat.chat_id for chat in overview.chats]

View File

@ -0,0 +1,133 @@
import inspect
from typing import Any, get_args, get_origin
from pydantic_core import PydanticUndefined
from src.config.config_base import ConfigBase
class ConfigSchemaGenerator:
@classmethod
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
return cls.generate_config_schema(config_class, include_nested=include_nested)
@classmethod
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
fields: list[dict[str, Any]] = []
nested: dict[str, dict[str, Any]] = {}
for field_name, field_info in config_class.model_fields.items():
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
continue
field_schema = cls._build_field_schema(config_class, field_name, field_info.annotation, field_info)
fields.append(field_schema)
if include_nested:
nested_schema = cls._build_nested_schema(field_info.annotation)
if nested_schema is not None:
nested[field_name] = nested_schema
return {
"className": config_class.__name__,
"classDoc": (config_class.__doc__ or "").strip(),
"fields": fields,
"nested": nested,
}
@classmethod
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
origin = get_origin(annotation)
args = get_args(annotation)
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return cls.generate_config_schema(annotation)
if origin in {list, tuple} and args:
first = args[0]
if inspect.isclass(first) and issubclass(first, ConfigBase):
return cls.generate_config_schema(first)
return None
@classmethod
def _build_field_schema(
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
) -> dict[str, Any]:
field_docs = config_class.get_class_field_docs()
field_type = cls._map_field_type(annotation)
schema: dict[str, Any] = {
"name": field_name,
"type": field_type,
"label": field_name,
"description": field_docs.get(field_name, field_info.description or ""),
"required": field_info.is_required(),
}
if field_info.default is not PydanticUndefined:
schema["default"] = field_info.default
origin = get_origin(annotation)
args = get_args(annotation)
if origin is list and args:
schema["items"] = {"type": cls._map_field_type(args[0])}
options = cls._extract_options(annotation)
if options:
schema["options"] = options
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
if hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra:
schema.update(field_info.json_schema_extra)
# Task 1d: Map Pydantic constraints to minValue/maxValue (frontend naming convention)
if hasattr(field_info, "metadata") and field_info.metadata:
for constraint in field_info.metadata:
if hasattr(constraint, "ge"):
schema["minValue"] = constraint.ge
if hasattr(constraint, "le"):
schema["maxValue"] = constraint.le
return schema
@staticmethod
def _extract_options(annotation: Any) -> list[str] | None:
origin = get_origin(annotation)
if origin is None:
return None
if str(origin) != "typing.Literal":
return None
args = get_args(annotation)
options = [str(item) for item in args]
return options or None
@classmethod
def _map_field_type(cls, annotation: Any) -> str:
origin = get_origin(annotation)
args = get_args(annotation)
if origin in {list, tuple}:
return "array"
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return "object"
if annotation is bool:
return "boolean"
if annotation is int:
return "integer"
if annotation is float:
return "number"
if annotation is str:
return "string"
if origin in {list, tuple} and args:
return "array"
if origin in {dict}:
return "object"
if origin is not None and str(origin) == "typing.Literal":
return "select"
return "string"

View File

@ -10,7 +10,7 @@ from typing import Any, Annotated, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import ( from src.config.official_configs import (
BotConfig, BotConfig,
PersonalityConfig, PersonalityConfig,
@ -77,7 +77,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
async def get_model_config_schema(_auth: bool = Depends(require_auth)): async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)""" """获取模型配置架构(包含提供商和模型任务配置)"""
try: try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig) schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取模型配置架构失败: {e}") logger.error(f"获取模型配置架构失败: {e}")
@ -227,7 +227,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
try: try:
# 验证配置数据 # 验证配置数据
try: try:
APIAdapterConfig.from_dict(config_data) ModelConfig.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
@ -377,7 +377,7 @@ async def update_model_config_section(
# 验证完整配置 # 验证完整配置
try: try:
APIAdapterConfig.from_dict(config_data) ModelConfig.from_dict(config_data)
except Exception as e: except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}") logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider # 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider

View File

@ -1,21 +1,27 @@
"""表情包管理 API 路由""" """表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, List, Optional
import asyncio
import hashlib
import io
import os
import threading
from fastapi import APIRouter, Cookie, File, Form, Header, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse, JSONResponse from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List, Annotated
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
import time
import os
import hashlib
from PIL import Image from PIL import Image
import io from sqlalchemy import func
from pathlib import Path from sqlmodel import col, select
import threading
import asyncio from src.common.database.database import get_db_session
from concurrent.futures import ThreadPoolExecutor from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
logger = get_logger("webui.emoji") logger = get_logger("webui.emoji")
@ -61,7 +67,7 @@ def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
def _ensure_thumbnail_cache_dir() -> Path: def _ensure_thumbnail_cache_dir() -> Path:
"""确保缩略图缓存目录存在""" """确保缩略图缓存目录存在"""
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True) _ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
return THUMBNAIL_CACHE_DIR return THUMBNAIL_CACHE_DIR
@ -99,7 +105,7 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
try: try:
with Image.open(source_path) as img: with Image.open(source_path) as img:
# GIF 处理:提取第一帧 # GIF 处理:提取第一帧
if hasattr(img, "n_frames") and img.n_frames > 1: if getattr(img, "n_frames", 1) > 1:
img.seek(0) # 确保在第一帧 img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度) # 转换为 RGB/RGBAWebP 支持透明度)
@ -138,9 +144,9 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
return 0, 0 return 0, 0
# 获取所有表情包的哈希值 # 获取所有表情包的哈希值
valid_hashes = set() with get_db_session() as session:
for emoji in Emoji.select(Emoji.emoji_hash): statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI)
valid_hashes.add(emoji.emoji_hash) valid_hashes = set(session.exec(statement).all())
cleaned = 0 cleaned = 0
kept = 0 kept = 0
@ -179,7 +185,6 @@ class EmojiResponse(BaseModel):
id: int id: int
full_path: str full_path: str
format: str
emoji_hash: str emoji_hash: str
description: str description: str
query_count: int query_count: int
@ -188,7 +193,6 @@ class EmojiResponse(BaseModel):
emotion: Optional[str] # 直接返回字符串 emotion: Optional[str] # 直接返回字符串
record_time: float record_time: float
register_time: Optional[float] register_time: Optional[float]
usage_count: int
last_used_time: Optional[float] last_used_time: Optional[float]
@ -257,22 +261,19 @@ def verify_auth_token(
return verify_auth_token_from_cookie_or_header(maibot_session, authorization) return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def emoji_to_response(emoji: Emoji) -> EmojiResponse: def emoji_to_response(image: Images) -> EmojiResponse:
"""将 Emoji 模型转换为响应对象"""
return EmojiResponse( return EmojiResponse(
id=emoji.id, id=image.id if image.id is not None else 0,
full_path=emoji.full_path, full_path=image.full_path,
format=emoji.format, emoji_hash=image.image_hash,
emoji_hash=emoji.emoji_hash, description=image.description,
description=emoji.description, query_count=image.query_count,
query_count=emoji.query_count, is_registered=image.is_registered,
is_registered=emoji.is_registered, is_banned=image.is_banned,
is_banned=emoji.is_banned, emotion=image.emotion,
emotion=str(emoji.emotion) if emoji.emotion is not None else None, record_time=image.record_time.timestamp() if image.record_time else 0.0,
record_time=emoji.record_time, register_time=image.register_time.timestamp() if image.register_time else None,
register_time=emoji.register_time, last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
usage_count=emoji.usage_count,
last_used_time=emoji.last_used_time,
) )
@ -283,8 +284,7 @@ async def get_emoji_list(
search: Optional[str] = Query(None, description="搜索关键词"), search: Optional[str] = Query(None, description="搜索关键词"),
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"), is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"), is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"), sort_by: Optional[str] = Query("query_count", description="排序字段"),
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
sort_order: Optional[str] = Query("desc", description="排序方向"), sort_order: Optional[str] = Query("desc", description="排序方向"),
maibot_session: Optional[str] = Cookie(None), maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
@ -298,8 +298,7 @@ async def get_emoji_list(
search: 搜索关键词 (匹配 description, emoji_hash) search: 搜索关键词 (匹配 description, emoji_hash)
is_registered: 是否已注册筛选 is_registered: 是否已注册筛选
is_banned: 是否被禁用筛选 is_banned: 是否被禁用筛选
format: 格式筛选 sort_by: 排序字段 (query_count, register_time, record_time, last_used_time)
sort_by: 排序字段 (usage_count, register_time, record_time, last_used_time)
sort_order: 排序方向 (asc, desc) sort_order: 排序方向 (asc, desc)
authorization: Authorization header authorization: Authorization header
@ -310,47 +309,58 @@ async def get_emoji_list(
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
# 构建查询 # 构建查询
query = Emoji.select() statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
# 搜索过滤 # 搜索过滤
if search: if search:
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search))) statement = statement.where(
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
)
# 注册状态过滤 # 注册状态过滤
if is_registered is not None: if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered) statement = statement.where(col(Images.is_registered) == is_registered)
# 禁用状态过滤 # 禁用状态过滤
if is_banned is not None: if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned) statement = statement.where(col(Images.is_banned) == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
# 排序字段映射 # 排序字段映射
sort_field_map = { sort_field_map = {
"usage_count": Emoji.usage_count, "usage_count": col(Images.query_count),
"register_time": Emoji.register_time, "query_count": col(Images.query_count),
"record_time": Emoji.record_time, "register_time": col(Images.register_time),
"last_used_time": Emoji.last_used_time, "record_time": col(Images.record_time),
"last_used_time": col(Images.last_used_time),
} }
# 获取排序字段,默认使用 usage_count # 获取排序字段,默认使用 usage_count
sort_field = sort_field_map.get(sort_by, Emoji.usage_count) sort_key = sort_by or "query_count"
sort_field = sort_field_map.get(sort_key, col(Images.query_count))
# 应用排序 # 应用排序
if sort_order == "asc": if sort_order == "asc":
query = query.order_by(sort_field.asc()) statement = statement.order_by(sort_field.asc())
else: else:
query = query.order_by(sort_field.desc()) statement = statement.order_by(sort_field.desc())
# 获取总数
total = query.count()
# 分页 # 分页
offset = (page - 1) * page_size offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size) statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
emojis = session.exec(statement).all()
count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
if search:
count_statement = count_statement.where(
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
)
if is_registered is not None:
count_statement = count_statement.where(col(Images.is_registered) == is_registered)
if is_banned is not None:
count_statement = count_statement.where(col(Images.is_banned) == is_banned)
total = session.exec(count_statement).one()
# 转换为响应对象 # 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis] data = [emoji_to_response(emoji) for emoji in emojis]
@ -381,12 +391,17 @@ async def get_emoji_detail(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji: if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji)) return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException: except HTTPException:
raise raise
@ -416,34 +431,37 @@ async def update_emoji(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji: if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段 # 只更新提供的字段
update_data = request.model_dump(exclude_unset=True) update_data = request.model_dump(exclude_unset=True)
if not update_data: if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# emotion 字段直接使用字符串,无需转换 # 如果注册状态从 False 变为 True记录注册时间
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = datetime.now()
# 如果注册状态从 False 变为 True记录注册时间 # 执行更新
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered: for field, value in update_data.items():
update_data["register_time"] = time.time() setattr(emoji, field, value)
# 执行更新 session.add(emoji)
for field, value in update_data.items():
setattr(emoji, field, value)
emoji.save() logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}") return EmojiUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
return EmojiUpdateResponse( )
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException: except HTTPException:
raise raise
@ -469,20 +487,22 @@ async def delete_emoji(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji: if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息 emoji_hash = emoji.image_hash
emoji_hash = emoji.emoji_hash session.delete(emoji)
# 执行删除 logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}") return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException: except HTTPException:
raise raise
@ -505,27 +525,51 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
total = Emoji.select().count() with get_db_session() as session:
registered = Emoji.select().where(Emoji.is_registered).count() total_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
banned = Emoji.select().where(Emoji.is_banned).count() registered_statement = (
select(func.count())
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_registered) == True,
)
)
banned_statement = (
select(func.count())
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == True,
)
)
# 按格式统计 total = session.exec(total_statement).one()
formats = {} registered = session.exec(registered_statement).one()
for emoji in Emoji.select(Emoji.format): banned = session.exec(banned_statement).one()
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
# 获取最常用的表情包前10 formats: dict[str, int] = {}
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10) format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
top_used_list = [ for full_path in session.exec(format_statement).all():
{ suffix = Path(full_path).suffix.lower().lstrip(".")
"id": emoji.id, fmt = suffix or "unknown"
"emoji_hash": emoji.emoji_hash, formats[fmt] = formats.get(fmt, 0) + 1
"description": emoji.description,
"usage_count": emoji.usage_count, top_used_statement = (
} select(Images)
for emoji in top_used .where(col(Images.image_type) == ImageType.EMOJI)
] .order_by(col(Images.query_count).desc())
.limit(10)
)
top_used_list = [
{
"id": emoji.id,
"emoji_hash": emoji.image_hash,
"description": emoji.description,
"usage_count": emoji.query_count,
}
for emoji in session.exec(top_used_statement).all()
]
return { return {
"success": True, "success": True,
@ -563,23 +607,27 @@ async def register_emoji(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji: if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered: if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册") raise HTTPException(status_code=400, detail="该表情包已经注册")
# 注册表情包(如果已封禁,自动解除封禁) emoji.is_registered = True
emoji.is_registered = True emoji.is_banned = False
emoji.is_banned = False # 注册时自动解除封禁 emoji.register_time = datetime.now()
emoji.register_time = time.time() session.add(emoji)
emoji.save()
logger.info(f"表情包已注册: ID={emoji_id}") logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji)) return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException: except HTTPException:
raise raise
@ -605,19 +653,23 @@ async def ban_emoji(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji: if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册) emoji.is_banned = True
emoji.is_banned = True emoji.is_registered = False
emoji.is_registered = False session.add(emoji)
emoji.save()
logger.info(f"表情包已禁用: ID={emoji_id}") logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji)) return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException: except HTTPException:
raise raise
@ -672,61 +724,58 @@ async def get_emoji_thumbnail(
if not is_valid: if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期") raise HTTPException(status_code=401, detail="Token 无效或已过期")
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
statement = select(Images).where(
if not emoji: col(Images.id) == emoji_id,
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包") col(Images.image_type) == ImageType.EMOJI,
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
# 如果请求原图,直接返回原文件
if original:
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
) )
emoji = session.exec(statement).first()
# 尝试获取或生成缩略图 if not emoji:
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查缓存是否存在 if not os.path.exists(emoji.full_path):
if cache_path.exists(): raise HTTPException(status_code=404, detail="表情包文件不存在")
# 缓存命中,直接返回
return FileResponse( if original:
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp" mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
suffix = Path(emoji.full_path).suffix.lower().lstrip(".")
media_type = mime_types.get(suffix, "application/octet-stream")
return FileResponse(
path=emoji.full_path, media_type=media_type, filename=f"{emoji.image_hash}.{suffix}"
)
cache_path = _get_thumbnail_cache_path(emoji.image_hash)
if cache_path.exists():
return FileResponse(
path=str(cache_path), media_type="image/webp", filename=f"{emoji.image_hash}_thumb.webp"
)
with _generating_lock:
if emoji.image_hash not in _generating_thumbnails:
_generating_thumbnails.add(emoji.image_hash)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.image_hash)
return JSONResponse(
status_code=202,
content={
"status": "generating",
"message": "缩略图正在生成中,请稍后重试",
"emoji_id": emoji_id,
},
headers={
"Retry-After": "1",
},
) )
# 缓存未命中,触发后台生成并返回 202
with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
status_code=202,
content={
"status": "generating",
"message": "缩略图正在生成中,请稍后重试",
"emoji_id": emoji_id,
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
},
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -762,14 +811,19 @@ async def batch_delete_emojis(
for emoji_id in request.emoji_ids: for emoji_id in request.emoji_ids:
try: try:
emoji = Emoji.get_or_none(Emoji.id == emoji_id) with get_db_session() as session:
if emoji: statement = select(Images).where(
emoji.delete_instance() col(Images.id) == emoji_id,
deleted_count += 1 col(Images.image_type) == ImageType.EMOJI,
logger.info(f"批量删除表情包: {emoji_id}") )
else: emoji = session.exec(statement).first()
failed_count += 1 if emoji:
failed_ids.append(emoji_id) session.delete(emoji)
deleted_count += 1
logger.info(f"批量删除表情包: {emoji_id}")
else:
failed_count += 1
failed_ids.append(emoji_id)
except Exception as e: except Exception as e:
logger.error(f"删除表情包 {emoji_id} 失败: {e}") logger.error(f"删除表情包 {emoji_id} 失败: {e}")
failed_count += 1 failed_count += 1
@ -864,19 +918,23 @@ async def upload_emoji(
# 计算文件哈希 # 计算文件哈希
emoji_hash = hashlib.md5(file_content).hexdigest() emoji_hash = hashlib.md5(file_content).hexdigest()
# 检查是否已存在相同哈希的表情包 with get_db_session() as session:
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash) existing_statement = select(Images).where(
if existing_emoji: col(Images.image_hash) == emoji_hash,
raise HTTPException( col(Images.image_type) == ImageType.EMOJI,
status_code=409,
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
) )
existing_emoji = session.exec(existing_statement).first()
if existing_emoji:
raise HTTPException(
status_code=409,
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
)
# 确保目录存在 # 确保目录存在
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True) os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
# 生成文件名 # 生成文件名
timestamp = int(time.time()) timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
@ -889,37 +947,38 @@ async def upload_emoji(
# 保存文件 # 保存文件
with open(full_path, "wb") as f: with open(full_path, "wb") as f:
f.write(file_content) _ = f.write(file_content)
logger.info(f"表情包文件已保存: {full_path}") logger.info(f"表情包文件已保存: {full_path}")
# 处理情感标签 # 处理情感标签
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else "" emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
# 创建数据库记录 current_time = datetime.now()
current_time = time.time() with get_db_session() as session:
emoji = Emoji.create( emoji = Images(
full_path=full_path, image_type=ImageType.EMOJI,
format=img_format, full_path=full_path,
emoji_hash=emoji_hash, image_hash=emoji_hash,
description=description, description=description,
emotion=emotion_str, emotion=emotion_str or None,
query_count=0, query_count=0,
is_registered=is_registered, is_registered=is_registered,
is_banned=False, is_banned=False,
record_time=current_time, record_time=current_time,
register_time=current_time if is_registered else None, register_time=current_time if is_registered else None,
usage_count=0, last_used_time=None,
last_used_time=None, )
) session.add(emoji)
session.flush()
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}") logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
return EmojiUploadResponse( return EmojiUploadResponse(
success=True, success=True,
message="表情包上传成功" + ("并已注册" if is_registered else ""), message="表情包上传成功" + ("并已注册" if is_registered else ""),
data=emoji_to_response(emoji), data=emoji_to_response(emoji),
) )
except HTTPException: except HTTPException:
raise raise
@ -951,7 +1010,7 @@ async def batch_upload_emoji(
try: try:
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
results = { results: dict[str, Any] = {
"success": True, "success": True,
"total": len(files), "total": len(files),
"uploaded": 0, "uploaded": 0,
@ -1008,20 +1067,24 @@ async def batch_upload_emoji(
# 计算哈希 # 计算哈希
emoji_hash = hashlib.md5(file_content).hexdigest() emoji_hash = hashlib.md5(file_content).hexdigest()
# 检查重复 with get_db_session() as session:
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash): existing_statement = select(Images).where(
results["failed"] += 1 col(Images.image_hash) == emoji_hash,
results["details"].append( col(Images.image_type) == ImageType.EMOJI,
{
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
}
) )
continue if session.exec(existing_statement).first():
results["failed"] += 1
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
}
)
continue
# 生成文件名并保存 # 生成文件名并保存
timestamp = int(time.time()) timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}" filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename) full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
@ -1032,36 +1095,37 @@ async def batch_upload_emoji(
counter += 1 counter += 1
with open(full_path, "wb") as f: with open(full_path, "wb") as f:
f.write(file_content) _ = f.write(file_content)
# 处理情感标签 # 处理情感标签
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else "" emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
# 创建数据库记录 current_time = datetime.now()
current_time = time.time() with get_db_session() as session:
emoji = Emoji.create( emoji = Images(
full_path=full_path, image_type=ImageType.EMOJI,
format=img_format, full_path=full_path,
emoji_hash=emoji_hash, image_hash=emoji_hash,
description="", # 批量上传暂不设置描述 description="",
emotion=emotion_str, emotion=emotion_str or None,
query_count=0, query_count=0,
is_registered=is_registered, is_registered=is_registered,
is_banned=False, is_banned=False,
record_time=current_time, record_time=current_time,
register_time=current_time if is_registered else None, register_time=current_time if is_registered else None,
usage_count=0, last_used_time=None,
last_used_time=None, )
) session.add(emoji)
session.flush()
results["uploaded"] += 1 results["uploaded"] += 1
results["details"].append( results["details"].append(
{ {
"filename": file.filename, "filename": file.filename,
"success": True, "success": True,
"id": emoji.id, "id": emoji.id,
} }
) )
except Exception as e: except Exception as e:
results["failed"] += 1 results["failed"] += 1
@ -1138,8 +1202,9 @@ async def get_thumbnail_cache_stats(
total_size = sum(f.stat().st_size for f in cache_files) total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2) total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数 with get_db_session() as session:
emoji_count = Emoji.select().count() count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
emoji_count = session.exec(count_statement).one()
# 计算覆盖率 # 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1) coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
@ -1213,12 +1278,17 @@ async def preheat_thumbnail_cache(
_ensure_thumbnail_cache_dir() _ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先) # 获取使用次数最高的表情包(未缓存的优先)
emojis = ( with get_db_session() as session:
Emoji.select() statement = (
.where(Emoji.is_banned == False) # noqa: E712 Peewee ORM requires == for boolean comparison select(Images)
.order_by(Emoji.usage_count.desc()) .where(
.limit(limit * 2) # 多查一些,因为有些可能已缓存 col(Images.image_type) == ImageType.EMOJI,
) col(Images.is_banned) == False,
)
.order_by(col(Images.query_count).desc())
.limit(limit * 2)
)
emojis = session.exec(statement).all()
generated = 0 generated = 0
skipped = 0 skipped = 0
@ -1228,25 +1298,22 @@ async def preheat_thumbnail_cache(
if generated >= limit: if generated >= limit:
break break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash) cache_path = _get_thumbnail_cache_path(emoji.image_hash)
# 已缓存,跳过
if cache_path.exists(): if cache_path.exists():
skipped += 1 skipped += 1
continue continue
# 原文件不存在,跳过
if not os.path.exists(emoji.full_path): if not os.path.exists(emoji.full_path):
failed += 1 failed += 1
continue continue
try: try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash) await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.image_hash)
generated += 1 generated += 1
except Exception as e: except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}") logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
failed += 1 failed += 1
return ThumbnailPreheatResponse( return ThumbnailPreheatResponse(

View File

@ -65,9 +65,6 @@ class ExpressionUpdateRequest(BaseModel):
situation: Optional[str] = None situation: Optional[str] = None
style: Optional[str] = None style: Optional[str] = None
chat_id: Optional[str] = None chat_id: Optional[str] = None
checked: Optional[bool] = None
rejected: Optional[bool] = None
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
class ExpressionUpdateResponse(BaseModel): class ExpressionUpdateResponse(BaseModel):
@ -388,26 +385,16 @@ async def update_expression(
if not expression: if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and getattr(expression, "checked", False):
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段 # 只更新提供的字段
update_data = request.model_dump(exclude_unset=True) update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段 # 映射 API 字段名到数据库字段名
update_data.pop("require_unchecked", None) if "chat_id" in update_data:
update_data["session_id"] = update_data.pop("chat_id")
if not update_data: if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if "checked" in update_data or "rejected" in update_data:
update_data["modified_by"] = "user"
# 更新最后活跃时间 # 更新最后活跃时间
update_data["last_active_time"] = datetime.now() update_data["last_active_time"] = datetime.now()

View File

@ -1,13 +1,16 @@
"""黑话(俚语)管理路由""" """黑话(俚语)管理路由"""
import json from typing import Annotated, Any, List, Optional
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func as fn from sqlalchemy import func as fn
from sqlmodel import Session, col, delete, select
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Jargon
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon") logger = get_logger("webui.jargon")
@ -43,27 +46,26 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
return [chat_id_str] return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str: def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
""" """
获取 chat_id 的显示名称 获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称 尝试解析 JSON 并查询 ChatSession 表获取群聊名称
""" """
stream_ids = parse_chat_id_to_stream_ids(chat_id_str) stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids: if not stream_ids:
return chat_id_str return chat_id_str[:20]
# 查询所有 stream_id 对应的名称 stream_id = stream_ids[0]
names = [] chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str if not chat_session:
return stream_id[:20]
if chat_session.group_id:
return str(chat_session.group_id)
return chat_session.session_id[:20]
# ==================== 请求/响应模型 ==================== # ==================== 请求/响应模型 ====================
@ -79,7 +81,6 @@ class JargonResponse(BaseModel):
chat_id: str chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配 stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示 chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0 count: int = 0
is_jargon: Optional[bool] = None is_jargon: Optional[bool] = None
is_complete: bool = False is_complete: bool = False
@ -94,7 +95,7 @@ class JargonListResponse(BaseModel):
total: int total: int
page: int page: int
page_size: int page_size: int
data: List[JargonResponse] data: List[dict[str, Any]]
class JargonDetailResponse(BaseModel): class JargonDetailResponse(BaseModel):
@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel):
raw_content: Optional[str] = Field(None, description="原始内容") raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义") meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID") chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel): class JargonUpdateRequest(BaseModel):
@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel):
raw_content: Optional[str] = None raw_content: Optional[str] = None
meaning: Optional[str] = None meaning: Optional[str] = None
chat_id: Optional[str] = None chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None is_jargon: Optional[bool] = None
@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel):
"""黑话统计响应""" """黑话统计响应"""
success: bool = True success: bool = True
data: dict data: dict[str, Any]
class ChatInfoResponse(BaseModel): class ChatInfoResponse(BaseModel):
@ -181,27 +180,24 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ==================== # ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict: def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"""将 Jargon ORM 对象转换为字典""" """将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id chat_id = jargon.session_id or ""
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return { return {
"id": jargon.id, "id": jargon.id,
"content": jargon.content, "content": jargon.content,
"raw_content": jargon.raw_content, "raw_content": jargon.raw_content,
"meaning": jargon.meaning, "meaning": jargon.meaning,
"chat_id": jargon.chat_id, "chat_id": chat_id,
"stream_id": stream_id, "stream_id": jargon.session_id,
"chat_name": chat_name, "chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count, "count": jargon.count,
"is_jargon": jargon.is_jargon, "is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete, "is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context, "inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only, "inference_content_only": jargon.inference_with_content_only,
} }
@ -215,49 +211,41 @@ async def get_jargon_list(
search: Optional[str] = Query(None, description="搜索关键词"), search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"), chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"), is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
): ):
"""获取黑话列表""" """获取黑话列表"""
try: try:
# 构建查询 statement = select(Jargon)
query = Jargon.select() count_statement = select(fn.count()).select_from(Jargon)
# 搜索过滤
if search: if search:
query = query.where( search_filter = (
(Jargon.content.contains(search)) (col(Jargon.content).contains(search))
| (Jargon.meaning.contains(search)) | (col(Jargon.meaning).contains(search))
| (Jargon.raw_content.contains(search)) | (col(Jargon.raw_content).contains(search))
) )
statement = statement.where(search_filter)
count_statement = count_statement.where(search_filter)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id: if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id) stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids: if stream_ids:
# 使用第一个 stream_id 进行模糊匹配 chat_filter = col(Jargon.session_id).contains(stream_ids[0])
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
else: else:
# 如果无法解析,使用精确匹配 chat_filter = col(Jargon.session_id) == chat_id
query = query.where(Jargon.chat_id == chat_id) statement = statement.where(chat_filter)
count_statement = count_statement.where(chat_filter)
# 按是否是黑话筛选
if is_jargon is not None: if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon) statement = statement.where(col(Jargon.is_jargon) == is_jargon)
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
# 按是否全局筛选 statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
if is_global is not None: statement = statement.offset((page - 1) * page_size).limit(page_size)
query = query.where(Jargon.is_global == is_global)
# 获取总数 with get_db_session() as session:
total = query.count() total = session.exec(count_statement).one()
jargons = session.exec(statement).all()
# 分页和排序(按使用次数降序) data = [jargon_to_dict(jargon, session) for jargon in jargons]
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
return JargonListResponse( return JargonListResponse(
success=True, success=True,
@ -276,10 +264,9 @@ async def get_jargon_list(
async def get_chat_list(): async def get_chat_list():
"""获取所有有黑话记录的聊天列表""" """获取所有有黑话记录的聊天列表"""
try: try:
# 获取所有不同的 chat_id with get_db_session() as session:
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)) statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重 # 用于按 stream_id 去重
seen_stream_ids: set[str] = set() seen_stream_ids: set[str] = set()
@ -290,27 +277,28 @@ async def get_chat_list():
seen_stream_ids.add(stream_ids[0]) seen_stream_ids.add(stream_ids[0])
result = [] result = []
for stream_id in seen_stream_ids: with get_db_session() as session:
# 尝试从 ChatStreams 表获取聊天名称 for stream_id in seen_stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id) chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if chat_stream: if chat_session:
result.append( chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
ChatInfoResponse( result.append(
chat_id=stream_id, # 使用 stream_id方便筛选匹配 ChatInfoResponse(
chat_name=chat_stream.group_name or stream_id, chat_id=stream_id,
platform=chat_stream.platform, chat_name=chat_name,
is_group=True, platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
) )
) else:
else: result.append(
result.append( ChatInfoResponse(
ChatInfoResponse( chat_id=stream_id,
chat_id=stream_id, # 使用 stream_id chat_name=stream_id[:20],
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id, platform=None,
platform=None, is_group=False,
is_group=False, )
) )
)
return ChatListResponse(success=True, data=result) return ChatListResponse(success=True, data=result)
@ -323,35 +311,35 @@ async def get_chat_list():
async def get_jargon_stats(): async def get_jargon_stats():
"""获取黑话统计数据""" """获取黑话统计数据"""
try: try:
# 总数量 with get_db_session() as session:
total = Jargon.select().count() total = session.exec(select(fn.count()).select_from(Jargon)).one()
# 已确认是黑话的数量 confirmed_jargon = session.exec(
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count() select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
# 已确认不是黑话的数量 complete_count = session.exec(
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count() select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
).one()
# 未判定的数量 chat_count = session.exec(
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count() select(fn.count()).select_from(
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
)
).one()
# 全局黑话数量 top_chats = session.exec(
global_count = Jargon.select().where(Jargon.is_global).count() select(col(Jargon.session_id), fn.count().label("count"))
.where(col(Jargon.session_id).is_not(None))
# 已完成推断的数量 .group_by(col(Jargon.session_id))
complete_count = Jargon.select().where(Jargon.is_complete).count() .order_by(fn.count().desc())
.limit(5)
# 关联的聊天数量 ).all()
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count() top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
return JargonStatsResponse( return JargonStatsResponse(
success=True, success=True,
@ -360,7 +348,6 @@ async def get_jargon_stats():
"confirmed_jargon": confirmed_jargon, "confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon, "confirmed_not_jargon": confirmed_not_jargon,
"pending": pending, "pending": pending,
"global_count": global_count,
"complete_count": complete_count, "complete_count": complete_count,
"chat_count": chat_count, "chat_count": chat_count,
"top_chats": top_chats_dict, "top_chats": top_chats_dict,
@ -376,11 +363,13 @@ async def get_jargon_stats():
async def get_jargon_detail(jargon_id: int): async def get_jargon_detail(jargon_id: int):
"""获取黑话详情""" """获取黑话详情"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon)) return JargonDetailResponse(success=True, data=data)
except HTTPException: except HTTPException:
raise raise
@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int):
async def create_jargon(request: JargonCreateRequest): async def create_jargon(request: JargonCreateRequest):
"""创建黑话""" """创建黑话"""
try: try:
# 检查是否已存在相同内容的黑话 with get_db_session() as session:
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id)) existing = session.exec(
if existing: select(Jargon).where(
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话") (col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
)
).first()
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话 jargon = Jargon(
jargon = Jargon.create( content=request.content,
content=request.content, raw_content=request.raw_content,
raw_content=request.raw_content, meaning=request.meaning or "",
meaning=request.meaning, session_id=request.chat_id,
chat_id=request.chat_id, count=0,
is_global=request.is_global, is_jargon=None,
count=0, is_complete=False,
is_jargon=None, )
is_complete=False, session.add(jargon)
) session.flush()
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}") logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonCreateResponse( return JargonCreateResponse(success=True, message="创建成功", data=data)
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
except HTTPException: except HTTPException:
raise raise
@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest):
async def update_jargon(jargon_id: int, request: JargonUpdateRequest): async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)""" """更新黑话(增量更新)"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段 update_data = request.model_dump(exclude_unset=True)
update_data = request.model_dump(exclude_unset=True) if update_data:
if update_data: for field, value in update_data.items():
for field, value in update_data.items(): if field == "is_global":
if value is not None or field in ["meaning", "raw_content", "is_jargon"]: continue
setattr(jargon, field, value) if field == "chat_id":
jargon.save() jargon.session_id = value
continue
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
session.add(jargon)
logger.info(f"更新黑话成功: id={jargon_id}") logger.info(f"更新黑话成功: id={jargon_id}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonUpdateResponse( return JargonUpdateResponse(success=True, message="更新成功", data=data)
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
except HTTPException: except HTTPException:
raise raise
@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
async def delete_jargon(jargon_id: int): async def delete_jargon(jargon_id: int):
"""删除黑话""" """删除黑话"""
try: try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id) with get_db_session() as session:
if not jargon: jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
raise HTTPException(status_code=404, detail="黑话不存在") if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content content = jargon.content
jargon.delete_instance() session.delete(jargon)
logger.info(f"删除黑话成功: id={jargon_id}, content={content}") logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse( return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1)
success=True,
message="删除成功",
deleted_count=1,
)
except HTTPException: except HTTPException:
raise raise
@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
if not request.ids: if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute() with get_db_session() as session:
result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录") logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse( return JargonDeleteResponse(
success=True, success=True,
@ -516,14 +507,16 @@ async def batch_set_jargon_status(
if not ids: if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空") raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute() with get_db_session() as session:
jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all()
for jargon in jargons:
jargon.is_jargon = is_jargon
session.add(jargon)
updated_count = len(jargons)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}") logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse( return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态")
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
except HTTPException: except HTTPException:
raise raise

View File

@ -7,7 +7,7 @@ from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.config_types import ConfigField
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.core import get_token_manager from src.webui.core import get_token_manager
from src.webui.routers.websocket.plugin_progress import update_progress from src.webui.routers.websocket.plugin_progress import update_progress

View File

@ -7,7 +7,6 @@ class EmojiResponse(BaseModel):
id: int id: int
full_path: str full_path: str
format: str
emoji_hash: str emoji_hash: str
description: str description: str
query_count: int query_count: int
@ -16,7 +15,6 @@ class EmojiResponse(BaseModel):
emotion: Optional[str] emotion: Optional[str]
record_time: float record_time: float
register_time: Optional[float] register_time: Optional[float]
usage_count: int
last_used_time: Optional[float] last_used_time: Optional[float]

View File

@ -0,0 +1,662 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
_update_progress = callback
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
GITHUB = "github" # GitHub 官方源(兜底)
CUSTOM = "custom" # 自定义镜像源
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
"id": "gh-proxy",
"name": "gh-proxy 镜像",
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None,
},
{
"id": "hk-gh-proxy",
"name": "gh-proxy 香港节点",
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None,
},
{
"id": "cdn-gh-proxy",
"name": "gh-proxy CDN 节点",
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
"name": "gh-proxy EdgeOne 节点",
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None,
},
{
"id": "meyzh-github",
"name": "Meyzh GitHub 镜像",
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None,
},
{
"id": "github",
"name": "GitHub 官方源(兜底)",
"raw_prefix": "https://raw.githubusercontent.com",
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
self._init_default_mirrors()
else:
self.mirrors = data["git_mirrors"]
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
else:
logger.info("配置文件不存在,创建默认配置")
self._init_default_mirrors()
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
name: str,
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
"raw_prefix": raw_prefix,
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
name: Optional[str] = None,
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置如果不存在则返回 None
"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
for i, mirror in enumerate(self.mirrors):
if mirror.get("id") == mirror_id:
self.mirrors.pop(i)
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
return [m["id"] for m in enabled]
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间
config: 镜像源配置管理器可选默认创建新实例
"""
self.max_retries = max_retries
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
- version: str - Git 版本号如果已安装
- path: str - Git 可执行文件路径如果已安装
- error: str - 错误信息如果未安装或检测失败
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
branch: 分支名称
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
- data: str - 文件内容成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
if _update_progress:
try:
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
await _update_progress(
stage="loading",
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
stage="loading",
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
logger.debug(f"尝试 #{attempt + 1}: {url}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except httpx.TimeoutException as e:
last_error = f"请求超时: {e}"
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
target_path: 目标路径
branch: 分支名称可选
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度浅克隆
Returns:
Dict 包含:
- success: bool - 是否成功
- path: str - 克隆路径成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone(clone_cmd=cmd):
return subprocess.run(
clone_cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
"success": True,
"path": str(target_path),
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例
_git_mirror_service: Optional[GitMirrorService] = None
def get_git_mirror_service() -> GitMirrorService:
"""获取 Git 镜像源服务实例(单例)"""
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service