ARGO-ProtNLM-50 Benchmark Evaluation¶

Systematic evaluation of ProtNLM2 GO predictions against curated GOA annotations for 50 benchmark proteins selected across 14 taxonomic groups.

Data source: UniProt REST API (rest.uniprot.org/uniprotkb/protnlm/{accession}) Benchmark: argo_protnlm_50.csv — stratified across species, prediction types, evidence methods Evaluation: Closure-based comparison using isa_partof_closure from goa_uniprot_all.ddb

In [ ]:
import pandas as pd
import duckdb
import matplotlib.pyplot as plt
from pathlib import Path

plt.rcParams["figure.dpi"] = 120
plt.rcParams["figure.facecolor"] = "white"

DATA = Path(".")
GOA_DB = Path.home() / "repos/go-db/db/goa_uniprot_all.ddb"

benchmark = pd.read_csv(DATA / "argo_protnlm_50.csv")
entries = pd.read_csv(DATA / "entries.tsv", sep="\t")
predictions = pd.read_csv(DATA / "predictions.tsv", sep="\t")
evidence = pd.read_csv(DATA / "evidence.tsv", sep="\t")

bench_accs = benchmark["accession"].tolist()
bench_preds = predictions[predictions["accession"].isin(bench_accs)]
bench_go = bench_preds[bench_preds["pred_type"] == "GO"].copy()
bench_evid = evidence[evidence["accession"].isin(bench_accs)]

print(f"Benchmark proteins: {len(benchmark)}")
print(f"  with GO predictions: {bench_go['accession'].nunique()}")
print(f"  total GO predictions: {len(bench_go)}")
print(f"  total evidence rows: {len(bench_evid)}")

Benchmark composition¶

The ARGO-ProtNLM-50 benchmark was selected to cover:

  • Taxonomic diversity: 14 broad groups from mammals to alveolates
  • Prediction richness: name-only, GO-only, partial, and rich prediction categories
  • Evidence diversity: string match, phmmer, and tmalign evidence methods
  • Known patterns: 5 case studies from exploratory analysis
In [ ]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

group_counts = benchmark["broad_group"].value_counts()
ax = axes[0]
group_counts.plot.barh(ax=ax, color="#4c78a8")
ax.set_xlabel("Count")
ax.set_title("Taxonomic groups")
ax.invert_yaxis()

pred_counts = benchmark["pred_category"].value_counts()
ax = axes[1]
pred_counts.plot.bar(ax=ax, color=["#2ca02c", "#f58518", "#4c78a8", "#bab0ac"])
ax.set_ylabel("Count")
ax.set_title("Prediction categories")
ax.tick_params(axis="x", rotation=30)

ax = axes[2]
benchmark["n_go"].plot.hist(ax=ax, bins=range(0, benchmark["n_go"].max() + 2),
                            color="#4c78a8", edgecolor="white")
ax.set_xlabel("GO predictions per protein")
ax.set_ylabel("Count")
ax.set_title("GO prediction distribution")

plt.tight_layout()
plt.savefig("slides_figures/bench50_composition.png", dpi=150, bbox_inches="tight")
plt.show()

Evaluation methodology¶

For each ProtNLM2 GO prediction, we classify it against all curated GOA annotations for that protein using the GO isa_partof_closure:

Category Definition
EXACT Same GO term exists in GOA for this protein
LESS_SPECIFIC Predicted term is a strict ancestor of a GOA term (redundant)
MORE_SPECIFIC Predicted term is a strict descendant of a GOA term (potentially novel)
NO_OVERLAP No hierarchical relationship — different branch (novel or incorrect)
NOT_IN_GOA Protein has no curated GOA annotations — cannot evaluate
In [ ]:
con = duckdb.connect(str(GOA_DB), read_only=True)

bench_go_path = "/tmp/bench50_go_preds.csv"
bench_go[["accession", "pred_type", "pred_id", "pred_label"]].to_csv(bench_go_path, index=False)

