#!/usr/bin/env python3
"""
Summarize and compare Eyesis MCP run logs.

Usage examples:
  scripts/eyesis_run_report.py extract \
    --log attic/session-logs/eyesis_mcp.log \
    --json /tmp/run_latest.json

  scripts/eyesis_run_report.py extract \
    --snapshot-dir /tmp/eyesis-reports \
    --auto-compare

  scripts/eyesis_run_report.py compare \
    --base /tmp/run_prev.json \
    --new /tmp/run_latest.json
"""

from __future__ import annotations

import argparse
import datetime as dt
import json
import math
import re
import sys
from pathlib import Path
from typing import Any


RUN_START_RE = re.compile(r"^STARTED PROCESSING SCENE SEQUENCE (?P<idx>\d+) \(last is (?P<last>\d+)\)")
RUN_FINISHED_RE = re.compile(
    r"^PROCESSING OF (?P<count>\d+) SCENE SEQUENCES is FINISHED in (?P<seconds>[0-9]+(?:\.[0-9]+)?) sec\.$"
)
BATCH_FINISHED_RE = re.compile(
    r"^batchRig\(\): Processing finished at (?P<seconds>[0-9]+(?:\.[0-9]+)?) sec, --- Free memory=(?P<free>\d+) "
    r"\(of (?P<total>\d+)\)"
)
CONFIG_RESTORED_RE = re.compile(r"Configuration parameters are restored from (?P<path>/\S+)")
PATH_RE = re.compile(r"(/[^ \t\r\n\"'<>]+)")


def to_number(value: str) -> Any:
    if value in {"NaN", "+NaN", "-NaN"}:
        return math.nan
    try:
        if re.fullmatch(r"[+-]?\d+", value):
            return int(value)
        return float(value)
    except ValueError:
        return value


def parse_global_line(line: str) -> dict[str, Any] | None:
    if "IntersceneGlobalRefine: outer=" not in line or " inner=" not in line:
        return None
    section = line
    for marker in (" lpfE(", " lpfR(", " pair weights applied", " LPF "):
        pos = section.find(marker)
        if pos >= 0:
            section = section[:pos]
    pairs = re.findall(r"([A-Za-z][A-Za-z0-9_]*)=([^\s,()]+)", section)
    if not pairs:
        return None
    parsed: dict[str, Any] = {}
    for key, raw in pairs:
        parsed[key] = to_number(raw)
    return parsed


def unique_keep_order(items: list[str]) -> list[str]:
    seen: set[str] = set()
    out: list[str] = []
    for item in items:
        if item in seen:
            continue
        seen.add(item)
        out.append(item)
    return out


def choose_block(lines: list[str]) -> tuple[int, int]:
    finished_idx = [i for i, line in enumerate(lines) if RUN_FINISHED_RE.search(line)]
    if not finished_idx:
        raise ValueError("No completed run marker found in log.")
    end_idx = finished_idx[-1]

    start_idx = 0
    for i in range(end_idx, -1, -1):
        if RUN_START_RE.search(lines[i]):
            start_idx = i
            break

    next_start = len(lines)
    for i in range(end_idx + 1, len(lines)):
        if RUN_START_RE.search(lines[i]):
            next_start = i
            break
    return start_idx, next_start - 1


