#!/usr/bin/env python3 from __future__ import annotations import argparse import json import re import shutil import sys import tempfile from pathlib import Path from typing import Any SOURCE_REF_RE = re.compile( r"^(CMOS18|BRING|HOUSE)\s§[0-9A-Za-z][0-9A-Za-z._\-]*\s+p[0-9ivxlcdmIVXLCDM]+" r"(?:-[0-9ivxlcdmIVXLCDM]+)?(?:\s\(scan p[0-9]+\))?$" ) def _load_ndjson(path: Path) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for idx, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): if not line.strip(): continue try: rows.append(json.loads(line)) except json.JSONDecodeError as e: raise SystemExit(f"{path}:{idx}: invalid JSON ({e})") return rows def _write_ndjson(path: Path, rows: list[dict[str, Any]]) -> None: with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as tmp: tmp_path = Path(tmp.name) for row in rows: tmp.write(json.dumps(row, ensure_ascii=False, separators=(",", ":"))) tmp.write("\n") shutil.move(str(tmp_path), str(path)) def _load_mapping(path: Path) -> dict[str, list[str]]: data = json.loads(path.read_text(encoding="utf-8")) if not isinstance(data, dict): raise SystemExit(f"{path}: expected object mapping rule id -> source_refs list") out: dict[str, list[str]] = {} for k, v in data.items(): if not isinstance(k, str): raise SystemExit(f"{path}: mapping key must be string, got {type(k).__name__}") if isinstance(v, str): v = [v] if not (isinstance(v, list) and all(isinstance(x, str) for x in v)): raise SystemExit(f"{path}: mapping value for {k} must be string or list of strings") out[k] = list(v) return out def _validate_source_refs(rule_id: str, source_refs: list[str]) -> list[str]: errors: list[str] = [] if not source_refs: errors.append("source_refs must be non-empty") return errors for ref in source_refs: if not SOURCE_REF_RE.match(ref): errors.append(f"invalid source_ref: {ref!r}") return errors def patch_source_refs(ndjson_path: Path, mapping: dict[str, list[str]], *, strict: bool) -> int: rows = _load_ndjson(ndjson_path) changed = 0 missing: set[str] = set(mapping.keys()) problems: list[str] = [] for row in rows: rule_id = row.get("id") if rule_id in mapping: new_refs = mapping[rule_id] missing.discard(rule_id) errs = _validate_source_refs(rule_id, new_refs) if errs: problems.extend([f"{rule_id}: {e}" for e in errs]) continue old_refs = row.get("source_refs") if old_refs != new_refs: row["source_refs"] = new_refs changed += 1 if missing: msg = f"{ndjson_path}: mapping includes unknown ids: {', '.join(sorted(missing))}" if strict: problems.append(msg) else: print(f"WARNING: {msg}", file=sys.stderr) if problems: for p in problems: print(f"ERROR: {p}", file=sys.stderr) return 2 if changed: _write_ndjson(ndjson_path, rows) print(json.dumps({"file": str(ndjson_path), "updated": changed}, indent=2)) return 0 def main(argv: list[str]) -> int: ap = argparse.ArgumentParser(description="Patch NDJSON rule files without editing by hand.") ap.add_argument("--file", required=True, type=Path, help="NDJSON file to patch in-place") ap.add_argument("--mapping-json", required=True, type=Path, help="JSON mapping: rule id -> source_refs") ap.add_argument( "--strict", action="store_true", help="Fail if mapping contains unknown ids (default: warn only)", ) args = ap.parse_args(argv) return patch_source_refs(args.file, _load_mapping(args.mapping_json), strict=args.strict) if __name__ == "__main__": raise SystemExit(main(sys.argv[1:]))