mirror of https://github.com/Mai-with-u/MaiBot.git
361 lines
15 KiB
Python
361 lines
15 KiB
Python
import argparse
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import List, Tuple, Dict, Any
|
||
import json
|
||
import os
|
||
|
||
# 强制使用 utf-8,避免控制台编码报错
|
||
try:
|
||
if hasattr(sys.stdout, "reconfigure"):
|
||
sys.stdout.reconfigure(encoding="utf-8")
|
||
if hasattr(sys.stderr, "reconfigure"):
|
||
sys.stderr.reconfigure(encoding="utf-8")
|
||
except Exception:
|
||
pass
|
||
|
||
# 确保能找到 src 包
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||
|
||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||
from src.chat.knowledge.kg_manager import KGManager
|
||
from src.common.logger import get_logger
|
||
from src.chat.knowledge.utils.hash import get_sha256
|
||
|
||
logger = get_logger("delete_lpmm_items")
|
||
|
||
|
||
def read_hashes(file_path: Path) -> List[str]:
|
||
"""读取哈希列表,跳过空行"""
|
||
hashes: List[str] = []
|
||
for line in file_path.read_text(encoding="utf-8").splitlines():
|
||
val = line.strip()
|
||
if not val:
|
||
continue
|
||
hashes.append(val)
|
||
return hashes
|
||
|
||
|
||
def read_openie_hashes(file_path: Path) -> List[str]:
|
||
"""从 OpenIE JSON 中提取 idx 作为段落哈希"""
|
||
data: Dict[str, Any] = json.loads(file_path.read_text(encoding="utf-8"))
|
||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||
hashes: List[str] = []
|
||
for doc in docs:
|
||
idx = doc.get("idx") if isinstance(doc, dict) else None
|
||
if isinstance(idx, str) and idx.strip():
|
||
hashes.append(idx.strip())
|
||
return hashes
|
||
|
||
|
||
def normalize_paragraph_keys(raw_hashes: List[str]) -> Tuple[List[str], List[str]]:
|
||
"""将输入规范为完整键和纯哈希两份列表"""
|
||
keys: List[str] = []
|
||
hashes: List[str] = []
|
||
for h in raw_hashes:
|
||
if h.startswith("paragraph-"):
|
||
keys.append(h)
|
||
hashes.append(h.replace("paragraph-", "", 1))
|
||
else:
|
||
keys.append(f"paragraph-{h}")
|
||
hashes.append(h)
|
||
return keys, hashes
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Delete paragraphs from LPMM knowledge base (vectors + graph).")
|
||
parser.add_argument("--hash-file", help="文本文件路径,每行一个 paragraph 哈希或带前缀键")
|
||
parser.add_argument("--openie-file", help="OpenIE 输出文件(JSON),将其 docs.idx 作为待删段落哈希")
|
||
parser.add_argument("--raw-file", help="原始 txt 语料文件(按空行分段),可结合 --raw-index 使用")
|
||
parser.add_argument(
|
||
"--raw-index",
|
||
help="在 --raw-file 中要删除的段落索引,1 基,支持逗号分隔,例如 1,3",
|
||
)
|
||
parser.add_argument("--search-text", help="在当前段落库中按子串搜索匹配段落并交互选择删除")
|
||
parser.add_argument(
|
||
"--search-limit",
|
||
type=int,
|
||
default=10,
|
||
help="--search-text 模式下最多展示的候选段落数量",
|
||
)
|
||
parser.add_argument("--delete-entities", action="store_true", help="同时删除 OpenIE 文件中的实体节点/嵌入")
|
||
parser.add_argument("--delete-relations", action="store_true", help="同时删除 OpenIE 文件中的关系嵌入")
|
||
parser.add_argument("--remove-orphan-entities", action="store_true", help="删除删除后孤立的实体节点")
|
||
parser.add_argument("--dry-run", action="store_true", help="仅预览将删除的项,不实际修改")
|
||
parser.add_argument("--yes", action="store_true", help="跳过交互确认,直接执行删除(谨慎使用)")
|
||
parser.add_argument(
|
||
"--max-delete-nodes",
|
||
type=int,
|
||
default=2000,
|
||
help="单次最大允许删除的节点数量(段落+实体),超过则需要显式确认或调整该参数",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
# 至少需要一种来源
|
||
if not (args.hash_file or args.openie_file or args.raw_file or args.search_text):
|
||
logger.error("必须指定 --hash-file / --openie-file / --raw-file / --search-text 之一")
|
||
sys.exit(1)
|
||
|
||
raw_hashes: List[str] = []
|
||
raw_entities: List[str] = []
|
||
raw_relations: List[str] = []
|
||
|
||
if args.hash_file:
|
||
hash_file = Path(args.hash_file)
|
||
if not hash_file.exists():
|
||
logger.error(f"哈希文件不存在: {hash_file}")
|
||
sys.exit(1)
|
||
raw_hashes.extend(read_hashes(hash_file))
|
||
|
||
if args.openie_file:
|
||
openie_path = Path(args.openie_file)
|
||
if not openie_path.exists():
|
||
logger.error(f"OpenIE 文件不存在: {openie_path}")
|
||
sys.exit(1)
|
||
# 段落
|
||
raw_hashes.extend(read_openie_hashes(openie_path))
|
||
# 实体/关系(实体同时包含 extracted_entities 与三元组主语/宾语,以匹配 KG 构图逻辑)
|
||
try:
|
||
data = json.loads(openie_path.read_text(encoding="utf-8"))
|
||
docs = data.get("docs", []) if isinstance(data, dict) else []
|
||
for doc in docs:
|
||
if not isinstance(doc, dict):
|
||
continue
|
||
ents = doc.get("extracted_entities", [])
|
||
if isinstance(ents, list):
|
||
raw_entities.extend([e for e in ents if isinstance(e, str)])
|
||
triples = doc.get("extracted_triples", [])
|
||
if isinstance(triples, list):
|
||
for t in triples:
|
||
if isinstance(t, list) and len(t) == 3:
|
||
subj, _, obj = t
|
||
if isinstance(subj, str):
|
||
raw_entities.append(subj)
|
||
if isinstance(obj, str):
|
||
raw_entities.append(obj)
|
||
raw_relations.append(str(tuple(t)))
|
||
except Exception as e:
|
||
logger.error(f"读取 OpenIE 文件失败: {e}")
|
||
sys.exit(1)
|
||
|
||
# 从原始 txt 语料按段落索引选择删除
|
||
if args.raw_file:
|
||
raw_path = Path(args.raw_file)
|
||
if not raw_path.exists():
|
||
logger.error(f"原始语料文件不存在: {raw_path}")
|
||
sys.exit(1)
|
||
text = raw_path.read_text(encoding="utf-8")
|
||
paragraphs: List[str] = []
|
||
buf = []
|
||
for line in text.splitlines():
|
||
if line.strip() == "":
|
||
if buf:
|
||
paragraphs.append("\n".join(buf).strip())
|
||
buf = []
|
||
else:
|
||
buf.append(line)
|
||
if buf:
|
||
paragraphs.append("\n".join(buf).strip())
|
||
|
||
if not paragraphs:
|
||
logger.error(f"原始语料文件 {raw_path} 中没有解析到任何段落")
|
||
sys.exit(1)
|
||
|
||
if not args.raw_index:
|
||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
||
sys.exit(1)
|
||
|
||
# 解析索引列表(1-based)
|
||
try:
|
||
idx_list = [int(x.strip()) for x in str(args.raw_index).split(",") if x.strip()]
|
||
except ValueError:
|
||
logger.error(f"--raw-index 解析失败: {args.raw_index}")
|
||
sys.exit(1)
|
||
|
||
for idx in idx_list:
|
||
if idx < 1 or idx > len(paragraphs):
|
||
logger.error(f"--raw-index 包含无效索引 {idx}(有效范围 1~{len(paragraphs)})")
|
||
sys.exit(1)
|
||
|
||
logger.info("根据原始语料选择段落:")
|
||
for idx in idx_list:
|
||
para = paragraphs[idx - 1]
|
||
h = get_sha256(para)
|
||
logger.info(f"- 第 {idx} 段,hash={h},内容预览:{para[:80]}")
|
||
raw_hashes.append(h)
|
||
|
||
# 在现有库中按子串搜索候选段落并交互选择
|
||
if args.search_text:
|
||
search_text = args.search_text.strip()
|
||
if not search_text:
|
||
logger.error("--search-text 不能为空")
|
||
sys.exit(1)
|
||
logger.info(f"正在根据关键字在现有段落库中搜索:{search_text!r}")
|
||
em_search = EmbeddingManager()
|
||
try:
|
||
em_search.load_from_file()
|
||
except Exception as e:
|
||
logger.error(f"加载嵌入库失败,无法使用 --search-text 功能: {e}")
|
||
sys.exit(1)
|
||
|
||
candidates = []
|
||
for key, item in em_search.paragraphs_embedding_store.store.items():
|
||
if search_text in item.str:
|
||
candidates.append((key, item.str))
|
||
if len(candidates) >= args.search_limit:
|
||
break
|
||
|
||
if not candidates:
|
||
logger.info("未在现有段落库中找到包含该关键字的段落")
|
||
else:
|
||
logger.info("找到以下候选段落(输入序号选择要删除的条目,可用逗号分隔,多选):")
|
||
for i, (key, text) in enumerate(candidates, start=1):
|
||
logger.info(f"{i}. {key} | {text[:80]}")
|
||
choice = input("请输入要删除的序号列表(如 1,3),或直接回车取消:").strip()
|
||
if choice:
|
||
try:
|
||
idxs = [int(x.strip()) for x in choice.split(",") if x.strip()]
|
||
except ValueError:
|
||
logger.error("输入的序号列表无法解析,已取消 --search-text 删除")
|
||
else:
|
||
for i in idxs:
|
||
if 1 <= i <= len(candidates):
|
||
key, _ = candidates[i - 1]
|
||
# key 已是完整的 paragraph-xxx
|
||
if key.startswith("paragraph-"):
|
||
raw_hashes.append(key.split("paragraph-", 1)[1])
|
||
else:
|
||
logger.warning(f"忽略无效序号: {i}")
|
||
|
||
# 去重但保持顺序
|
||
seen = set()
|
||
raw_hashes = [h for h in raw_hashes if not (h in seen or seen.add(h))]
|
||
|
||
if not raw_hashes:
|
||
logger.error("未读取到任何待删哈希,无操作")
|
||
sys.exit(1)
|
||
|
||
keys, pg_hashes = normalize_paragraph_keys(raw_hashes)
|
||
|
||
ent_hashes: List[str] = []
|
||
rel_hashes: List[str] = []
|
||
if args.delete_entities and raw_entities:
|
||
ent_hashes = [get_sha256(e) for e in raw_entities]
|
||
if args.delete_relations and raw_relations:
|
||
rel_hashes = [get_sha256(r) for r in raw_relations]
|
||
|
||
logger.info("=== 删除操作预备 ===")
|
||
logger.info("请确保已备份 data/embedding 与 data/rag,必要时可使用 --dry-run 预览")
|
||
logger.info(f"待删除段落数量: {len(keys)}")
|
||
logger.info(f"示例: {keys[:5]}")
|
||
if ent_hashes:
|
||
logger.info(f"待删除实体数量: {len(ent_hashes)}")
|
||
if rel_hashes:
|
||
logger.info(f"待删除关系数量: {len(rel_hashes)}")
|
||
|
||
total_nodes_to_delete = len(pg_hashes) + (len(ent_hashes) if args.delete_entities else 0)
|
||
logger.info(f"本次预计删除节点总数(段落+实体): {total_nodes_to_delete}")
|
||
|
||
if args.dry_run:
|
||
logger.info("dry-run 模式,未执行删除")
|
||
return
|
||
|
||
# 大批次删除保护
|
||
if total_nodes_to_delete > args.max_delete_nodes and not args.yes:
|
||
logger.error(
|
||
f"本次预计删除节点 {total_nodes_to_delete} 个,超过阈值 {args.max_delete_nodes}。"
|
||
" 为避免误删,请降低批次规模或使用 --max-delete-nodes 调整阈值,并加上 --yes 明确确认。"
|
||
)
|
||
sys.exit(1)
|
||
|
||
# 交互确认
|
||
if not args.yes:
|
||
confirm = input("确认删除上述数据?输入大写 YES 以继续,其他任意键取消: ").strip()
|
||
if confirm != "YES":
|
||
logger.info("用户取消删除操作")
|
||
return
|
||
|
||
# 加载嵌入与图
|
||
embed_manager = EmbeddingManager()
|
||
kg_manager = KGManager()
|
||
|
||
try:
|
||
embed_manager.load_from_file()
|
||
kg_manager.load_from_file()
|
||
except Exception as e:
|
||
logger.error(f"加载现有知识库失败: {e}")
|
||
sys.exit(1)
|
||
|
||
# 记录删除前全局统计,便于对比
|
||
before_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||
before_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||
before_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||
before_nodes = len(kg_manager.graph.get_node_list())
|
||
before_edges = len(kg_manager.graph.get_edge_list())
|
||
logger.info(
|
||
f"删除前统计: 段落向量={before_para_vec}, 实体向量={before_ent_vec}, 关系向量={before_rel_vec}, "
|
||
f"KG节点={before_nodes}, KG边={before_edges}"
|
||
)
|
||
|
||
# 删除向量
|
||
deleted, skipped = embed_manager.paragraphs_embedding_store.delete_items(keys)
|
||
embed_manager.stored_pg_hashes = set(embed_manager.paragraphs_embedding_store.store.keys())
|
||
logger.info(f"段落向量删除完成,删除: {deleted}, 跳过: {skipped}")
|
||
ent_deleted = ent_skipped = rel_deleted = rel_skipped = 0
|
||
if ent_hashes:
|
||
ent_keys = [f"entity-{h}" for h in ent_hashes]
|
||
ent_deleted, ent_skipped = embed_manager.entities_embedding_store.delete_items(ent_keys)
|
||
logger.info(f"实体向量删除完成,删除: {ent_deleted}, 跳过: {ent_skipped}")
|
||
if rel_hashes:
|
||
rel_keys = [f"relation-{h}" for h in rel_hashes]
|
||
rel_deleted, rel_skipped = embed_manager.relation_embedding_store.delete_items(rel_keys)
|
||
logger.info(f"关系向量删除完成,删除: {rel_deleted}, 跳过: {rel_skipped}")
|
||
|
||
# 删除图节点/边
|
||
kg_result = kg_manager.delete_paragraphs(
|
||
pg_hashes,
|
||
ent_hashes=ent_hashes if args.delete_entities else None,
|
||
remove_orphan_entities=args.remove_orphan_entities,
|
||
)
|
||
logger.info(
|
||
f"KG 删除完成,删除: {kg_result.get('deleted', 0)}, 跳过: {kg_result.get('skipped', 0)}, "
|
||
f"孤立实体清理: {kg_result.get('orphan_removed', 0)}"
|
||
)
|
||
|
||
# 重建索引并保存
|
||
logger.info("重建 Faiss 索引并保存嵌入文件...")
|
||
embed_manager.rebuild_faiss_index()
|
||
embed_manager.save_to_file()
|
||
|
||
logger.info("保存 KG 数据...")
|
||
kg_manager.save_to_file()
|
||
|
||
# 删除后统计
|
||
after_para_vec = len(embed_manager.paragraphs_embedding_store.store)
|
||
after_ent_vec = len(embed_manager.entities_embedding_store.store)
|
||
after_rel_vec = len(embed_manager.relation_embedding_store.store)
|
||
after_nodes = len(kg_manager.graph.get_node_list())
|
||
after_edges = len(kg_manager.graph.get_edge_list())
|
||
|
||
logger.info(
|
||
"删除后统计: 段落向量=%d(%+d), 实体向量=%d(%+d), 关系向量=%d(%+d), KG节点=%d(%+d), KG边=%d(%+d)"
|
||
% (
|
||
after_para_vec,
|
||
after_para_vec - before_para_vec,
|
||
after_ent_vec,
|
||
after_ent_vec - before_ent_vec,
|
||
after_rel_vec,
|
||
after_rel_vec - before_rel_vec,
|
||
after_nodes,
|
||
after_nodes - before_nodes,
|
||
after_edges,
|
||
after_edges - before_edges,
|
||
)
|
||
)
|
||
|
||
logger.info("删除流程完成")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|