#!/usr/bin/env python3
"""
Produce Comprehensive Results Tables for Aurora Paper

Generates tables in the EXACT format of paper_table_comprehensive_v2_standalone.tex
but with:
- AURC[F1]* computed over coverage spectrum c ∈ [0.5, 1.0] in 5% steps (like standard AURC)
- FNR* (FNR after selective classification at B_M)
- Two variants: MIRROR and OPERABLE budget sets

=============================================================================
TABLE VARIANTS
=============================================================================

VARIANT A - MIRROR BUDGETS:
  Rejection budgets = [50, 100, 200, 400]
  Mirrors the monthly label budget (B_M) values

VARIANT B - OPERABLE SPECTRUM:
  Rejection budgets = [100, 200, 400, 800, 1600]
  Tests the full operational deployment range

Usage:
  cd <repository-root>
  python examples/reproduce_paper_table/produce_comprehensive_results_new_transcendent.py
"""

import sys
sys.path.insert(0, 'src')

import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict
from datetime import datetime
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import roc_auc_score

from aurora import (
    PickleResultsLoader,
    JSONResultsLoader,
    create_cutoff_month_filter,
    create_hyperparameter_filter,
    combine_collections,
    expand_results_with_ncms,
    ResultsCollection,
    ExperimentResult,
    ExperimentMetadata,
    StabilityMetrics,
    compute_metrics_numba,
    compute_aurc,
)
from aurora.performance_rejection import PerformanceRejectionSimulator
from aurora.pareto import ParetoAnalyzer


# ============================================================================
# CONFIGURATION
# ============================================================================

RESULTS_DIR = Path("data-for-export/deep_drebin_svc")
OTHER_RESULTS_DIR = Path("data-for-export/others_v2")
TRANSCENDENT_DIR = Path("data-for-export/deep_drebin_svc/transcendent")
OUTPUT_DIR = Path("examples/reproduce_paper_table/results")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

VALIDATION_MONTHS = 6
BUDGETS = [50, 100, 200, 400]

# TWO REJECTION BUDGET VARIANTS
REJECTION_BUDGETS_MIRROR = [50, 100, 200, 400]
REJECTION_BUDGETS_OPERABLE = [100, 200, 400, 800, 1600]

# FNR targets (relaxed for better coverage)
FNR_TARGETS = [0.10, 0.15, 0.20, 0.25, 0.30]
FPR_TARGETS = [0.005, 0.01, 0.02, 0.05]

DATASET_ABBREV = {"androzoo": "azoo", "apigraph": "apig", "transcendent": "trans"}
DATASET_SHORT = {"azoo": "AZ", "apig": "AP", "trans": "TR"}

METHOD_SHORT = {
    "DeepDrebin (cold) - MSP": "DeepDrebin",
    "HCC (warm) - Pseudo-Loss": "HCC-PL",
    "HCC (warm) - MSP": "HCC-MSP",
    "Drebin (cold) - Margin": "Drebin",
    "CADE (cold) - OOD": "CADE-OOD",
    "CADE (cold) - MSP": "CADE-MSP",
    "CADE (warm) - OOD": "CADE-OOD-W",
    "CADE (warm) - MSP": "CADE-MSP-W",
}

METHOD_ORDER = [
    "DeepDrebin (cold) - MSP",
    "HCC (warm) - Pseudo-Loss",
    "HCC (warm) - MSP",
    "Drebin (cold) - Margin",
    "CADE (cold) - OOD",
    "CADE (cold) - MSP",
]

# Methods that should be excluded for B_0=4800 (subsampled) because no data exists
EXCLUDE_FOR_SUBSAMPLED = ["CADE (cold) - OOD", "CADE (cold) - MSP"]

SAMPLER_MODE_TO_B0 = {
    "full_first_year_subsample_months": "D_0",
    "subsample_first_year_subsample_months": 4800,
}

# Fields to exclude from hyperparameter comparisons
# These are either seed-like fields or alignment metadata
SEED_FIELDS = {"Random-Seed", "Run", "Seed", "_Original-Test-Month"}

# ============================================================================
# PROBLEMATIC SEED FILTERING (for clean_seeds variant)
# ============================================================================
# CADE-OOD on Transcendent has 2 problematic seeds (154733, 154734) with extreme
# OOD score outliers that cause unreliable calibration. These seeds (and their
# budget variants) should be excluded for a "clean" analysis.
#
# Set EXCLUDE_PROBLEMATIC_SEEDS = True to filter these out
# See: investigate_cade_ood_seeds.py for full investigation
EXCLUDE_PROBLEMATIC_SEEDS = False  # Set to True for clean_seeds variant

# Base problematic seeds and their budget variants
# Seeds are budget-specific: 154730-154734 for B_M=50, 204730-204734 for B_M=100, etc.
PROBLEMATIC_CADE_OOD_SEEDS = {
    154733, 154734,  # B_M=50
    204733, 204734,  # B_M=100
    304733, 304734,  # B_M=200
    504733, 504734,  # B_M=400
}


# ============================================================================
# TRANSCENDENT DATA ALIGNMENT
# ============================================================================
#
# PROBLEM:
#   OLD (HCC/CADE) and NEW (DeepDrebin/Drebin) Transcendent datasets have
#   different month numbering due to different dataset versions.
#
# INVESTIGATION FINDINGS (from investigate_transcendent_alignment.py):
#   - OLD has months 0-46 (47 months)
#   - NEW has months 0-41 (42 months)
#   - Sample counts align PERFECTLY with offset -6:
#       OLD month 6 = NEW month 0 (same 2581 samples)
#       OLD month 7 = NEW month 1 (same 2730 samples)
#       ... and so on
#   - 41 months align (OLD 6-46 ↔ NEW 0-40)
#   - Labels differ slightly (<1% of samples) due to label corrections in NEW
#
# SOLUTION:
#   1. Filter OLD (HCC/CADE) transcendent data to months 6-46
#   2. Relabel OLD months by subtracting 6 (so month 6 → 0, month 7 → 1, etc.)
#   3. Filter NEW data to months 0-40 (exclude month 41 which has no OLD equivalent)
#   4. Result: Both datasets have aligned months 0-40 (41 test months)
#
# VALIDATION IMPACT:
#   - VALIDATION_MONTHS = 6, so months 0-5 are used for hyperparameter selection
#   - After alignment: test months are 6-40 (35 months) for both OLD and NEW
#
# ============================================================================

TRANSCENDENT_ALIGNMENT = {
    # Month offset: OLD month X corresponds to NEW month (X - offset)
    'old_to_new_offset': 6,

    # OLD data: include only months in this range (inclusive)
    'old_month_min': 6,
    'old_month_max': 46,

    # NEW data: include only months in this range (inclusive)
    'new_month_min': 0,
    'new_month_max': 40,

    # Resulting aligned range after transformation
    'aligned_month_min': 0,
    'aligned_month_max': 40,
    'aligned_month_count': 41,
}


def align_old_transcendent_months(results: List[Dict], source_name: str = "OLD") -> List[Dict]:
    """
    Align OLD Transcendent data (HCC/CADE) with NEW data (DeepDrebin/Drebin).

    Transformation:
      1. Filter to months 6-46 only (exclude months 0-5)
      2. Relabel months by subtracting 6 (month 6 → 0, month 7 → 1, etc.)

    Args:
        results: List of result dictionaries
        source_name: Name for logging (e.g., "HCC", "CADE")

    Returns:
        List of aligned results with transformed month numbers
    """
    offset = TRANSCENDENT_ALIGNMENT['old_to_new_offset']
    min_month = TRANSCENDENT_ALIGNMENT['old_month_min']
    max_month = TRANSCENDENT_ALIGNMENT['old_month_max']

    aligned = []
    excluded_count = 0

    for r in results:
        dataset = r.get('Dataset', '').lower()

        # Only transform transcendent data
        if dataset != 'transcendent':
            aligned.append(r)
            continue

        month = r.get('Test-Month')
        if month is None:
            aligned.append(r)
            continue

        # Filter to aligned month range
        if month < min_month or month > max_month:
            excluded_count += 1
            continue

        # Create aligned copy with transformed month
        aligned_r = r.copy()
        aligned_r['Test-Month'] = month - offset

        # Preserve original month for debugging if needed
        aligned_r['_Original-Test-Month'] = month

        aligned.append(aligned_r)

    if excluded_count > 0:
        print(f"    [{source_name}] Aligned transcendent months: excluded {excluded_count} results "
              f"(months 0-{min_month-1}), relabeled {min_month}-{max_month} → 0-{max_month-offset}")

    return aligned


def filter_new_transcendent_months(results: List[Dict]) -> List[Dict]:
    """
    Filter NEW Transcendent data to aligned months only.

    Excludes month 41 which has no OLD equivalent.

    Args:
        results: List of result dictionaries

    Returns:
        List of filtered results
    """
    max_month = TRANSCENDENT_ALIGNMENT['new_month_max']

    filtered = []
    excluded_count = 0

    for r in results:
        month = r.get('Test-Month')
        if month is None:
            filtered.append(r)
            continue

        if month > max_month:
            excluded_count += 1
            continue

        filtered.append(r)

    if excluded_count > 0:
        print(f"    [NEW] Filtered to aligned months: excluded {excluded_count} results "
              f"(month {max_month + 1}+)")

    return filtered


# ============================================================================
# NEW TRANSCENDENT DATA LOADING
# ============================================================================