def extract_report(log_path: Path) -> dict[str, Any]:
    lines = log_path.read_text(encoding="utf-8", errors="replace").splitlines()
    start_idx, end_idx = choose_block(lines)
    block = lines[start_idx : end_idx + 1]

    run_finish: dict[str, Any] = {}
    batch_finish: dict[str, Any] = {}
    config_path: str | None = None
    global_rows: list[dict[str, Any]] = []
    paths: list[str] = []

    for line in block:
        if m := RUN_FINISHED_RE.search(line):
            run_finish = {
                "scene_sequence_count": int(m.group("count")),
                "processing_seconds": float(m.group("seconds")),
            }
        if m := BATCH_FINISHED_RE.search(line):
            batch_finish = {
                "total_seconds": float(m.group("seconds")),
                "free_memory_bytes": int(m.group("free")),
                "heap_memory_bytes": int(m.group("total")),
            }
        if m := CONFIG_RESTORED_RE.search(line):
            config_path = m.group("path")
        if parsed := parse_global_line(line):
            global_rows.append(parsed)

        for p in PATH_RE.findall(line):
            lower = p.lower()
            if lower.endswith((".tiff", ".csv", ".corr-xml", ".xml", ".log", ".list")):
                paths.append(p.rstrip(".,);]"))

    if not config_path:
        for line in reversed(lines[: end_idx + 1]):
            if m := CONFIG_RESTORED_RE.search(line):
                config_path = m.group("path")
                break

    start_line = block[0] if block else ""
    start_match = RUN_START_RE.search(start_line)
    run_start = {
        "line_index": start_idx + 1,
        "line_text": start_line,
        "scene_index": int(start_match.group("idx")) if start_match else None,
        "last_scene_index": int(start_match.group("last")) if start_match else None,
    }

    last_detailed = None
    for row in reversed(global_rows):
        if "pcgIter" in row or "lambda" in row:
            last_detailed = row
            break
    if last_detailed is None and global_rows:
        last_detailed = global_rows[-1]

    report: dict[str, Any] = {
        "log_path": str(log_path),
        "run": {
            **run_start,
            "line_end_index": end_idx + 1,
            **run_finish,
            **batch_finish,
        },
        "config_path": config_path,
        "global_lma": {
            "lines_count": len(global_rows),
            "last": global_rows[-1] if global_rows else None,
            "last_detailed": last_detailed,
        },
        "referenced_outputs": unique_keep_order(paths),
    }
    return report


def get_path(data: dict[str, Any], dotted: str) -> Any:
    cur: Any = data
    for part in dotted.split("."):
        if not isinstance(cur, dict) or part not in cur:
            return None
        cur = cur[part]
    return cur


def is_number(v: Any) -> bool:
    return isinstance(v, (int, float)) and not (isinstance(v, float) and math.isnan(v))


def fmt_num(v: Any) -> str:
    if v is None:
        return "n/a"
    if isinstance(v, float):
        return f"{v:.9g}"
    return str(v)


def compare_reports(base: dict[str, Any], new: dict[str, Any]) -> str:
    metrics = [
        "run.processing_seconds",
        "run.total_seconds",
        "global_lma.lines_count",
        "global_lma.last_detailed.outer",
        "global_lma.last_detailed.inner",
        "global_lma.last_detailed.avgPairRms",
        "global_lma.last_detailed.avgPairRmsPure",
        "global_lma.last_detailed.maxDelta",
        "global_lma.last_detailed.pcgIter",
        "global_lma.last_detailed.lambda",
    ]

    lines = []
    lines.append("metric,base,new,delta,new/base")
    for key in metrics:
        bv = get_path(base, key)
        nv = get_path(new, key)
        delta = None
        ratio = None
        if is_number(bv) and is_number(nv):
            delta = float(nv) - float(bv)
            if float(bv) != 0.0:
                ratio = float(nv) / float(bv)
        lines.append(
            ",".join(
                [
                    key,
                    fmt_num(bv),
                    fmt_num(nv),
                    fmt_num(delta),
                    fmt_num(ratio),
                ]
            )
        )
    return "\n".join(lines)