results = con.execute(f"""
    CREATE TEMP TABLE protnlm AS
    SELECT * FROM read_csv_auto('{bench_go_path}');

    WITH in_goa AS (
        SELECT DISTINCT db_object_id FROM gaf_association
    ),
    exact AS (
        SELECT DISTINCT p.accession, p.pred_id
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
          AND p.pred_id = g.ontology_class_ref
    ),
    less_specific AS (
        SELECT DISTINCT p.accession, p.pred_id
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
        JOIN isa_partof_closure ipc ON g.ontology_class_ref = ipc.subject
        WHERE ipc.object = p.pred_id
          AND g.ontology_class_ref != p.pred_id
    ),
    more_specific AS (
        SELECT DISTINCT p.accession, p.pred_id
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
        JOIN isa_partof_closure ipc ON p.pred_id = ipc.subject
        WHERE ipc.object = g.ontology_class_ref
          AND g.ontology_class_ref != p.pred_id
    )
    SELECT
        CASE
            WHEN e.pred_id IS NOT NULL THEN 'EXACT'
            WHEN ls.pred_id IS NOT NULL THEN 'LESS_SPECIFIC'
            WHEN ms.pred_id IS NOT NULL THEN 'MORE_SPECIFIC'
            WHEN ig.db_object_id IS NULL THEN 'NOT_IN_GOA'
            ELSE 'NO_OVERLAP'
        END AS match_category,
        COUNT(*) AS n_predictions,
        COUNT(DISTINCT p.accession) AS n_proteins
    FROM protnlm p
    LEFT JOIN in_goa ig ON p.accession = ig.db_object_id
    LEFT JOIN exact e ON p.accession = e.accession AND p.pred_id = e.pred_id
    LEFT JOIN less_specific ls ON p.accession = ls.accession AND p.pred_id = ls.pred_id
    LEFT JOIN more_specific ms ON p.accession = ms.accession AND p.pred_id = ms.pred_id
    GROUP BY 1
    ORDER BY 2 DESC;
""").fetchdf()

print(results.to_string(index=False))
In [ ]:
detail = con.execute(f"""
    WITH in_goa AS (
        SELECT DISTINCT db_object_id FROM gaf_association
    ),
    exact AS (
        SELECT DISTINCT p.accession, p.pred_id
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
          AND p.pred_id = g.ontology_class_ref
    ),
    less_specific AS (
        SELECT DISTINCT ON (p.accession, p.pred_id)
               p.accession, p.pred_id,
               g.ontology_class_ref AS goa_match,
               gt.label AS goa_match_label
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
        JOIN isa_partof_closure ipc ON g.ontology_class_ref = ipc.subject
        JOIN term_label gt ON g.ontology_class_ref = gt.id
        WHERE ipc.object = p.pred_id
          AND g.ontology_class_ref != p.pred_id
    ),
    more_specific AS (
        SELECT DISTINCT ON (p.accession, p.pred_id)
               p.accession, p.pred_id,
               g.ontology_class_ref AS goa_match,
               gt.label AS goa_match_label
        FROM protnlm p
        JOIN gaf_association g ON p.accession = g.db_object_id
        JOIN isa_partof_closure ipc ON p.pred_id = ipc.subject
        JOIN term_label gt ON g.ontology_class_ref = gt.id
        WHERE ipc.object = g.ontology_class_ref
          AND g.ontology_class_ref != p.pred_id
    )
    SELECT
        p.accession, p.pred_id, p.pred_label,
        CASE
            WHEN e.pred_id IS NOT NULL THEN 'EXACT'
            WHEN ls.pred_id IS NOT NULL THEN 'LESS_SPECIFIC'
            WHEN ms.pred_id IS NOT NULL THEN 'MORE_SPECIFIC'
            WHEN ig.db_object_id IS NULL THEN 'NOT_IN_GOA'
            ELSE 'NO_OVERLAP'
        END AS match_category,
        COALESCE(ls.goa_match, ms.goa_match) AS goa_match,
        COALESCE(ls.goa_match_label, ms.goa_match_label) AS goa_match_label
    FROM protnlm p
    LEFT JOIN in_goa ig ON p.accession = ig.db_object_id
    LEFT JOIN exact e ON p.accession = e.accession AND p.pred_id = e.pred_id
    LEFT JOIN less_specific ls ON p.accession = ls.accession AND p.pred_id = ls.pred_id
    LEFT JOIN more_specific ms ON p.accession = ms.accession AND p.pred_id = ms.pred_id
    ORDER BY p.accession, match_category;
""").fetchdf()