def convert_new_transcendent_format(raw_results, trainer_mode, base_name):
    """
    Convert NEW Transcendent data format to match OLD format.

    NEW data already has one result per month with correct fields.
    We normalize sampler mode names to match OLD format.
    """
    # Map NEW sampler mode names to OLD names
    SAMPLER_MODE_MAP = {
        'full_first_year': 'full_first_year_subsample_months',
        'subsampled_4800': 'subsample_first_year_subsample_months',
    }

    converted = []
    for r in raw_results:
        # Normalize sampler mode to match OLD format
        sampler_mode = r['Sampler-Mode']
        sampler_mode = SAMPLER_MODE_MAP.get(sampler_mode, sampler_mode)

        # Create result in OLD format (one per month)
        result = {
            'Dataset': 'transcendent',
            'Trainer-Mode': trainer_mode,
            'Sampler-Mode': sampler_mode,
            'Monthly-Label-Budget': r['Monthly-Label-Budget'],
            'Random-Seed': r['Random-Seed'],
            'Test-Month': r['Test-Month'],
            'Predictions': r['Predictions'],
            'Labels': r['Labels'],
            'Uncertainties (Month Ahead)': r['Uncertainties (Month Ahead)'],
            'Uncertainties (Past Month)': r.get('Uncertainties (Past Month)', r['Uncertainties (Month Ahead)']),
            'base_name': base_name,
            'Num-Epochs': 30,
        }
        converted.append(result)

    return converted


def load_new_transcendent_data():
    """
    Load the NEW Transcendent data for DeepDrebin and Drebin.

    Applies month filtering to align with OLD data (excludes month 41).
    """
    print("\n[NEW] Loading NEW Transcendent data (correct dataset version)...")

    new_files = {
        ('DeepDrebin (cold) - MSP', 'full_first_year'): 'deep_drebin_full.pkl',
        ('DeepDrebin (cold) - MSP', 'subsampled_4800'): 'deep_drebin_subsampled_4800.pkl',
        ('Drebin (cold) - Margin', 'full_first_year'): 'drebin_full.pkl',
        ('Drebin (cold) - Margin', 'subsampled_4800'): 'drebin_subsampled_4800.pkl',
    }

    all_converted = []
    for (trainer_mode, sampler_suffix), filename in new_files.items():
        fp = TRANSCENDENT_DIR / filename
        if fp.exists():
            with open(fp, 'rb') as f:
                raw = pickle.load(f)

            # Convert format
            converted = convert_new_transcendent_format(raw, trainer_mode, trainer_mode)
            print(f"    Loaded {len(converted)} results from {filename}")
            all_converted.extend(converted)
        else:
            print(f"    Warning: File not found: {filename}")

    print(f"    Total NEW Transcendent results (before alignment): {len(all_converted)}")

    # Apply month filtering to align with OLD data
    all_converted = filter_new_transcendent_months(all_converted)
    print(f"    Total NEW Transcendent results (after alignment): {len(all_converted)}")

    return all_converted


# ============================================================================
# DATA LOADING (MODIFIED FOR NEW TRANSCENDENT)
# ============================================================================

def make_np_arrays(results):
    for item in results:
        for key, value in item.items():
            if isinstance(value, list):
                item[key] = np.array(value)
    return results


def load_all_data():
    """
    Load all experimental data with NCM expansion and transcendent alignment.

    TRANSCENDENT ALIGNMENT:
      - NEW data (DeepDrebin/Drebin): Uses correct dataset, months 0-40
      - OLD data (HCC/CADE): Months 6-46 relabeled to 0-40 to align with NEW
      - Result: All methods use aligned months 0-40 (41 test months)

    See TRANSCENDENT_ALIGNMENT configuration for details.
    """
    print("=" * 80)
    print("LOADING EXPERIMENTAL DATA (WITH TRANSCENDENT ALIGNMENT)")
    print("=" * 80)
    print(f"  Transcendent alignment: OLD months {TRANSCENDENT_ALIGNMENT['old_month_min']}-"
          f"{TRANSCENDENT_ALIGNMENT['old_month_max']} → aligned months "
          f"{TRANSCENDENT_ALIGNMENT['aligned_month_min']}-{TRANSCENDENT_ALIGNMENT['aligned_month_max']}")
    print("=" * 80)

    common_filters = [
        create_cutoff_month_filter(),
        create_hyperparameter_filter("Num-Epochs", [10], exclude=True),
    ]
    all_collections = []

    # 1. Deep Drebin (OLD data - FILTER OUT Transcendent)
    print("\n[1/5] Loading Deep Drebin (excluding OLD Transcendent)...")
    # Custom filter to exclude Transcendent
    def dd_filter(r):
        ds = r.get('Dataset', '').lower()
        return ds != 'transcendent'

    loader_dd = PickleResultsLoader(
        base_name_mapper=lambda r: "DeepDrebin (cold) - MSP" if r.get("Trainer-Mode") in ["CE", "DeepDrebin"] else r.get("Trainer-Mode"),
        filters=common_filters + [dd_filter],
        auto_validate=False
    )
    dd_file = RESULTS_DIR / "parallel_ce_no_aug_v2.pkl"
    if dd_file.exists():
        deep_drebin = loader_dd.load(dd_file, experiment_name="DeepDrebin")
        print(f"  Loaded {len(deep_drebin)} results (AndroZoo + API-Graph only)")
        all_collections.append(deep_drebin)

    # 2. SVC/Drebin (OLD data - FILTER OUT Transcendent)
    print("\n[2/5] Loading Drebin/SVC (excluding OLD Transcendent)...")
    loader_svc = PickleResultsLoader(
        base_name_mapper=lambda r: "Drebin (cold) - Margin",
        filters=common_filters + [dd_filter],  # Same filter to exclude Transcendent
        auto_validate=False
    )
    svc_file = RESULTS_DIR / "parallel_svc_v2.pkl"
    if svc_file.exists():
        svc = loader_svc.load(svc_file, experiment_name="SVC")
        for result in svc.results:
            result.trainer_mode = "Drebin (cold) - Margin"
            result.base_name = "Drebin (cold) - Margin"
        print(f"  Loaded {len(svc)} results (AndroZoo + API-Graph only)")
        all_collections.append(svc)

    # 3. NEW Transcendent data for DeepDrebin and Drebin
    # NOTE: Do NOT use expand_results_with_ncms - NEW data only has MSP uncertainty
    print("\n[3/5] Loading NEW Transcendent data (correct dataset)...")
    new_trans = load_new_transcendent_data()
    if new_trans:
        # Convert to ResultsCollection directly (no NCM expansion needed)
        loader_new = PickleResultsLoader(
            base_name_mapper=lambda r: r.get("base_name", r.get("Trainer-Mode")),
            filters=[],
            auto_validate=False
        )
        metadata = ExperimentMetadata(
            experiment_name="NewTranscendent",
            source_file="(NEW aurora-experiments)",
            load_timestamp=datetime.now().isoformat()
        )
        results = [loader_new._convert_result(r) for r in new_trans if r]
        new_trans_collection = ResultsCollection(metadata=metadata, results=[r for r in results if r])
        print(f"  Loaded {len(new_trans_collection)} NEW Transcendent results")
        all_collections.append(new_trans_collection)

    # 4. HCC (with transcendent month alignment)
    print("\n[4/5] Loading HCC...")
    hcc_files = [
        "hcc_mlp_warm-androzoo.json", "hcc_mlp_warm-apigraph.json", "hcc_mlp_warm-transcendent.json",
        "hcc_mlp_warm-androzoo-subsampling.json", "hcc_mlp_warm-apigraph-subsampling.json",
        "hcc_mlp_warm-transcendent-subsampling.json",
    ]
    all_hcc = []
    for fn in hcc_files:
        fp = OTHER_RESULTS_DIR / fn
        if fp.exists():
            with open(fp, "r") as f:
                all_hcc.extend(json.load(f))

    if all_hcc:
        all_hcc = make_np_arrays(all_hcc)
        cutoff_filter = create_cutoff_month_filter()
        all_hcc = [r for r in all_hcc if cutoff_filter(r)]
        for item in all_hcc:
            if item.get("Sampler-Mode") == "subsampled_first_year_subsample_months":
                item["Sampler-Mode"] = "subsample_first_year_subsample_months"
            if "Seed" in item:
                item["Random-Seed"] = item.pop("Seed")

        # ALIGNMENT: Transform transcendent months to align with NEW data
        print(f"    HCC before alignment: {len(all_hcc)} results")
        all_hcc = align_old_transcendent_months(all_hcc, source_name="HCC")
        print(f"    HCC after alignment: {len(all_hcc)} results")

        hcc_expanded = expand_results_with_ncms(all_hcc, trainer_mode_key="Trainer-Mode",
                                                 uncertainty_key="Uncertainties (Month Ahead)", clean_temp_fields=True)
        loader_hcc = JSONResultsLoader(base_name_mapper=lambda r: r.get("Trainer-Mode"), filters=[], auto_validate=False, rename_seed_field=False)
        metadata = ExperimentMetadata(experiment_name="HCC", source_file="(multiple)", load_timestamp=datetime.now().isoformat())
        results = [loader_hcc._convert_result(raw) for raw in hcc_expanded if raw]
        hcc_combined = ResultsCollection(metadata=metadata, results=[r for r in results if r])
        print(f"  Loaded {len(hcc_combined)} results (with alignment)")
        all_collections.append(hcc_combined)

    # 5. CADE (with transcendent month alignment)
    print("\n[5/5] Loading CADE...")
    cade_files = [
        "cade_mlp_cold-androzoo.json", "cade_mlp_cold-apigraph.json", "cade_mlp_cold-transcendent.json",
        "cade_mlp_warm-androzoo.json", "cade_mlp_warm-apigraph.json", "cade_mlp_warm-transcendent.json",
    ]
    all_cade = []
    for fn in cade_files:
        fp = OTHER_RESULTS_DIR / fn
        if fp.exists():
            with open(fp, "r") as f:
                all_cade.extend(json.load(f))

    if all_cade:
        all_cade = make_np_arrays(all_cade)
        cutoff_filter = create_cutoff_month_filter()
        all_cade = [r for r in all_cade if cutoff_filter(r)]
        for item in all_cade:
            if "Seed" in item:
                item["Random-Seed"] = item.pop("Seed")

        # ALIGNMENT: Transform transcendent months to align with NEW data
        print(f"    CADE before alignment: {len(all_cade)} results")
        all_cade = align_old_transcendent_months(all_cade, source_name="CADE")
        print(f"    CADE after alignment: {len(all_cade)} results")

        cade_expanded = expand_results_with_ncms(all_cade, trainer_mode_key="Trainer-Mode",
                                                  uncertainty_key="Uncertainties (Month Ahead)", clean_temp_fields=True)

        # OPTIONAL: Filter out problematic CADE-OOD seeds for transcendent
        if EXCLUDE_PROBLEMATIC_SEEDS:
            before_count = len(cade_expanded)
            filtered_cade = []
            excluded_count = 0
            for r in cade_expanded:
                is_cade_ood = "OOD" in str(r.get("Trainer-Mode", ""))
                is_transcendent = str(r.get("Dataset", "")).lower() == "transcendent"
                seed = r.get("Random-Seed") or r.get("Seed")
                is_problematic = seed in PROBLEMATIC_CADE_OOD_SEEDS

                if is_cade_ood and is_transcendent and is_problematic:
                    excluded_count += 1
                    continue
                filtered_cade.append(r)
            cade_expanded = filtered_cade
            print(f"    [CLEAN SEEDS] Excluded {excluded_count} problematic CADE-OOD transcendent results")

        loader_cade = JSONResultsLoader(base_name_mapper=lambda r: r.get("Trainer-Mode"), filters=[], auto_validate=False, rename_seed_field=False)
        metadata = ExperimentMetadata(experiment_name="CADE", source_file="(multiple)", load_timestamp=datetime.now().isoformat())
        results = [loader_cade._convert_result(raw) for raw in cade_expanded if raw]
        cade_combined = ResultsCollection(metadata=metadata, results=[r for r in results if r])
        print(f"  Loaded {len(cade_combined)} results (with alignment)")
        all_collections.append(cade_combined)

    if all_collections:
        combined = combine_collections(*all_collections)
        print(f"\nTOTAL: {len(combined)} results")

        # Verify transcendent alignment
        print("\n" + "-" * 40)
        print("TRANSCENDENT ALIGNMENT VERIFICATION")
        print("-" * 40)
        trans_results = [r for r in combined.results if r.dataset == 'transcendent']
        if trans_results:
            months = set(r.test_month for r in trans_results)
            methods = set(r.base_name for r in trans_results)
            expected_min = TRANSCENDENT_ALIGNMENT['aligned_month_min']
            expected_max = TRANSCENDENT_ALIGNMENT['aligned_month_max']

            print(f"  Transcendent results: {len(trans_results)}")
            print(f"  Month range: {min(months)} - {max(months)}")
            print(f"  Expected range: {expected_min} - {expected_max}")
            print(f"  Methods: {len(methods)}")

            if min(months) == expected_min and max(months) == expected_max:
                print(f"  ✅ Alignment verified: all methods use months {expected_min}-{expected_max}")
            else:
                print(f"  ⚠️  Unexpected month range detected!")
        print("-" * 40)

        return combined
    return None