def write_json(path: Path, payload: dict[str, Any]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")


def collect_snapshots(snapshot_dir: Path, prefix: str) -> list[Path]:
    return sorted(
        p for p in snapshot_dir.glob(f"{prefix}_*.json") if p.name not in {f"{prefix}_latest.json", f"{prefix}_previous.json"}
    )


def create_snapshot(report: dict[str, Any], snapshot_dir: Path, prefix: str) -> tuple[Path, Path | None]:
    snapshot_dir.mkdir(parents=True, exist_ok=True)
    existing = collect_snapshots(snapshot_dir, prefix)
    previous = existing[-1] if existing else None

    stamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
    base_name = f"{prefix}_{stamp}"
    new_path = snapshot_dir / f"{base_name}.json"
    suffix = 1
    while new_path.exists():
        new_path = snapshot_dir / f"{base_name}_{suffix}.json"
        suffix += 1

    write_json(new_path, report)
    write_json(snapshot_dir / f"{prefix}_latest.json", report)
    if previous is not None:
        prev_data = json.loads(previous.read_text(encoding="utf-8"))
        write_json(snapshot_dir / f"{prefix}_previous.json", prev_data)
    return new_path, previous


def cmd_extract(args: argparse.Namespace) -> int:
    report = extract_report(Path(args.log))
    if args.json:
        write_json(Path(args.json), report)

    print(f"log: {report['log_path']}")
    print(
        "run: "
        f"sequences={report['run'].get('scene_sequence_count', 'n/a')} "
        f"processing={fmt_num(report['run'].get('processing_seconds'))}s "
        f"total={fmt_num(report['run'].get('total_seconds'))}s"
    )
    print(f"config: {report.get('config_path') or 'n/a'}")
    gl = report["global_lma"]
    print(f"global_lma_lines: {gl['lines_count']}")
    if gl["last_detailed"]:
        last = gl["last_detailed"]
        print(
            "global_lma_last: "
            f"outer={fmt_num(last.get('outer'))} "
            f"inner={fmt_num(last.get('inner'))} "
            f"avgPairRms={fmt_num(last.get('avgPairRms'))} "
            f"avgPairRmsPure={fmt_num(last.get('avgPairRmsPure'))} "
            f"pcgIter={fmt_num(last.get('pcgIter'))} "
            f"lambda={fmt_num(last.get('lambda'))}"
        )
    show_n = max(0, args.show_outputs)
    outputs = report["referenced_outputs"]
    if show_n and outputs:
        print(f"referenced_outputs_last_{min(show_n, len(outputs))}:")
        for p in outputs[-show_n:]:
            print(f"  {p}")
    if args.json:
        print(f"json: {args.json}")

    previous_snapshot: Path | None = None
    new_snapshot: Path | None = None
    if args.snapshot_dir:
        new_snapshot, previous_snapshot = create_snapshot(report, Path(args.snapshot_dir), args.prefix)
        print(f"snapshot: {new_snapshot}")
        if previous_snapshot:
            print(f"snapshot_previous: {previous_snapshot}")
        else:
            print("snapshot_previous: n/a")

    if args.auto_compare:
        if previous_snapshot is None or new_snapshot is None:
            print("compare: skipped (no previous snapshot)")
        else:
            base = json.loads(previous_snapshot.read_text(encoding="utf-8"))
            print(f"compare_base: {previous_snapshot}")
            print(f"compare_new: {new_snapshot}")
            print(compare_reports(base, report))
    return 0


def cmd_compare(args: argparse.Namespace) -> int:
    base = json.loads(Path(args.base).read_text(encoding="utf-8"))
    new = json.loads(Path(args.new).read_text(encoding="utf-8"))
    print(compare_reports(base, new))
    return 0


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Eyesis MCP run report helper")
    sub = parser.add_subparsers(dest="cmd", required=True)

    p_extract = sub.add_parser("extract", help="Extract latest completed run from log")
    p_extract.add_argument(
        "--log",
        default="attic/session-logs/eyesis_mcp.log",
        help="Path to Eyesis log (default: attic/session-logs/eyesis_mcp.log)",
    )
    p_extract.add_argument("--json", help="Write report JSON to this path")
    p_extract.add_argument(
        "--snapshot-dir",
        help="Write timestamped snapshot JSON under this directory and refresh <prefix>_latest.json",
    )
    p_extract.add_argument(
        "--prefix",
        default="run",
        help="Snapshot filename prefix for --snapshot-dir (default: run)",
    )
    p_extract.add_argument(
        "--auto-compare",
        action="store_true",
        help="With --snapshot-dir, compare against previous snapshot and print metric table",
    )
    p_extract.add_argument(
        "--show-outputs",
        type=int,
        default=10,
        help="Show last N referenced output paths from log (default: 10, 0 to disable)",
    )
    p_extract.set_defaults(func=cmd_extract)

    p_compare = sub.add_parser("compare", help="Compare two report JSON files")
    p_compare.add_argument("--base", required=True, help="Baseline report JSON")
    p_compare.add_argument("--new", required=True, help="New report JSON")
    p_compare.set_defaults(func=cmd_compare)
    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()
    try:
        return args.func(args)
    except Exception as ex:
        print(f"error: {ex}", file=sys.stderr)
        return 1


if __name__ == "__main__":
    raise SystemExit(main())
