pull/1001/head
SnowindMe 2025-05-21 13:15:27 +08:00
commit 54ede46df9
3 changed files with 28 additions and 6 deletions

View File

@ -6,7 +6,7 @@ from .global_logger import logger
from . import prompt_template from . import prompt_template
from .lpmmconfig import global_config, INVALID_ENTITY from .lpmmconfig import global_config, INVALID_ENTITY
from .llm_client import LLMClient from .llm_client import LLMClient
from .utils.json_fix import fix_broken_generated_json from .utils.json_fix import new_fix_broken_generated_json
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
@ -24,7 +24,7 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
if "]" in request_result: if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1] request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(fix_broken_generated_json(request_result)) entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
entity_extract_result = [ entity_extract_result = [
entity entity
@ -53,7 +53,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
if "]" in request_result: if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1] request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(fix_broken_generated_json(request_result)) entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
for triple in entity_extract_result: for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:

View File

@ -1,4 +1,5 @@
import json import json
from json_repair import repair_json
def _find_unclosed(json_str): def _find_unclosed(json_str):
@ -74,3 +75,24 @@ def fix_broken_generated_json(json_str: str) -> str:
json_str += closing_map[open_char] json_str += closing_map[open_char]
return json_str return json_str
def new_fix_broken_generated_json(json_str: str) -> str:
"""
使用 json-repair 库修复格式错误的 JSON 字符串
如果原始 json_str 字符串可以被 json.loads() 成功加载则直接返回而不进行任何修改
参数:
json_str (str): 需要修复的格式错误的 JSON 字符串
返回:
str: 修复后的 JSON 字符串
"""
try:
# 尝试加载 JSON 以查看其是否有效
json.loads(json_str)
return json_str # 如果有效则按原样返回
except json.JSONDecodeError:
# 如果无效,则尝试修复它
return repair_json(json_str)

View File

@ -8,7 +8,7 @@ class Identity:
identity_detail: List[str] # 身份细节描述 identity_detail: List[str] # 身份细节描述
height: int # 身高(厘米) height: int # 身高(厘米)
weight: int # 体重(千克) weight: float # 体重(千克)
age: int # 年龄 age: int # 年龄
gender: str # 性别 gender: str # 性别
appearance: str # 外貌特征 appearance: str # 外貌特征
@ -24,7 +24,7 @@ class Identity:
self, self,
identity_detail: List[str] = None, identity_detail: List[str] = None,
height: int = 0, height: int = 0,
weight: int = 0, weight: float = 0,
age: int = 0, age: int = 0,
gender: str = "", gender: str = "",
appearance: str = "", appearance: str = "",
@ -61,7 +61,7 @@ class Identity:
@classmethod @classmethod
def initialize( def initialize(
cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str cls, identity_detail: List[str], height: int, weight: float, age: int, gender: str, appearance: str
) -> "Identity": ) -> "Identity":
"""初始化身份特征 """初始化身份特征