# ============================================================================
# AGGREGATION HELPERS
# ============================================================================

def aggregate_across_seeds(results):
    results_by_month = defaultdict(list)
    for r in results:
        results_by_month[r.test_month].append(r)

    preds, labels, uncs = [], [], []
    for month in sorted(results_by_month.keys()):
        mr = results_by_month[month]
        preds.append(np.concatenate([r.predictions for r in mr]))
        labels.append(np.concatenate([r.labels for r in mr]))
        try:
            uncs.append(np.concatenate([r.uncertainties_month_ahead for r in mr]))
        except:
            try:
                uncs.append(np.concatenate([r.uncertainties for r in mr]))
            except:
                uncs.append(np.zeros_like(preds[-1]))
    return preds, labels, uncs


def select_best_hyperparams(collection, dataset, budget, base_name, sampler_mode):
    results = [r for r in collection.results
               if r.dataset == dataset and r.monthly_label_budget == budget
               and r.base_name == base_name and r.sampler_mode == sampler_mode]
    if not results:
        return None

    hp_groups = defaultdict(list)
    for r in results:
        hp_key = tuple((k, v) for k, v in sorted(r.hyperparameters.items()) if k not in SEED_FIELDS)
        hp_groups[hp_key].append(r)

    if len(hp_groups) <= 1:
        return {k: v for k, v in results[0].hyperparameters.items() if k not in SEED_FIELDS} or None

    best_f1, best_hp = -np.inf, None
    for hp_key, group in hp_groups.items():
        try:
            p, l, _ = aggregate_across_seeds(group)
            vp, vl = p[:VALIDATION_MONTHS], l[:VALIDATION_MONTHS]
            if vp:
                f1, _, _ = compute_metrics_numba(np.concatenate(vl), np.concatenate(vp))
                if f1 > best_f1:
                    best_f1, best_hp = f1, dict(hp_key)
        except:
            continue
    return best_hp


def get_results_with_hp(collection, dataset, budget, base_name, sampler_mode, hp):
    results = []
    for r in collection.results:
        if (r.dataset == dataset and r.monthly_label_budget == budget
            and r.base_name == base_name and r.sampler_mode == sampler_mode):
            if hp is None or all(r.hyperparameters.get(k) == v for k, v in hp.items() if k not in SEED_FIELDS):
                results.append(r)
    return results


# ============================================================================
# METRICS (PROPER THRESHOLD-BASED REJECTION)
# ============================================================================
#
# KEY FIX: All rejection metrics now use proper threshold calibration:
#   1. Calibrate threshold on VALIDATION data to achieve target rejection quota
#   2. Apply SAME threshold across all TEST months (actual rejections may vary)
#   3. This is the correct budget-based rejection simulation approach
#
# This matches how MAE[FNR*] works and is realistic for deployment scenarios.
# ============================================================================


def calibrate_threshold_for_quota(
    val_uncertainties: np.ndarray,
    target_rejections_per_month: int,
    n_months: int,
    tolerance: float = 1e-4,
    max_iterations: int = 100
) -> float:
    """
    Calibrate uncertainty threshold to achieve target rejection count PER MONTH.

    Uses binary search to find threshold τ such that we reject approximately
    `target_rejections_per_month` samples per month on average.

    IMPORTANT: Budget B means "reject ~B samples per month", not "B samples total".
    This matches the paper's methodology where B mirrors the monthly label budget.

    Args:
        val_uncertainties: Pooled uncertainties from validation data
        target_rejections_per_month: Target rejections per month (e.g., 50, 100, 200, 400)
        n_months: Number of months in the validation data (for computing total target)
        tolerance: Convergence tolerance for binary search
        max_iterations: Maximum iterations

    Returns:
        Calibrated threshold value (reject samples with uncertainty > threshold)
    """
    val_uncertainties = np.asarray(val_uncertainties).flatten()
    n_total = len(val_uncertainties)

    if n_total == 0:
        return np.inf  # No data, don't reject anything

    # Target TOTAL rejections = target per month * number of months
    target_total_rejections = target_rejections_per_month * n_months

    # Target rejection rate across pooled validation
    target_rate = target_total_rejections / n_total
    target_rate = min(target_rate, 0.95)  # Cap at 95% rejection

    if target_rate <= 0:
        return np.inf  # Don't reject anything

    # Binary search over threshold space
    sorted_uncs = np.sort(val_uncertainties)
    low_thresh = sorted_uncs[0] - 1e-6   # Accept all
    high_thresh = sorted_uncs[-1] + 1e-6  # Reject all

    for _ in range(max_iterations):
        mid_thresh = (low_thresh + high_thresh) / 2

        # Count rejections at this threshold
        n_rejected = np.sum(val_uncertainties > mid_thresh)
        current_rate = n_rejected / n_total

        # Check convergence
        if abs(current_rate - target_rate) < tolerance:
            return mid_thresh

        # Adjust bounds
        if current_rate < target_rate:
            # Need to reject more → lower threshold
            high_thresh = mid_thresh
        else:
            # Rejecting too much → higher threshold
            low_thresh = mid_thresh

    return mid_thresh


def apply_rejection_threshold(
    predictions: np.ndarray,
    labels: np.ndarray,
    uncertainties: np.ndarray,
    threshold: float,
    min_samples: int = 20
) -> Tuple[float, int, int]:
    """
    Apply rejection threshold and compute F1 on accepted samples.

    Args:
        predictions: Binary predictions
        labels: True labels
        uncertainties: Uncertainty scores
        threshold: Rejection threshold (reject if uncertainty > threshold)
        min_samples: Minimum samples required for valid F1

    Returns:
        (f1, n_accepted, n_rejected)
    """
    # Accept samples with uncertainty <= threshold
    accept_mask = uncertainties <= threshold
    n_accepted = np.sum(accept_mask)
    n_rejected = len(predictions) - n_accepted

    if n_accepted < min_samples:
        return np.nan, n_accepted, n_rejected

    # Compute F1 on accepted samples
    f1, _, _ = compute_metrics_numba(labels[accept_mask], predictions[accept_mask])
    return float(f1), n_accepted, n_rejected


