diff --git a/src/chat/knowledge/src/ie_process.py b/src/chat/knowledge/src/ie_process.py index 0bbe2169..ddc5eb02 100644 --- a/src/chat/knowledge/src/ie_process.py +++ b/src/chat/knowledge/src/ie_process.py @@ -6,7 +6,7 @@ from .global_logger import logger from . import prompt_template from .lpmmconfig import global_config, INVALID_ENTITY from .llm_client import LLMClient -from .utils.json_fix import fix_broken_generated_json +from .utils.json_fix import new_fix_broken_generated_json def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: @@ -24,7 +24,7 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: if "]" in request_result: request_result = request_result[: request_result.rindex("]") + 1] - entity_extract_result = json.loads(fix_broken_generated_json(request_result)) + entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) entity_extract_result = [ entity @@ -53,7 +53,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) - if "]" in request_result: request_result = request_result[: request_result.rindex("]") + 1] - entity_extract_result = json.loads(fix_broken_generated_json(request_result)) + entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) for triple in entity_extract_result: if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: diff --git a/src/chat/knowledge/src/utils/json_fix.py b/src/chat/knowledge/src/utils/json_fix.py index a83eb491..53fa8f36 100644 --- a/src/chat/knowledge/src/utils/json_fix.py +++ b/src/chat/knowledge/src/utils/json_fix.py @@ -1,4 +1,5 @@ import json +from json_repair import repair_json def _find_unclosed(json_str): @@ -74,3 +75,24 @@ def fix_broken_generated_json(json_str: str) -> str: json_str += closing_map[open_char] return json_str + + +def new_fix_broken_generated_json(json_str: str) -> str: + """ + 使用 json-repair 库修复格式错误的 JSON 字符串。 + + 如果原始 json_str 字符串可以被 json.loads() 成功加载,则直接返回而不进行任何修改。 + + 参数: + json_str (str): 需要修复的格式错误的 JSON 字符串。 + + 返回: + str: 修复后的 JSON 字符串。 + """ + try: + # 尝试加载 JSON 以查看其是否有效 + json.loads(json_str) + return json_str # 如果有效则按原样返回 + except json.JSONDecodeError: + # 如果无效,则尝试修复它 + return repair_json(json_str) diff --git a/src/individuality/identity.py b/src/individuality/identity.py index fd0d70f3..f79da547 100644 --- a/src/individuality/identity.py +++ b/src/individuality/identity.py @@ -8,7 +8,7 @@ class Identity: identity_detail: List[str] # 身份细节描述 height: int # 身高(厘米) - weight: int # 体重(千克) + weight: float # 体重(千克) age: int # 年龄 gender: str # 性别 appearance: str # 外貌特征 @@ -24,7 +24,7 @@ class Identity: self, identity_detail: List[str] = None, height: int = 0, - weight: int = 0, + weight: float = 0, age: int = 0, gender: str = "", appearance: str = "", @@ -61,7 +61,7 @@ class Identity: @classmethod def initialize( - cls, identity_detail: List[str], height: int, weight: int, age: int, gender: str, appearance: str + cls, identity_detail: List[str], height: int, weight: float, age: int, gender: str, appearance: str ) -> "Identity": """初始化身份特征