con.close()
print(f"Total predictions evaluated: {len(detail)}")
detail.head(20)

Results: classification distribution¶

In [ ]:
CAT_ORDER = ["EXACT", "LESS_SPECIFIC", "MORE_SPECIFIC", "NO_OVERLAP", "NOT_IN_GOA"]
CAT_COLORS = {"EXACT": "#2ca02c", "LESS_SPECIFIC": "#f58518", "MORE_SPECIFIC": "#4c78a8",
              "NO_OVERLAP": "#e45756", "NOT_IN_GOA": "#bab0ac"}

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

cat_counts = detail["match_category"].value_counts().reindex(CAT_ORDER, fill_value=0)
ax = axes[0]
cat_counts.plot.bar(ax=ax, color=[CAT_COLORS[c] for c in cat_counts.index])
ax.set_ylabel("Number of GO predictions")
ax.set_title("Classification of ProtNLM2 predictions")
ax.tick_params(axis="x", rotation=30)
for i, v in enumerate(cat_counts):
    ax.text(i, v + 0.3, str(v), ha="center", fontsize=9)

prot_cats = detail.groupby("match_category")["accession"].nunique().reindex(CAT_ORDER, fill_value=0)
ax = axes[1]
prot_cats.plot.bar(ax=ax, color=[CAT_COLORS[c] for c in prot_cats.index])
ax.set_ylabel("Number of proteins")
ax.set_title("Proteins with each category")
ax.tick_params(axis="x", rotation=30)
for i, v in enumerate(prot_cats):
    ax.text(i, v + 0.2, str(v), ha="center", fontsize=9)

plt.tight_layout()
plt.savefig("slides_figures/bench50_classification.png", dpi=150, bbox_inches="tight")
plt.show()

Classification by taxonomic group¶

In [ ]:
detail_meta = detail.merge(
    benchmark[["accession", "broad_group", "organism"]],
    on="accession", how="left")

group_cats = pd.crosstab(detail_meta["broad_group"], detail_meta["match_category"])
group_cats = group_cats.reindex(columns=CAT_ORDER, fill_value=0)
group_cats = group_cats.loc[group_cats.sum(axis=1).sort_values(ascending=False).index]

fig, ax = plt.subplots(figsize=(10, 5))
group_cats.plot.barh(stacked=True, ax=ax,
                     color=[CAT_COLORS[c] for c in group_cats.columns])
ax.set_xlabel("Number of GO predictions")
ax.set_title("Classification by taxonomic group")
ax.legend(loc="lower right", fontsize=8)
ax.invert_yaxis()
plt.tight_layout()
plt.savefig("slides_figures/bench50_by_taxon.png", dpi=150, bbox_inches="tight")
plt.show()

display(group_cats)

Model score vs classification¶

Does the Evidencer model score correlate with prediction accuracy?

In [ ]:
# Link evidence to predictions
bench_evid_m = bench_evid.copy()
bench_evid_m["method"] = "string_match"
bench_evid_m.loc[bench_evid_m["phmmer_accession"].notna()
                 & (bench_evid_m["phmmer_accession"] != ""), "method"] = "phmmer"
bench_evid_m.loc[bench_evid_m["tmalign_accession"].notna()
                 & (bench_evid_m["tmalign_accession"] != ""), "method"] = "tmalign"

# Match evidence to GO predictions by accession + evidence_key
pred_ev = bench_go.copy()
pred_ev["evidence_key"] = pred_ev["evidence_key"].astype(str).str.split(",").str[0]
pred_ev = pred_ev.merge(
    bench_evid_m[["accession", "evidence_key", "method", "model_score"]].astype(
        {"evidence_key": str}),
    on=["accession", "evidence_key"],
    how="left"
)