def compute_auroc_from_predictions_and_uncertainties(
    labels: np.ndarray,
    predictions: np.ndarray,
    uncertainties: np.ndarray
) -> float:
    """
    Compute AUROC from binary predictions and uncertainty scores.

    For MSP-based uncertainty, we reconstruct predicted probability of positive class:
    - If prediction = 1 (positive): p_positive = 1 - uncertainty (confident in positive)
    - If prediction = 0 (negative): p_positive = uncertainty (uncertain means closer to positive)

    This works because:
    - uncertainty = 1 - max_softmax
    - If we predict positive, max_softmax = p_positive, so p_positive = 1 - uncertainty
    - If we predict negative, max_softmax = p_negative = 1 - p_positive, so p_positive = uncertainty

    Args:
        labels: Ground truth binary labels (0/1)
        predictions: Binary predictions (0/1)
        uncertainties: Uncertainty scores (higher = more uncertain)

    Returns:
        AUROC score (0-1, higher is better)
    """
    if len(labels) == 0 or len(np.unique(labels)) < 2:
        return np.nan

    # Reconstruct predicted probability of positive class
    predictions = np.asarray(predictions)
    uncertainties = np.asarray(uncertainties)
    labels = np.asarray(labels)

    # p_positive = (1 - uncertainty) if prediction=1, else uncertainty
    p_positive = np.where(predictions == 1, 1 - uncertainties, uncertainties)

    # Clip to valid probability range
    p_positive = np.clip(p_positive, 0, 1)

    try:
        return float(roc_auc_score(labels, p_positive))
    except Exception:
        return np.nan


def compute_selective_f1_with_calibration(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    target_rejections_per_month: int,
    n_val_months: int = VALIDATION_MONTHS,
    min_samples: int = 20
) -> Tuple[float, List[float], float, List[int]]:
    """
    Compute selective F1 using proper threshold calibration.

    Process:
        1. Calibrate threshold on validation data to achieve target_rejections PER MONTH
        2. Apply SAME threshold to each test month (actual rejections vary)
        3. Compute F1 on accepted samples per month
        4. Return mean F1 across months

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data
        test_preds, test_labels, test_uncs: Per-month test data
        target_rejections_per_month: Target rejections per month (e.g., 50, 100, 200, 400)
        n_val_months: Number of validation months (for computing total target)
        min_samples: Minimum samples for valid F1

    Returns:
        (mean_f1, monthly_f1_list, calibrated_threshold, monthly_rejections)
    """
    # Pool validation uncertainties for calibration
    val_uncs_pooled = val_uncs if isinstance(val_uncs, np.ndarray) else np.concatenate(val_uncs)

    # Calibrate threshold on validation data (target is PER MONTH)
    threshold = calibrate_threshold_for_quota(
        val_uncs_pooled,
        target_rejections_per_month=target_rejections_per_month,
        n_months=n_val_months
    )

    # Apply to each test month
    monthly_f1 = []
    monthly_rejections = []

    for p, l, u in zip(test_preds, test_labels, test_uncs):
        if len(p) == 0:
            monthly_f1.append(np.nan)
            monthly_rejections.append(0)
            continue

        f1, n_acc, n_rej = apply_rejection_threshold(p, l, u, threshold, min_samples)
        monthly_f1.append(f1)
        monthly_rejections.append(n_rej)

    mean_f1 = float(np.nanmean(monthly_f1)) if monthly_f1 else np.nan
    return mean_f1, monthly_f1, threshold, monthly_rejections


def compute_auc_f1_star(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    budgets: List[int],
    n_val_months: int = VALIDATION_MONTHS,
    include_baseline: bool = True
) -> Tuple[float, Dict[int, float], Dict[int, List[float]]]:
    """
    Compute AUC[F1*] using proper threshold-based rejection simulation.

    For each rejection budget B (interpreted as rejections PER MONTH):
        1. Calibrate threshold on validation to achieve ~B rejections per month
        2. Apply threshold to test months (actual rejections vary)
        3. Compute F1* = mean(monthly F1 after rejection)

    AUC[F1*] = trapezoidal integration of F1*(B) over budgets

    IMPORTANT: Following the paper's methodology (Figure 4), the integration
    STARTS from B=0 (baseline, no rejection) to properly capture the
    improvement from selective classification. The paper shows
    ρ ∈ {0, 100, 200, ..., 1500}.

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data
        test_preds, test_labels, test_uncs: Per-month test data
        budgets: List of rejection budgets PER MONTH (e.g., [50, 100, 200, 400])
        n_val_months: Number of validation months
        include_baseline: If True, include B=0 (baseline) in AUC integration

    Returns:
        (auc_f1_star, budget_to_f1, budget_to_monthly_f1)
    """
    budget_f1 = {}
    budget_monthly_f1 = {}

    # FIRST: Compute baseline F1 at B=0 (no rejection)
    # This is critical for proper AUC computation matching the paper's Figure 4
    if include_baseline:
        baseline_monthly_f1 = []
        for p, l in zip(test_preds, test_labels):
            if len(p) > 0:
                f1, _, _ = compute_metrics_numba(l, p)
                baseline_monthly_f1.append(float(f1))

        if baseline_monthly_f1:
            budget_f1[0] = float(np.mean(baseline_monthly_f1))
            budget_monthly_f1[0] = baseline_monthly_f1

    # Compute selective F1 for each budget
    for B in budgets:
        mean_f1, monthly_f1, _, _ = compute_selective_f1_with_calibration(
            val_preds, val_labels, val_uncs,
            test_preds, test_labels, test_uncs,
            target_rejections_per_month=B,
            n_val_months=n_val_months
        )
        if not np.isnan(mean_f1):
            budget_f1[B] = mean_f1
            budget_monthly_f1[B] = monthly_f1

    if len(budget_f1) < 2:
        return (list(budget_f1.values())[0] if budget_f1 else np.nan), budget_f1, budget_monthly_f1

    # Trapezoidal integration from 0 (or first budget) to max budget
    sorted_b = sorted(budget_f1.keys())
    sorted_f1 = [budget_f1[b] for b in sorted_b]

    # Integration range: from 0 (if baseline included) to max budget
    integration_start = sorted_b[0]  # 0 if baseline included
    integration_end = sorted_b[-1]
    integration_range = integration_end - integration_start

    if integration_range <= 0:
        return sorted_f1[0], budget_f1, budget_monthly_f1

    auc = np.trapz(sorted_f1, x=sorted_b) / integration_range

    return float(auc), budget_f1, budget_monthly_f1


def compute_auf1_star(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    coverage_min: float = 0.05,
    coverage_step: float = 0.05,
    n_val_months: int = VALIDATION_MONTHS,
    min_samples: int = 20
) -> Tuple[float, Dict[float, float]]:
    """
    Compute AURC[F1]*: Area Under F1-Risk Coverage Curve (operational setting).

    This is the COVERAGE-BASED version, analogous to standard AURC:
    - Integrates Risk_F1(c) = 1 - F1*(c) over coverage c ∈ [coverage_min, 1.0]
    - Coverage c = fraction of samples KEPT (not rejected)
    - Lower is better (like AURC)

    Key difference from standard AURC:
    - Uses F1-based risk instead of error-based risk
    - Uses operational threshold calibration (calibrate on val, apply to test)

    The coverage-based approach is more principled than budget-based because:
    - Results are comparable across datasets with different sizes
    - Integration range [0.05, 1.0] covers nearly the full spectrum
    - Matches AURC methodology

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data for calibration
        test_preds, test_labels, test_uncs: Per-month test data
        coverage_min: Minimum coverage to probe (default 0.05 = reject up to 95%)
        coverage_step: Coverage step size (default 0.05 = 5% increments)
        n_val_months: Number of validation months
        min_samples: Minimum samples required per month for valid F1

    Returns:
        (aurc_f1_star, coverage_risk): The integrated F1-risk and per-coverage risk values
        Lower is better (like AURC, but using F1-based risk).
    """
    # Pool validation uncertainties for calibration
    val_uncs_pooled = val_uncs if isinstance(val_uncs, np.ndarray) else np.concatenate(val_uncs)
    n_val_samples = len(val_uncs_pooled)

    if n_val_samples == 0:
        return np.nan, {}

    # Grid of coverage levels: 1.0, 0.95, 0.90, ..., coverage_min
    # Coverage = fraction of samples we KEEP
    coverages = np.arange(1.0, coverage_min - 0.001, -coverage_step)

    coverage_f1 = {}

    for cov in coverages:
        if cov >= 1.0:
            # Baseline: no rejection (coverage = 100%)
            monthly_f1 = []
            for p, l in zip(test_preds, test_labels):
                if len(p) >= min_samples:
                    f1, _, _ = compute_metrics_numba(l, p)
                    if not np.isnan(f1):
                        monthly_f1.append(float(f1))
            if monthly_f1:
                coverage_f1[1.0] = np.mean(monthly_f1)
            continue

        # Compute threshold that achieves this coverage on validation data
        # Coverage c means we KEEP c fraction, so we REJECT (1-c) fraction
        # Find threshold such that P(unc <= threshold) = c
        threshold = np.percentile(val_uncs_pooled, cov * 100)

        monthly_f1 = []
        for p, l, u in zip(test_preds, test_labels, test_uncs):
            if len(p) == 0:
                continue

            # Apply threshold: keep samples with uncertainty <= threshold
            keep_mask = u <= threshold
            n_kept = np.sum(keep_mask)

            if n_kept >= min_samples:
                f1, _, _ = compute_metrics_numba(l[keep_mask], p[keep_mask])
                if not np.isnan(f1):
                    monthly_f1.append(float(f1))
            # Else: month has too few samples at this coverage, skip

        if monthly_f1:
            coverage_f1[cov] = np.mean(monthly_f1)

    # Convert to RISK (1 - F1)
    coverage_risk = {c: 1 - f1 for c, f1 in coverage_f1.items() if not np.isnan(f1)}

    # Trapezoidal integration of risk over coverage
    # Sort by coverage (ascending for proper integration)
    valid_coverages = sorted(coverage_risk.keys())
    valid_risk = [coverage_risk[c] for c in valid_coverages]

    if len(valid_coverages) < 2:
        return (valid_risk[0] if valid_risk else np.nan), coverage_risk

    # Integrate risk over coverage using trapezoidal rule
    # Note: we integrate from coverage_min to 1.0
    auc = np.trapz(valid_risk, valid_coverages)
    integration_range = valid_coverages[-1] - valid_coverages[0]

    if integration_range <= 0:
        return valid_risk[0], coverage_risk

    # Normalize by integration range to get average risk
    aurc_f1_star = auc / integration_range

    return float(aurc_f1_star), coverage_risk


