#!/usr/bin/env python3
"""Cross-reference the database-wide CAUTION survey against genes already in this
repo, to produce a prioritized worklist of high-value entries NOT yet reviewed.

Reads:
    caution_uniprot_reviewed.tsv  (from uniprot_api_survey.py)
    genes/*/*/*-uniprot.txt        (AC lines = accessions we already fetched)

Writes:
    candidates_high_value.tsv      one row per high-signal, not-yet-fetched entry
    candidates.md                  grouped, ranked shortlist (top of each category)

"High value" = a CAUTION category that signals likely GO over/mis-annotation:
degenerate-domain (pseudo-enzyme), retracted-reference, reclassified-function,
possible-artifact. Boilerplate and generic notes are excluded.
"""
from __future__ import annotations

import re
from collections import Counter
from pathlib import Path

REPO = Path(__file__).resolve().parents[2]
OUT_DIR = Path(__file__).resolve().parent
SURVEY = OUT_DIR / "caution_uniprot_reviewed.tsv"

HIGH_VALUE = ("degenerate-domain", "retracted-reference",
              "reclassified-function", "possible-artifact")


def categorize(text: str) -> str:
    t = text.lower()
    if "retract" in t:
        return "retracted-reference"
    if (("lacks" in t and any(w in t for w in ("active site", "catalytic", "conserved",
                                               "heme-binding", "phospho-accepting",
                                               "essential", "required")))
            or ("although" in t and "lacks" in t)
            or "probably not" in t or "may not be functional" in t):
        return "degenerate-domain"
    if ("was initially" in t or "was originally" in t or "was reported" in t
            or "originally thought" in t or "initially believed" in t or "previously" in t):
        return "reclassified-function"
    if "artifact" in t or "artefact" in t:
        return "possible-artifact"
    if any(w in t for w in ("controvers", "uncertain", "in contrast", "however",
                            "may not", "could be", "disput", "remains to be",
                            "not clear", "unclear", "questioned")):
        return "contested-function"
    return "other"


def local_accessions() -> set[str]:
    acc: set[str] = set()
    for p in REPO.glob("genes/*/*/*-uniprot.txt"):
        for line in p.read_text(encoding="utf-8", errors="replace").splitlines():
            if line.startswith("AC "):
                acc.add(line[5:].split(";")[0].strip())
                break
    return acc


def main() -> None:
    have = local_accessions()
    rows = [l.rstrip("\n").split("\t") for l in SURVEY.read_text().splitlines()][1:]

    candidates = []
    for r in rows:
        if len(r) < 6:
            continue
        acc, gene, organism, org_id, pname, caution = r[:6]
        first_note = re.split(r"(?:^|\s)CAUTION:\s*", caution.strip())
        first_note = next((p for p in first_note if p.strip()), caution)
        cat = categorize(first_note)
        if acc in have or cat not in HIGH_VALUE:
            continue
        candidates.append((cat, acc, gene, organism, org_id, pname, first_note.strip()))

    candidates.sort(key=lambda c: (c[0], c[3], c[2]))

    tsv = OUT_DIR / "candidates_high_value.tsv"
    with tsv.open("w", encoding="utf-8") as fh:
        fh.write("category\taccession\tgene\torganism\torganism_id\tprotein_name\tcaution\n")
        for c in candidates:
            fh.write("\t".join(c) + "\n")

    by_cat: dict[str, list] = {}
    org_in_cat: dict[str, Counter] = {}
    for c in candidates:
        by_cat.setdefault(c[0], []).append(c)
        org_in_cat.setdefault(c[0], Counter())[c[3]] += 1

    md = OUT_DIR / "candidates.md"
    with md.open("w", encoding="utf-8") as fh:
        fh.write("---\ntitle: \"UniProt CAUTION — High-Value Review Candidates\"\n---\n\n")
        fh.write("# High-Value CAUTION Review Candidates (not yet in repo)\n\n")
        fh.write("Auto-generated by `shortlist_candidates.py`. Do not edit by hand.\n\n")
        fh.write(f"- Local accessions already fetched: **{len(have)}**\n")
        fh.write(f"- High-value, not-yet-fetched candidates: **{len(candidates)}**\n\n")
        for cat in HIGH_VALUE:
            items = by_cat.get(cat, [])
            fh.write(f"## {cat} ({len(items)})\n\n")
            top_orgs = ", ".join(f"{o} ({n})" for o, n in org_in_cat.get(cat, Counter()).most_common(5))
            if top_orgs:
                fh.write(f"Top organisms: {top_orgs}\n\n")
            # show human entries first as the most actionable, capped
            human = [c for c in items if c[4] == "9606"][:20]
            if human:
                fh.write("Human examples:\n\n")
                for c in human:
                    fh.write(f"- **{c[2] or c[1]}** ({c[1]}): {c[6][:130]}\n")
                fh.write("\n")

    print(f"local={len(have)} candidates={len(candidates)}")
    for cat in HIGH_VALUE:
        print(f"  {cat}: {len(by_cat.get(cat, []))}")
    print(f"wrote {tsv} and {md}")


if __name__ == "__main__":
    main()