detail_evid = detail.merge(
    pred_ev[["accession", "pred_id", "method", "model_score"]].drop_duplicates(
        subset=["accession", "pred_id"]),
    on=["accession", "pred_id"],
    how="left"
)

fig, ax = plt.subplots(figsize=(8, 4))
positions = []
for i, cat in enumerate(CAT_ORDER):
    subset = detail_evid[detail_evid["match_category"] == cat]["model_score"].dropna()
    if len(subset) > 0:
        bp = ax.boxplot(subset, positions=[i], widths=0.6,
                        boxprops=dict(color=CAT_COLORS[cat]),
                        medianprops=dict(color="black"),
                        whiskerprops=dict(color=CAT_COLORS[cat]),
                        capprops=dict(color=CAT_COLORS[cat]),
                        flierprops=dict(markeredgecolor=CAT_COLORS[cat], markersize=3))
        positions.append(i)

ax.set_xticks(range(len(CAT_ORDER)))
ax.set_xticklabels(CAT_ORDER, rotation=30)
ax.set_ylabel("Model score")
ax.set_title("Model score distribution by classification")
plt.tight_layout()
plt.savefig("slides_figures/bench50_score_by_cat.png", dpi=150, bbox_inches="tight")
plt.show()

print(detail_evid.groupby("match_category")["model_score"].describe()[
    ["count", "mean", "std", "min", "50%", "max"]])

Case study highlights¶

The 5 case study proteins from the exploratory analysis:

  1. A0A3B6GK97 (wheat) — trivially correct lipid deepening
  2. A0A3B6RKV1 (wheat) — phmmer transfer from JMJ22
  3. F4JLB7 (Arabidopsis) — false positive kinase from multidomain hit
  4. F6LAX4 (wheat) — cross-kingdom neuron error
  5. Q9KZ33 (S. coelicolor) — ontology gap (sigma factor vs transcription)
In [ ]:
case_accs = ["A0A3B6GK97", "A0A3B6RKV1", "F4JLB7", "F6LAX4", "Q9KZ33"]

detail_display = detail.merge(
    benchmark[["accession", "organism", "broad_group", "protein_name", "selection_reason"]],
    on="accession", how="left")

for acc in case_accs:
    subset = detail_display[detail_display["accession"] == acc]
    if len(subset) == 0:
        print(f"\n{'='*60}")
        print(f"{acc}: No GO predictions")
        continue
    row = subset.iloc[0]
    print(f"\n{'='*60}")
    print(f"{acc} — {row['protein_name']} ({row['organism']})")
    print(f"Reason: {row['selection_reason']}")
    print(f"{'─'*60}")
    for _, p in subset.iterrows():
        goa_info = ""
        if pd.notna(p.get("goa_match")):
            goa_info = f" (GOA: {p['goa_match']} {p['goa_match_label']})"
        print(f"  {p['match_category']:15s} {p['pred_id']} {p['pred_label']}{goa_info}")

Per-protein evaluation detail¶

Full table of all GO predictions with their classification.

In [ ]:
CAT_BG = {"EXACT": "#d4edda", "LESS_SPECIFIC": "#fff3cd",
          "MORE_SPECIFIC": "#d1ecf1", "NO_OVERLAP": "#f8d7da",
          "NOT_IN_GOA": "#e2e3e5"}

display(detail_display[
    ["accession", "organism", "protein_name", "pred_id", "pred_label",
     "match_category", "goa_match", "goa_match_label"]
].style.map(
    lambda v: "background-color: " + CAT_BG.get(v, ""),
    subset=["match_category"]))

Export evaluation results¶

In [ ]:
detail.to_csv(DATA / "bench50_evaluation_results.csv", index=False)
print(f"Saved {len(detail)} evaluation results to bench50_evaluation_results.csv")

print(f"\n{'='*50}")
print(f"ARGO-ProtNLM-50 Evaluation Summary")
print(f"{'='*50}")
total = len(detail)
for cat in CAT_ORDER:
    n = (detail["match_category"] == cat).sum()
    pct = n / total * 100 if total > 0 else 0
    print(f"  {cat:17s}: {n:3d} ({pct:5.1f}%)")
print(f"  {'Total':17s}: {total:3d}")