def compute_fnr_star_at_bm(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    bm: int,
    n_val_months: int = VALIDATION_MONTHS,
    min_samples: int = 20
) -> float:
    """
    Compute FNR* at exactly B_M (the monthly label budget).

    FNR* is the False Negative Rate achieved under selective classification
    when rejecting ~B_M samples per month.

    Process:
        1. Calibrate threshold on validation to reject ~B_M samples/month
        2. Apply threshold to each test month
        3. Compute FNR on accepted (non-rejected) samples per month
        4. Return mean FNR across months

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data
        test_preds, test_labels, test_uncs: Per-month test data
        bm: Monthly label budget (rejection target)
        n_val_months: Number of validation months
        min_samples: Minimum samples for valid FNR computation

    Returns:
        Mean FNR* across test months (as fraction, not percentage)
    """
    if bm <= 0:
        # No rejection - compute baseline FNR
        monthly_fnr = []
        for p, l in zip(test_preds, test_labels):
            if len(p) >= min_samples:
                _, fnr, _ = compute_metrics_numba(l, p)
                if not np.isnan(fnr):
                    monthly_fnr.append(float(fnr))
        return np.mean(monthly_fnr) if monthly_fnr else np.nan

    # Pool validation uncertainties for calibration
    val_uncs_pooled = val_uncs if isinstance(val_uncs, np.ndarray) else np.concatenate(val_uncs)

    # Calibrate threshold to reject ~B_M samples per month
    threshold = calibrate_threshold_for_quota(
        val_uncs_pooled,
        target_rejections_per_month=bm,
        n_months=n_val_months
    )

    # Apply to each test month and compute FNR on accepted samples
    monthly_fnr = []
    for p, l, u in zip(test_preds, test_labels, test_uncs):
        if len(p) == 0:
            continue

        # Accept samples with uncertainty <= threshold
        keep_mask = u <= threshold
        n_kept = np.sum(keep_mask)

        if n_kept >= min_samples:
            _, fnr, _ = compute_metrics_numba(l[keep_mask], p[keep_mask])
            if not np.isnan(fnr):
                monthly_fnr.append(float(fnr))

    return np.mean(monthly_fnr) if monthly_fnr else np.nan


def compute_bf_at_bm(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    bm: int
) -> float:
    """
    Compute Benefit Fraction (BF*) at exactly B_M (the monthly label budget).

    BF* = Sum(positive ΔF1) / Sum(|ΔF1|) × 100%
    where ΔF1_i = F1*_i (selective) - F1_i (baseline) per month

    - BF* = 100%: SC always helps, never hurts
    - BF* = 50%: SC helps and hurts equally (break-even)
    - BF* = 0%: SC always hurts, never helps

    This version computes at exactly B_M (matching the labeling capacity),
    answering: "If I defer B_M samples/month for inspection, does SC help?"

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data
        test_preds, test_labels, test_uncs: Per-month test data
        bm: The monthly label budget (B_M) to use as rejection target

    Returns:
        BF* as percentage (0-100)
    """
    if bm <= 0:
        return np.nan

    # Compute baseline monthly F1 (no rejection)
    baseline_monthly = []
    for p, l in zip(test_preds, test_labels):
        if len(p) > 0:
            f1, _, _ = compute_metrics_numba(l, p)
            baseline_monthly.append(float(f1))
        else:
            baseline_monthly.append(np.nan)

    baseline_arr = np.array([x if not np.isnan(x) else 0 for x in baseline_monthly])

    if len(baseline_arr) == 0:
        return np.nan

    # Compute selective F1 at exactly B_M using proper calibration
    _, selective_monthly, _, _ = compute_selective_f1_with_calibration(
        val_preds, val_labels, val_uncs,
        test_preds, test_labels, test_uncs,
        target_rejections_per_month=bm
    )

    rej_arr = np.array([x if not np.isnan(x) else 0 for x in selective_monthly])

    if len(rej_arr) != len(baseline_arr):
        return np.nan

    # Compute per-month differences
    delta = rej_arr - baseline_arr

    # Sum of positive changes (gains)
    total_gain = float(np.sum(np.maximum(delta, 0)))

    # Sum of absolute changes (total movement)
    total_movement = float(np.sum(np.abs(delta)))

    # Handle edge case: no change at all
    if total_movement == 0:
        return 50.0  # Neutral

    return (total_gain / total_movement) * 100


def compute_rejection_stats_at_bm(
    val_preds: np.ndarray,
    val_labels: np.ndarray,
    val_uncs: np.ndarray,
    test_preds: List[np.ndarray],
    test_labels: List[np.ndarray],
    test_uncs: List[np.ndarray],
    bm: int
) -> Tuple[float, float, List[float], List[int]]:
    """
    Compute rejection statistics at exactly B_M (the monthly label budget).

    This function calibrates a threshold on validation to reject ~B_M samples/month,
    then applies that threshold to test months and measures:
    - ΔRej: Mean(actual rejections) - B_M (systematic bias, in samples)
    - σ[Rej]: Std(actual rejections) (variability, in samples)

    Args:
        val_preds, val_labels, val_uncs: Pooled validation data
        test_preds, test_labels, test_uncs: Per-month test data
        bm: The monthly label budget (B_M) to use as rejection target

    Returns:
        (delta_rej, sigma_rej, monthly_f1, monthly_rejections)
        - delta_rej: Mean(rejections) - B_M (positive = over-rejection)
        - sigma_rej: Std(rejections) across months
        - monthly_f1: List of F1* per month (for MD* computation)
        - monthly_rejections: List of actual rejection counts per month
    """
    if bm <= 0:
        return np.nan, np.nan, [], []

    # Use compute_selective_f1_with_calibration which already does the work
    _, monthly_f1, _, monthly_rejections = compute_selective_f1_with_calibration(
        val_preds, val_labels, val_uncs,
        test_preds, test_labels, test_uncs,
        target_rejections_per_month=bm
    )

    if not monthly_rejections:
        return np.nan, np.nan, monthly_f1, monthly_rejections

    # Filter out zero rejections (empty months)
    valid_rejections = [r for r in monthly_rejections if r > 0 or len(monthly_rejections) == sum(1 for x in monthly_rejections if x >= 0)]
    # Actually, keep all rejections including zeros for accurate stats
    valid_rejections = monthly_rejections

    if len(valid_rejections) < 2:
        return np.nan, np.nan, monthly_f1, monthly_rejections

    # ΔRej: Mean(actual) - target (bias)
    mean_rej = float(np.mean(valid_rejections))
    delta_rej = mean_rej - bm  # Positive = over-rejecting

    # σ[Rej]: Standard deviation
    sigma_rej = float(np.std(valid_rejections))

    return delta_rej, sigma_rej, monthly_f1, monthly_rejections


def compute_perf_rejection(val_p, val_l, val_u, test_p, test_l, test_u):
    """
    Compute performance-based rejection metrics.

    Returns:
        dict with:
        - mae_fnr: Mean Absolute Error from FNR target (%)
        - fnr_coverage: Fraction of targets achieved (%)
        - sigma_rej: σ[Rej%] - Std dev of rejection rate across months (pp)
    """
    result = {'mae_fnr': np.nan, 'fnr_coverage': 0, 'sigma_rej': np.nan}
    if len(val_p) == 0 or len(test_p) == 0:
        return result

    sim = PerformanceRejectionSimulator()
    fnr_maes = []
    all_rejection_rates = []  # Collect rejection rates across all targets

    for target in FNR_TARGETS:
        try:
            thresh = sim.calibrate_threshold(val_p, val_l, val_u, target_metric="FNR", target_value=target)
            monthly = sim.simulate_deployment(test_p, test_l, test_u, threshold=thresh, target_metric="FNR", target_value=target)
            metrics = sim.compute_adherence_metrics(monthly, target_metric="FNR", target_value=target)
            if not np.isnan(metrics.mae):
                fnr_maes.append(metrics.mae)
                # Collect rejection rates (1 - coverage) for each month
                monthly_rej_rates = [1 - m.coverage for m in monthly if not np.isnan(m.coverage)]
                all_rejection_rates.extend(monthly_rej_rates)
        except:
            pass

    if fnr_maes:
        result['mae_fnr'] = float(np.mean(fnr_maes)) * 100
        result['fnr_coverage'] = len(fnr_maes) / len(FNR_TARGETS) * 100

    # Compute σ[Rej%]: standard deviation of rejection rate in percentage points
    if all_rejection_rates:
        result['sigma_rej'] = float(np.std(all_rejection_rates)) * 100  # Convert to pp

    return result


# ============================================================================
# TABLE GENERATION
# ============================================================================

