#!/usr/bin/env python3
"""Build a local RAG index from attic/CODEX/rag_sources.

Outputs:
- index.bin (hnswlib)
- meta.jsonl (chunk metadata)
- config.json
"""

from __future__ import annotations

import argparse
import hashlib
import json
import os
import re
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable, List, Tuple

import hnswlib
import numpy as np
from fastembed import TextEmbedding
from tqdm import tqdm


def sha1_text(text: str) -> str:
    h = hashlib.sha1()
    h.update(text.encode("utf-8", errors="ignore"))
    return h.hexdigest()


def read_text_file(path: Path) -> str:
    return path.read_text(encoding="utf-8", errors="ignore")


def latex_to_text(text: str) -> str:
    # Remove comments
    text = re.sub(r"(?m)^\s*%.*$", "", text)
    # Drop LaTeX commands but keep their arguments
    text = re.sub(r"\\[a-zA-Z*]+\s*\{", "{", text)
    text = re.sub(r"\\[a-zA-Z*]+\s*\[.*?\]", "", text)
    text = re.sub(r"\\[a-zA-Z*]+", " ", text)
    # Strip remaining braces
    text = text.replace("{", " ").replace("}", " ")
    # Collapse whitespace
    text = re.sub(r"\s+", " ", text)
    return text.strip()


def pdf_to_text(path: Path) -> str:
    # pdftotext writes to stdout with '-' output
    try:
        result = subprocess.run(
            ["pdftotext", "-q", str(path), "-"],
            check=True,
            capture_output=True,
        )
    except subprocess.CalledProcessError as exc:
        raise RuntimeError(f"pdftotext failed for {path}: {exc}")
    return result.stdout.decode("utf-8", errors="ignore")


def iter_source_files(root: Path, exclude_substrings: List[str]) -> Iterable[Path]:
    exts = {".md", ".txt", ".tex", ".pdf"}
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        path_str = str(p)
        if any(token in path_str for token in exclude_substrings):
            continue
        if p.suffix.lower() in exts:
            yield p


def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
    if not text:
        return []
    chunks: List[str] = []
    n = len(text)
    start = 0
    while start < n:
        end = min(n, start + chunk_size)
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        if end >= n:
            break
        start = max(0, end - overlap)
    return chunks


def file_to_chunks(path: Path, chunk_size: int, overlap: int) -> List[str]:
    suffix = path.suffix.lower()
    if suffix == ".pdf":
        text = pdf_to_text(path)
    else:
        text = read_text_file(path)
        if suffix == ".tex":
            text = latex_to_text(text)
    return chunk_text(text, chunk_size, overlap)


def main() -> int:
    parser = argparse.ArgumentParser(description="Build local RAG index")
    parser.add_argument(
        "--source",
        default="/home/elphel/git/imagej-elphel/attic/CODEX/rag_sources",
        help="Root directory with sources",
    )
    parser.add_argument(
        "--out",
        default="/home/elphel/git/imagej-elphel/attic/CODEX/rag_index",
        help="Output directory",
    )
    parser.add_argument(
        "--model",
        default="BAAI/bge-base-en-v1.5",
        help="Embedding model for fastembed",
    )
    parser.add_argument("--chunk-size", type=int, default=3000)
    parser.add_argument("--overlap", type=int, default=300)
    parser.add_argument("--top-k", type=int, default=10)
    parser.add_argument(
        "--exclude",
        action="append",
        default=["/elphel-bib-glossary/"],
        help="Exclude paths containing this substring (can be repeated)",
    )

    args = parser.parse_args()

    source_root = Path(args.source).resolve()
    out_dir = Path(args.out).resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    files = list(iter_source_files(source_root, args.exclude))
    if not files:
        print(f"No source files found under {source_root}")
        return 1

    embedding = TextEmbedding(model_name=args.model)

    meta_path = out_dir / "meta.jsonl"
    index_path = out_dir / "index.bin"
    config_path = out_dir / "config.json"

    # Build chunks and embeddings
    all_meta = []
    all_vectors = []

    for path in tqdm(files, desc="Files"):
        try:
            chunks = file_to_chunks(path, args.chunk_size, args.overlap)
        except Exception as exc:
            print(f"WARN: failed to read {path}: {exc}")
            continue
        if not chunks:
            continue

        rel_path = path.relative_to(source_root)
        embeddings = list(embedding.embed(chunks))
        for idx, (chunk, vec) in enumerate(zip(chunks, embeddings)):
            vec = np.array(vec, dtype=np.float32)
            # normalize for cosine similarity
            norm = np.linalg.norm(vec)
            if norm > 0:
                vec = vec / norm
            all_vectors.append(vec)
            all_meta.append(
                {
                    "id": len(all_meta),
                    "source": str(rel_path),
                    "path": str(path),
                    "chunk_index": idx,
                    "text": chunk,
                    "sha1": sha1_text(chunk),
                }
            )

    if not all_vectors:
        print("No embeddings created; aborting")
        return 1

    dim = len(all_vectors[0])
    index = hnswlib.Index(space="cosine", dim=dim)
    index.init_index(max_elements=len(all_vectors), ef_construction=200, M=16)
    index.add_items(np.vstack(all_vectors), np.arange(len(all_vectors)))
    index.set_ef(100)
    index.save_index(str(index_path))

    with meta_path.open("w", encoding="utf-8") as f:
        for item in all_meta:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

    config = {
        "source": str(source_root),
        "out": str(out_dir),
        "model": args.model,
        "chunk_size": args.chunk_size,
        "overlap": args.overlap,
        "count": len(all_meta),
        "created": datetime.now(timezone.utc).isoformat(),
    }
    config_path.write_text(json.dumps(config, indent=2), encoding="utf-8")

    print(f"Index built: {index_path}")
    print(f"Metadata: {meta_path}")
    print(f"Config: {config_path}")
    print(f"Chunks: {len(all_meta)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
