#!/usr/bin/env python3
"""Query a local RAG index built by rag_index.py."""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import List

import hnswlib
import numpy as np
from fastembed import TextEmbedding


def load_meta(meta_path: Path) -> List[dict]:
    meta = []
    with meta_path.open("r", encoding="utf-8") as f:
        for line in f:
            meta.append(json.loads(line))
    return meta


def main() -> int:
    parser = argparse.ArgumentParser(description="Query local RAG index")
    parser.add_argument("query", help="Query text")
    parser.add_argument(
        "--index-dir",
        default="/home/elphel/git/imagej-elphel/attic/CODEX/rag_index",
        help="Index directory",
    )
    parser.add_argument("--top-k", type=int, default=5)
    parser.add_argument("--json", action="store_true", help="Output JSON only")

    args = parser.parse_args()

    index_dir = Path(args.index_dir)
    config = json.loads((index_dir / "config.json").read_text(encoding="utf-8"))
    meta = load_meta(index_dir / "meta.jsonl")

    model = config.get("model", "BAAI/bge-base-en-v1.5")
    embedding = TextEmbedding(model_name=model)

    query_vec = list(embedding.embed([args.query]))[0]
    query_vec = np.array(query_vec, dtype=np.float32)
    norm = np.linalg.norm(query_vec)
    if norm > 0:
        query_vec = query_vec / norm

    # Load index
    dim = len(query_vec)
    index = hnswlib.Index(space="cosine", dim=dim)
    index.load_index(str(index_dir / "index.bin"))
    index.set_ef(100)

    labels, distances = index.knn_query(query_vec, k=args.top_k)

    results = []
    for rank, (idx, dist) in enumerate(zip(labels[0], distances[0]), start=1):
        item = meta[idx]
        snippet = item["text"][:400].replace("\n", " ")
        results.append(
            {
                "rank": rank,
                "source": item["source"],
                "chunk_index": item["chunk_index"],
                "score": float(1.0 - dist),
                "snippet": snippet,
            }
        )

    if args.json:
        print(json.dumps({"ok": True, "query": args.query, "results": results}, ensure_ascii=False))
        return 0

    print(f"Query: {args.query}")
    print("Results:\n")
    for item in results:
        print(f"{item['rank']}. {item['source']} (chunk {item['chunk_index']})")
        print(f"   score: {item['score']:.4f}")
        print(f"   {item['snippet']}")
        print()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