def generate_table(collection, rejection_budgets, variant_name):
    print(f"\n{'='*80}\nGENERATING TABLE: {variant_name}\nBudgets: {rejection_budgets}\n{'='*80}")

    datasets = sorted(collection.get_unique_values('dataset')['dataset'])
    base_names = [m for m in METHOD_ORDER if m in collection.get_unique_values('base_name')['base_name']]
    sampler_modes = [s for s in collection.get_unique_values('sampler_mode')['sampler_mode'] if s in SAMPLER_MODE_TO_B0]

    rows = []
    # Storage for Pareto analysis: (bm, b0) -> dataset -> method -> metrics
    pareto_data = defaultdict(lambda: defaultdict(dict))

    total = len(datasets) * len(BUDGETS) * len(base_names) * len(sampler_modes)
    pbar = tqdm(total=total, desc="Computing")

    for bm in BUDGETS:
        for sm in sampler_modes:
            b0 = SAMPLER_MODE_TO_B0[sm]

            for bn in base_names:
                # Skip CADE methods for subsampled (B_0=4800) - no data exists
                if b0 == 4800 and bn in EXCLUDE_FOR_SUBSAMPLED:
                    # Still need to update progress bar
                    for _ in datasets:
                        pbar.update(1)
                    continue

                row = {'B_M': bm, 'B_0': b0, 'Method': bn, 'Method_Short': METHOD_SHORT.get(bn, bn)}

                for ds in datasets:
                    pbar.set_description(f"{ds[:4]}|B_M={bm}|{bn[:15]}")
                    abbrev = DATASET_ABBREV[ds]

                    hp = select_best_hyperparams(collection, ds, bm, bn, sm)
                    results = get_results_with_hp(collection, ds, bm, bn, sm, hp)

                    # Initialize metrics - includes AUF1★ (Area Under F1-Risk Curve) and FNR
                    # Removed md_star, added auf1_star and mean_fnr
                    # Added fnr_star: FNR under selective classification at B_M
                    for m in ['mean_f1', 'mean_fnr', 'sigma_f1', 'min_f1', 'tau', 'aurc', 'auroc', 'auc_f1_star', 'auf1_star', 'bf', 'fnr_star', 'delta_rej', 'sigma_rej']:
                        row[f'{abbrev}_{m}'] = np.nan

                    if not results:
                        pbar.update(1)
                        continue

                    try:
                        all_p, all_l, all_u = aggregate_across_seeds(results)
                    except:
                        pbar.update(1)
                        continue

                    if len(all_p) <= VALIDATION_MONTHS:
                        pbar.update(1)
                        continue

                    val_p = np.concatenate(all_p[:VALIDATION_MONTHS])
                    val_l = np.concatenate(all_l[:VALIDATION_MONTHS])
                    val_u = np.concatenate(all_u[:VALIDATION_MONTHS])
                    test_p, test_l, test_u = all_p[VALIDATION_MONTHS:], all_l[VALIDATION_MONTHS:], all_u[VALIDATION_MONTHS:]

                    # Baseline metrics (F1, FNR, σ[F1], min[F1], τ)
                    monthly_metrics = []
                    for p, l in zip(test_p, test_l):
                        if len(p) > 0:
                            f1, fnr, fpr = compute_metrics_numba(l, p)
                            monthly_metrics.append({'f1': float(f1), 'fnr': float(fnr), 'fpr': float(fpr)})

                    if monthly_metrics:
                        monthly_f1 = [m['f1'] for m in monthly_metrics]
                        monthly_fnr = [m['fnr'] for m in monthly_metrics]

                        suite = StabilityMetrics.compute_stability_suite(monthly_f1)
                        row[f'{abbrev}_mean_f1'] = np.mean(monthly_f1) * 100
                        row[f'{abbrev}_mean_fnr'] = np.mean(monthly_fnr) * 100  # Add FNR
                        row[f'{abbrev}_sigma_f1'] = suite['sigma'] * 100
                        row[f'{abbrev}_min_f1'] = suite['min'] * 100
                        row[f'{abbrev}_tau'] = suite['mann_kendall_tau']

                        # Store for Pareto analysis
                        pareto_data[(bm, b0)][ds][bn] = {
                            'mean_f1': np.mean(monthly_f1),
                            'sigma_f1': suite['sigma'],
                            'min_f1': suite['min'],
                            'mann_kendall_tau': suite['mann_kendall_tau'],
                        }

                    # AURC (multiply by 100 for display as percentage)
                    try:
                        aurc_raw = float(compute_aurc(test_l, test_p, test_u, compute_eaurc=False))
                        row[f'{abbrev}_aurc'] = aurc_raw * 100  # Convert to percentage
                    except:
                        pass

                    # AUROC - Area Under ROC Curve (discrimination ability)
                    # Computed by pooling all test months together
                    try:
                        all_labels = np.concatenate(test_l)
                        all_preds = np.concatenate(test_p)
                        all_uncs = np.concatenate(test_u)
                        auroc = compute_auroc_from_predictions_and_uncertainties(all_labels, all_preds, all_uncs)
                        row[f'{abbrev}_auroc'] = auroc * 100 if not np.isnan(auroc) else np.nan  # Convert to percentage
                    except:
                        pass

                    # AUC[F1*] with proper threshold calibration
                    # Calibrate on validation data, apply to test months
                    # Now includes B=0 (baseline) in integration following paper's Figure 4
                    auc_f1, budget_f1, budget_monthly_f1 = compute_auc_f1_star(
                        val_p, val_l, val_u,  # Validation data for calibration
                        test_p, test_l, test_u,  # Test data for evaluation
                        rejection_budgets,
                        include_baseline=True  # Include B=0 per paper methodology
                    )
                    row[f'{abbrev}_auc_f1_star'] = auc_f1 * 100 if not np.isnan(auc_f1) else np.nan

                    # AURC[F1]*: Area Under F1-Risk Coverage Curve (operational setting)
                    # Coverage-based integration from c=1.0 (no rejection) to c=0.05 (95% rejection)
                    # Uses F1-based risk instead of error-based risk
                    # Probes at coverage = 1.0, 0.95, 0.90, ..., 0.05 (20 points)
                    auf1_star, _ = compute_auf1_star(
                        val_p, val_l, val_u,
                        test_p, test_l, test_u,
                        coverage_min=0.05,
                        coverage_step=0.05
                    )
                    row[f'{abbrev}_auf1_star'] = auf1_star * 100 if not np.isnan(auf1_star) else np.nan

                    # ================================================================
                    # BUDGET-BASED METRICS AT B_M EXACTLY
                    # ================================================================
                    # BF*, ΔRej, σ[Rej], FNR* are all computed at the operating point B_M
                    # This creates a coherent story: "What happens when we defer B_M samples?"

                    # Compute rejection stats at B_M - this gives us:
                    # - ΔRej: bias (mean rejections - B_M)
                    # - σ[Rej]: variability of rejection counts
                    # - monthly_f1: F1* per month (for other metrics)
                    delta_rej, sigma_rej, monthly_f1_at_bm, monthly_rejections = compute_rejection_stats_at_bm(
                        val_p, val_l, val_u,
                        test_p, test_l, test_u,
                        bm  # Use B_M exactly!
                    )
                    row[f'{abbrev}_delta_rej'] = delta_rej
                    row[f'{abbrev}_sigma_rej'] = sigma_rej

                    # FNR* at B_M: What's the FNR after selective classification?
                    # Calibrate threshold on validation, apply to test, compute FNR on accepted samples
                    fnr_star = compute_fnr_star_at_bm(
                        val_p, val_l, val_u,
                        test_p, test_l, test_u,
                        bm
                    )
                    row[f'{abbrev}_fnr_star'] = fnr_star * 100 if not np.isnan(fnr_star) else np.nan

                    # BF* at B_M: Does SC help at our exact operating point?
                    bf = compute_bf_at_bm(
                        val_p, val_l, val_u,
                        test_p, test_l, test_u,
                        bm  # Use B_M exactly!
                    )
                    row[f'{abbrev}_bf'] = bf if not np.isnan(bf) else np.nan

                    pbar.update(1)

                rows.append(row)

    pbar.close()

    # Compute Pareto status for each (bm, b0) group
    pareto_status = compute_pareto_status(pareto_data, datasets)

    # Add Pareto status to rows
    for row in rows:
        key = (row['B_M'], row['B_0'], row['Method'])
        row['pareto_status'] = pareto_status.get(key, 'dominated')

    return pd.DataFrame(rows)


def compute_pareto_status(pareto_data, datasets):
    """
    Compute Pareto status for each method within each B_M group.

    IMPORTANT: Compares ALL methods across both B_0 values within a B_M group,
    treating each (B_0, method) as a unique competitor. This ensures proper
    Pareto discrimination (comparing ~10 methods, not just 4-6).

    Returns dict: (bm, b0, method) -> 'universal' | 'partial' | 'dominated'
    """
    status = {}

    # First, reorganize data by B_M (combining all B_0 values)
    bm_data = defaultdict(lambda: defaultdict(dict))  # bm -> ds -> unique_method -> metrics
    method_to_key = {}  # unique_method -> (b0, original_method)

    for (bm, b0), ds_data in pareto_data.items():
        for ds, methods in ds_data.items():
            for method, metrics in methods.items():
                # Create unique method name that includes B_0
                b0_label = "D0" if b0 == "D_0" else str(b0)
                unique_method = f"{method}|{b0_label}"
                bm_data[bm][ds][unique_method] = metrics
                method_to_key[unique_method] = (b0, method)

    # Now compute Pareto for each B_M group (all methods together)
    for bm, ds_data in bm_data.items():
        analyzer = ParetoAnalyzer(metrics=['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau'])

        # Add data for each dataset and unique method
        for ds, methods in ds_data.items():
            for unique_method, metrics in methods.items():
                analyzer.add_precomputed_metrics(unique_method, ds, metrics)

        # Analyze across datasets
        try:
            result = analyzer.analyze_universal()

            # Map back to original (bm, b0, method) keys
            for unique_method in result.universal_optimal:
                b0, method = method_to_key[unique_method]
                status[(bm, b0, method)] = 'universal'

            for unique_method in result.partial_optimal.keys():
                b0, method = method_to_key[unique_method]
                status[(bm, b0, method)] = 'partial'

            for unique_method in result.always_dominated:
                b0, method = method_to_key[unique_method]
                status[(bm, b0, method)] = 'dominated'

        except Exception as e:
            print(f"  Warning: Pareto analysis failed for B_M={bm}: {e}")
            continue

    return status


def get_display_precision(metric):
    """Get the number of decimal places used for displaying each metric."""
    # These match the format strings used in export_latex()
    if metric in ['mean_f1', 'mean_fnr', 'fnr_star', 'auc_f1_star', 'bf', 'sigma_rej', 'delta_rej']:
        return 0  # Displayed as integers (σ[Rej] and ΔRej now shown as ±X format)
    elif metric in ['aurc', 'auf1_star', 'auroc']:
        return 1  # One decimal place (AUROC displayed as percentage with 1 decimal)
    elif metric in ['sigma_f1']:
        return 1  # One decimal place
    elif metric in ['tau']:
        return 2  # Two decimal places
    return 1  # Default


def round_to_display(value, metric):
    """Round a value to its display precision for fair comparison."""
    if pd.isna(value):
        return np.nan
    precision = get_display_precision(metric)
    return round(value, precision)


def compute_best_worst_values(df, bm):
    """
    Compute best AND worst values for each metric/dataset within a B_M group.
    Uses ROUNDED (display) values to properly detect ties.

    Returns:
        best: dict (metric, dataset) -> rounded best value
        worst: dict (metric, dataset) -> rounded worst value

    Metric directions:
    - mean_f1, auroc, tau, bf: higher is better (max=best, min=worst)
    - aurc, auf1_star, sigma_f1, sigma_rej: lower is better (min=best, max=worst)
    - delta_rej: closer to 0 is better (min |value| = best, max |value| = worst)
    """
    df_bm = df[df['B_M'] == bm]
    datasets = ['azoo', 'apig', 'trans']

    # Define which metrics are "higher is better"
    higher_is_better = {'mean_f1', 'auroc', 'tau', 'bf'}
    lower_is_better = {'mean_fnr', 'fnr_star', 'aurc', 'auf1_star', 'sigma_f1', 'sigma_rej'}
    closest_to_zero = {'delta_rej'}  # Special case: closest to 0 is best
    # Note: auf1_star (AURC[F1]*) is "lower is better" (it's F1-based risk, like AURC)
    # Note: mean_fnr and fnr_star are "lower is better" (lower false negative rate is better)
    # Note: auroc is "higher is better" (better discrimination)

    best = {}
    worst = {}

    for metric in ['mean_f1', 'mean_fnr', 'fnr_star', 'aurc', 'auroc', 'auf1_star', 'bf', 'sigma_f1', 'tau', 'sigma_rej', 'delta_rej']:
        for ds in datasets:
            col = f'{ds}_{metric}'
            if col not in df_bm.columns:
                continue
            raw_values = df_bm[col].dropna()
            if len(raw_values) == 0:
                continue

            # Round values to display precision for comparison
            rounded_values = [round_to_display(v, metric) for v in raw_values]
            rounded_values = [v for v in rounded_values if not pd.isna(v)]

            if len(rounded_values) == 0:
                continue

            if metric in higher_is_better:
                best[(metric, ds)] = max(rounded_values)
                worst[(metric, ds)] = min(rounded_values)
            elif metric in closest_to_zero:
                # Best = closest to 0 (smallest |value|), Worst = farthest from 0 (largest |value|)
                abs_values = [(abs(v), v) for v in rounded_values]
                best[(metric, ds)] = min(abs_values, key=lambda x: x[0])[1]  # Original value closest to 0
                worst[(metric, ds)] = max(abs_values, key=lambda x: x[0])[1]  # Original value farthest from 0
            else:  # lower is better
                best[(metric, ds)] = min(rounded_values)
                worst[(metric, ds)] = max(rounded_values)

    return best, worst


def is_best_or_worst(value, target_value, metric):
    """
    Check if a value equals the target (best or worst) when rounded to display precision.
    This ensures ALL tied values are highlighted, not just the first occurrence.

    For delta_rej: compares absolute values (e.g., -5 and +5 are both "best" if |5| is smallest)
    """
    if pd.isna(value) or pd.isna(target_value):
        return False
    rounded_value = round_to_display(value, metric)

    # For delta_rej, compare absolute values for ties (e.g., -5 and +5 are equally close to 0)
    if metric == 'delta_rej':
        return abs(rounded_value) == abs(target_value)

    return rounded_value == target_value


def format_value_styled(value, best_value, worst_value, metric, fmt):
    """
    Format a value with styling:
    - Bold if it's the best
    - Red if it's the worst (but not bold)
    - Plain otherwise
    """
    if pd.isna(value):
        return "---"

    formatted = fmt.format(value)

    if is_best_or_worst(value, best_value, metric):
        return r"\textbf{" + formatted + "}"
    elif is_best_or_worst(value, worst_value, metric):
        return r"\textcolor{red}{" + formatted + "}"
    return formatted


def export_latex(df, output_file, variant_name, rejection_budgets):
    """Export in exact format of paper_table_comprehensive_v2_standalone.tex with bold best values."""
    budget_str = ', '.join(map(str, rejection_budgets))
    datasets = ['azoo', 'apig', 'trans']
    ds_short = ['AZ', 'AP', 'TR']

    lines = [
        r"\documentclass[11pt]{article}",
        r"\usepackage[margin=0.5in]{geometry}",
        r"\usepackage{booktabs}",
        r"\usepackage{multirow}",
        r"\usepackage{colortbl}",
        r"\usepackage{xcolor}",
        r"\usepackage{pdflscape}",
        r"\usepackage{amssymb}",
        r"",
        r"% Define colors (matching paper: reliability=green, stability=blue)",
        r"\definecolor{reliablecolor}{RGB}{230,245,220}",  # Light green (paper's reliablecolortable)
        r"\definecolor{stablecolor}{RGB}{235,240,250}",    # Light blue (paper's stablecolortable)
        r"",
        r"\begin{document}",
        r"\pagestyle{empty}",
        r"",
        f"% Aurora Comprehensive Results - {variant_name.upper()}",
        f"% AURC[F1]* computed over coverage spectrum c ∈ [0.5, 1.0] in 5% steps",
        f"% BF*, ΔRej, σ[Rej] computed at B_M exactly",
        r"",
        r"\begin{landscape}",
        r"\begin{table}[p]",
        r"  \centering",
        r"  \caption{%",
        r"    Comprehensive evaluation across monthly budgets $B_M$.",
        r"    \textbf{Baseline}: F1=detection rate; FNR=false negative rate; AUROC=discrimination ability (higher=better).",
        r"    \textbf{Reliability} (green): AURC=uncertainty calibration (error-based risk, lower=better);",
        r"    AURC[F1]$^*$=Area Under F1-Risk Coverage Curve (F1-based risk integrated over coverage $c \in [0.05, 1.0]$, lower=better).",
        r"    \textbf{Stability} (blue): $\sigma$[F1]=F1 volatility over time; $\tau$=Mann-Kendall trend ($>$0=improving);",
        r"    BF$^*$=Benefit Fraction at $B_M$ (\% of SC impact that improves F1; 50\%=break-even, 100\%=always helps);",
        r"    $\Delta$Rej$^*$=rejection bias at $B_M$ (closest to 0=best); $\sigma$[Rej]$^*$=rejection std at $B_M$ (lower=better).",
        r"    $B_0$: initial training ($\mathcal{D}_0$=full year, 4800=subsampled).",
        r"    \textbf{P}: Pareto status ($\bigstar$=universal, $\triangle$=partial, $\circ$=dominated).",
        r"    \textbf{Bold}=best, \textcolor{red}{red}=worst per metric-dataset within each $B_M$.",
        r"  }",
        f"  \\label{{tab:comprehensive-{variant_name}}}",
        r"  \scriptsize",
        r"  \setlength{\tabcolsep}{2pt}",
        r"  \renewcommand{\arraystretch}{1.05}",
        r"",
        # Column specification with colors:
        # cols 1-3: B_M, B_0, Method (no color)
        # cols 4-12: Baseline F1, FNR, AUROC (no color) - 9 cols
        # cols 13-18: Reliability AURC, AURC[F1]* (green) - 6 cols
        # cols 19-33: Stability σ[F1], τ, BF*, ΔRej*, σ[Rej]* (blue) - 15 cols
        # col 34: P (no color)
        r"  \begin{tabular}{@{}cc l | ccc | ccc | ccc | >{\columncolor{reliablecolor}}c >{\columncolor{reliablecolor}}c >{\columncolor{reliablecolor}}c | >{\columncolor{reliablecolor}}c >{\columncolor{reliablecolor}}c >{\columncolor{reliablecolor}}c | >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c | >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c | >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c | >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c | >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c >{\columncolor{stablecolor}}c | c @{}}",
        r"    \toprule",
        r"",
        r"    & & & \multicolumn{9}{c|}{\textbf{Baseline}}",
        r"        & \multicolumn{6}{c|}{\cellcolor{reliablecolor}\textbf{Reliability}}",
        r"        & \multicolumn{15}{c|}{\cellcolor{stablecolor}\textbf{Stability}} & \\",
        r"    \cmidrule(lr){4-12} \cmidrule(lr){13-18} \cmidrule(lr){19-33}",
        r"",
        r"    & & & \multicolumn{3}{c|}{F1 (\%)$\uparrow$}",
        r"        & \multicolumn{3}{c|}{FNR (\%)$\downarrow$}",
        r"        & \multicolumn{3}{c|}{AUROC (\%)$\uparrow$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{reliablecolor}AURC$\downarrow$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{reliablecolor}AURC[F1]$^*\downarrow$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{stablecolor}$\sigma$[F1]$\downarrow$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{stablecolor}$\tau$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{stablecolor}BF$^*$(\%)$\uparrow$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{stablecolor}$\Delta$Rej$^*$}",
        r"        & \multicolumn{3}{c|}{\cellcolor{stablecolor}$\sigma$[Rej]$^*\downarrow$}",
        r"        & \\",
        r"",
        "    $B_M$ & $B_0$ & Method"
        + "".join([f" & \\tiny {s}" for s in ds_short] * 10)
        + r" & P \\",
        r"",
        r"    \midrule",
    ]

    # Group by B_M
    for bm in BUDGETS:
        df_bm = df[df['B_M'] == bm]

        # Compute best AND worst values for this B_M group (using rounded display values)
        best, worst = compute_best_worst_values(df, bm)

        lines.append("")
        lines.append(f"    % ========== B_M = {bm} ==========")
        n_rows = len(df_bm)
        lines.append(f"    \\multirow{{{n_rows}}}{{*}}{{\\rotatebox{{90}}{{\\textbf{{$B_M={bm}$}}}}}}")

        # First D_0, then 4800
        for b0 in ["D_0", 4800]:
            df_b0 = df_bm[df_bm['B_0'] == b0]
            if df_b0.empty:
                continue

            if b0 == 4800 and not df_bm[df_bm['B_0'] == "D_0"].empty:
                lines.append(r"    \cmidrule{2-34}")

            for _, row in df_b0.iterrows():
                b0_str = r"$\mathcal{D}_0$" if b0 == "D_0" else str(b0)
                method = row['Method_Short']

                line = f"    & {b0_str} & {method}"

                # F1 (higher is better)
                for ds in datasets:
                    v = row.get(f'{ds}_mean_f1', np.nan)
                    best_v = best.get(('mean_f1', ds), np.nan)
                    worst_v = worst.get(('mean_f1', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.0f}"
                        if is_best_or_worst(v, best_v, 'mean_f1'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'mean_f1'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # FNR (lower is better)
                for ds in datasets:
                    v = row.get(f'{ds}_mean_fnr', np.nan)
                    best_v = best.get(('mean_fnr', ds), np.nan)
                    worst_v = worst.get(('mean_fnr', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.0f}"
                        if is_best_or_worst(v, best_v, 'mean_fnr'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'mean_fnr'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # AUROC (higher is better) - discrimination ability (in Baseline)
                for ds in datasets:
                    v = row.get(f'{ds}_auroc', np.nan)
                    best_v = best.get(('auroc', ds), np.nan)
                    worst_v = worst.get(('auroc', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.1f}"
                        if is_best_or_worst(v, best_v, 'auroc'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'auroc'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # AURC (lower is better)
                for ds in datasets:
                    v = row.get(f'{ds}_aurc', np.nan)
                    best_v = best.get(('aurc', ds), np.nan)
                    worst_v = worst.get(('aurc', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.1f}"
                        if is_best_or_worst(v, best_v, 'aurc'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'aurc'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # AUF1★ (lower is better) - F1-based risk across full rejection spectrum
                for ds in datasets:
                    v = row.get(f'{ds}_auf1_star', np.nan)
                    best_v = best.get(('auf1_star', ds), np.nan)
                    worst_v = worst.get(('auf1_star', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.1f}"
                        if is_best_or_worst(v, best_v, 'auf1_star'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'auf1_star'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # σ[F1] (lower is better)
                for ds in datasets:
                    v = row.get(f'{ds}_sigma_f1', np.nan)
                    best_v = best.get(('sigma_f1', ds), np.nan)
                    worst_v = worst.get(('sigma_f1', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.1f}"
                        if is_best_or_worst(v, best_v, 'sigma_f1'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'sigma_f1'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # τ (higher is better)
                for ds in datasets:
                    v = row.get(f'{ds}_tau', np.nan)
                    best_v = best.get(('tau', ds), np.nan)
                    worst_v = worst.get(('tau', ds), np.nan)
                    if not pd.isna(v):
                        if v >= 0:
                            formatted = f"{v:.2f}"
                        else:
                            formatted = f"$-${abs(v):.2f}"
                        if is_best_or_worst(v, best_v, 'tau'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'tau'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # BF (Benefit Fraction - higher is better) - now in Stability section
                for ds in datasets:
                    v = row.get(f'{ds}_bf', np.nan)
                    best_v = best.get(('bf', ds), np.nan)
                    worst_v = worst.get(('bf', ds), np.nan)
                    if not pd.isna(v):
                        formatted = f"{v:.0f}"
                        if is_best_or_worst(v, best_v, 'bf'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(v, worst_v, 'bf'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # ΔRej* (closest to 0 is best)
                for ds in datasets:
                    delta = row.get(f'{ds}_delta_rej', np.nan)
                    best_delta = best.get(('delta_rej', ds), np.nan)
                    worst_delta = worst.get(('delta_rej', ds), np.nan)

                    if not pd.isna(delta):
                        sign = "+" if delta >= 0 else ""
                        formatted = f"{sign}{delta:.0f}"
                        if is_best_or_worst(delta, best_delta, 'delta_rej'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(delta, worst_delta, 'delta_rej'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # σ[Rej]* (lower is better)
                for ds in datasets:
                    sigma = row.get(f'{ds}_sigma_rej', np.nan)
                    best_sigma = best.get(('sigma_rej', ds), np.nan)
                    worst_sigma = worst.get(('sigma_rej', ds), np.nan)

                    if not pd.isna(sigma):
                        formatted = f"{sigma:.0f}"
                        if is_best_or_worst(sigma, best_sigma, 'sigma_rej'):
                            line += f" & \\textbf{{{formatted}}}"
                        elif is_best_or_worst(sigma, worst_sigma, 'sigma_rej'):
                            line += f" & \\textcolor{{red}}{{{formatted}}}"
                        else:
                            line += f" & {formatted}"
                    else:
                        line += " & ---"

                # Pareto status
                pareto = row.get('pareto_status', 'dominated')
                if pareto == 'universal':
                    line += r" & $\bigstar$ \\"
                elif pareto == 'partial':
                    line += r" & $\triangle$ \\"
                else:
                    line += r" & $\circ$ \\"
                lines.append(line)

        if bm < 400:
            lines.append(r"    \midrule")

    lines.extend([
        r"",
        r"    \bottomrule",
        r"  \end{tabular}",
        r"",
        r"  \vspace{4pt}",
        r"  {\scriptsize",
        f"    AZ=AndroZoo, AP=API-Graph, TR=Transcendent. PL=Pseudo-Loss.\\\\",
        f"    Rejection budgets $B \\in \\{{{budget_str}\\}}$. ``---'' = unavailable.",
        r"  }",
        r"\end{table}",
        r"\end{landscape}",
        r"",
        r"\end{document}",
    ])

    with open(output_file, 'w') as f:
        f.write('\n'.join(lines))
    print(f"✅ LaTeX saved to: {output_file}")


# ============================================================================
# MAIN
# ============================================================================

def main():
    global EXCLUDE_PROBLEMATIC_SEEDS

    print("\n" + "=" * 80)
    print("AURORA COMPREHENSIVE RESULTS GENERATION - NEW TRANSCENDENT DATA")
    print("=" * 80)
    print(f"Mirror budgets:   {REJECTION_BUDGETS_MIRROR}")
    print(f"Operable budgets: {REJECTION_BUDGETS_OPERABLE}")
    print("=" * 80)
    print("NOTE: This script uses the NEW (correct) Transcendent dataset")
    print("      for DeepDrebin and Drebin methods.")
    print("=" * 80)

    # Generate REGULAR variants (all seeds)
    EXCLUDE_PROBLEMATIC_SEEDS = False
    collection = load_all_data()
    if not collection:
        print("❌ Failed to load data!")
        return

    variants = [
        ("mirror_new_data", REJECTION_BUDGETS_MIRROR),
        ("operable_new_data", REJECTION_BUDGETS_OPERABLE),
    ]

    for variant_name, budgets in variants:
        df = generate_table(collection, budgets, variant_name)

        # Save CSV
        csv_file = OUTPUT_DIR / f"comprehensive_results_{variant_name}.csv"
        df.to_csv(csv_file, index=False)
        print(f"✅ CSV saved to: {csv_file}")

        # Save LaTeX
        tex_file = OUTPUT_DIR / f"comprehensive_results_{variant_name}.tex"
        export_latex(df, tex_file, variant_name, budgets)

        # Summary
        print(f"\n--- {variant_name.upper()} Summary ---")
        for ds in ['azoo', 'apig', 'trans']:
            if f'{ds}_auf1_star' in df.columns:
                mean_auc = df[f'{ds}_auf1_star'].mean()
                mean_fnr_star = df[f'{ds}_fnr_star'].mean()
                print(f"  {ds}: AURC[F1]*={mean_auc:.1f}%, FNR*={mean_fnr_star:.1f}%")

    # Generate CLEAN SEEDS variant (excluding problematic CADE-OOD seeds)
    print("\n" + "=" * 80)
    print("GENERATING CLEAN SEEDS VARIANT")
    print("=" * 80)
    EXCLUDE_PROBLEMATIC_SEEDS = True
    collection_clean = load_all_data()

    if collection_clean:
        variant_name = "operable_clean_seeds"
        budgets = REJECTION_BUDGETS_OPERABLE
        df = generate_table(collection_clean, budgets, variant_name)

        # Save CSV
        csv_file = OUTPUT_DIR / f"comprehensive_results_{variant_name}.csv"
        df.to_csv(csv_file, index=False)
        print(f"✅ CSV saved to: {csv_file}")

        # Save LaTeX
        tex_file = OUTPUT_DIR / f"comprehensive_results_{variant_name}.tex"
        export_latex(df, tex_file, variant_name, budgets)

        # Summary
        print(f"\n--- {variant_name.upper()} Summary ---")
        for ds in ['azoo', 'apig', 'trans']:
            if f'{ds}_auf1_star' in df.columns:
                mean_auc = df[f'{ds}_auf1_star'].mean()
                mean_fnr_star = df[f'{ds}_fnr_star'].mean()
                print(f"  {ds}: AURC[F1]*={mean_auc:.1f}%, FNR*={mean_fnr_star:.1f}%")

    print("\n✅ ALL VARIANTS COMPLETE")


if __name__ == "__main__":
    main()
