#!/usr/bin/env python3
"""
Produce Performance-Based Rejection Analysis Tables

This script runs performance-based rejection analysis on experimental data,
generating tables showing target adherence metrics (MAE, HitRate, Coverage)
and reliability diagrams for FNR/FPR targets.

Outputs:
  - rejection_analysis.csv: Full results
  - rejection_analysis.xlsx: Multi-sheet Excel with tables
  - rejection_reliability.tex: LaTeX table for paper
  - rejection_reliability_fnr.png: FNR reliability diagram
  - rejection_reliability_fpr.png: FPR reliability diagram

Usage:
  cd <repository-root>
  python examples/reproduce_paper_table/produce_rejection_table.py
"""

import sys
sys.path.insert(0, 'src')

import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional
from collections import defaultdict
from datetime import datetime
from tqdm import tqdm

# Aurora imports
from aurora import (
    PickleResultsLoader,
    JSONResultsLoader,
    DataQualityChecker,
    create_cutoff_month_filter,
    create_hyperparameter_filter,
    combine_collections,
    expand_results_with_ncms,
    ResultsCollection,
    ExperimentResult,
)
from aurora.performance_rejection import (
    PerformanceRejectionSimulator,
    TargetAdherenceMetrics,
    ReliabilityDiagramData,
    run_performance_rejection_analysis,
)

# ============================================================================
# CONFIGURATION
# ============================================================================

# Directories
RESULTS_DIR = Path("data-for-export/deep_drebin_svc")
OTHER_RESULTS_DIR = Path("data-for-export/others_v2")
OUTPUT_DIR = Path("examples/reproduce_paper_table/results")

# Target grids (revised based on analysis)
# Note: Rejection can only REDUCE error rates, not increase them.
# - Most methods have baseline FPR ~0.5-1% (except CADE-cold-MSP with severe issues)
# - FPR targets span range to capture both well-calibrated and problematic methods
# - FNR targets are appropriate since baseline FNR is typically higher
FNR_TARGETS = [0.01, 0.02, 0.05, 0.10, 0.15, 0.20]
FPR_TARGETS = [0.005, 0.01, 0.02, 0.05]  # 0.5%, 1%, 2%, 5% - meaningful range

# Validation period: First 6 months of test period
VALIDATION_MONTHS = 6

# Which budgets to analyze
BUDGETS_TO_ANALYZE = [50, 100, 200, 400]


# ============================================================================
# DATA LOADING HELPERS
# ============================================================================

def make_np_arrays(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Convert list fields to numpy arrays."""
    for item in results:
        for key, value in item.items():
            if isinstance(value, list):
                item[key] = np.array(value)
    return results


def deep_drebin_mapper(result_dict: Dict[str, Any]) -> str:
    trainer = result_dict.get("Trainer-Mode", "Unknown")
    if trainer == "CE" or trainer == "DeepDrebin":
        return "DeepDrebin (cold) - MSP"
    return trainer


def svc_mapper(result_dict: Dict[str, Any]) -> str:
    return "SVC - Margin"


def hcc_mapper(result_dict: Dict[str, Any]) -> str:
    trainer = result_dict.get("Trainer-Mode", "Unknown")
    if "HCC" in trainer:
        return trainer
    return trainer


def cade_mapper(result_dict: Dict[str, Any]) -> str:
    trainer = result_dict.get("Trainer-Mode", "Unknown")
    if "CADE" in trainer:
        return trainer
    return trainer


# ============================================================================
# DATA LOADING (Simplified from load_all_aurora_data.py)
# ============================================================================

def load_all_data() -> Tuple[Optional[ResultsCollection], Dict[str, ResultsCollection]]:
    """
    Load all experimental data.

    Returns:
        (all_results, individual_collections)
    """
    print("=" * 80)
    print("LOADING EXPERIMENTAL DATA FOR REJECTION ANALYSIS")
    print("=" * 80)

    common_filters = [
        create_cutoff_month_filter(),
        create_hyperparameter_filter("Num-Epochs", [10], exclude=True),
    ]

    all_collections = []
    individual = {}

    # Deep Drebin
    print("\n[1/4] Loading Deep Drebin...")
    loader_dd = PickleResultsLoader(
        base_name_mapper=deep_drebin_mapper,
        filters=common_filters,
        auto_validate=False
    )
    dd_file = RESULTS_DIR / "parallel_ce_no_aug_v2.pkl"
    if dd_file.exists():
        deep_drebin = loader_dd.load(dd_file, experiment_name="DeepDrebin")
        print(f"  Loaded {len(deep_drebin)} results")
        all_collections.append(deep_drebin)
        individual['DeepDrebin'] = deep_drebin
    else:
        print("  File not found!")

    # SVC
    print("\n[2/4] Loading SVC...")
    loader_svc = PickleResultsLoader(
        base_name_mapper=svc_mapper,
        filters=common_filters,
        auto_validate=False
    )
    svc_file = RESULTS_DIR / "parallel_svc_v2.pkl"
    if svc_file.exists():
        svc = loader_svc.load(svc_file, experiment_name="SVC")
        for result in svc.results:
            result.trainer_mode = "Drebin (cold) - Margin"
            result.base_name = "SVC - Margin"
        print(f"  Loaded {len(svc)} results")
        all_collections.append(svc)
        individual['SVC'] = svc
    else:
        print("  File not found!")

    # HCC
    print("\n[3/4] Loading HCC...")
    hcc_files = [
        "hcc_mlp_warm-androzoo.json",
        "hcc_mlp_warm-apigraph.json",
        "hcc_mlp_warm-transcendent.json",
        "hcc_mlp_warm-androzoo-subsampling.json",
        "hcc_mlp_warm-apigraph-subsampling.json",
        "hcc_mlp_warm-transcendent-subsampling.json",
    ]

    all_hcc_results = []
    for filename in hcc_files:
        file_path = OTHER_RESULTS_DIR / filename
        if file_path.exists():
            with open(file_path, "r") as f:
                all_hcc_results.extend(json.load(f))

    if all_hcc_results:
        all_hcc_results = make_np_arrays(all_hcc_results)
        cutoff_filter = create_cutoff_month_filter()
        all_hcc_results = [r for r in all_hcc_results if cutoff_filter(r)]

        # Fix sampler-mode typo
        for item in all_hcc_results:
            if item.get("Sampler-Mode") == "subsampled_first_year_subsample_months":
                item["Sampler-Mode"] = "subsample_first_year_subsample_months"
            if "Seed" in item:
                item["Random-Seed"] = item["Seed"]
                del item["Seed"]

        # NCM expansion
        hcc_expanded = expand_results_with_ncms(
            all_hcc_results,
            trainer_mode_key="Trainer-Mode",
            uncertainty_key="Uncertainties (Month Ahead)",
            clean_temp_fields=True
        )

        from aurora import ExperimentMetadata
        loader_hcc = JSONResultsLoader(
            base_name_mapper=hcc_mapper,
            filters=[],
            auto_validate=False,
            rename_seed_field=False
        )

        # Create collection manually
        from datetime import datetime
        metadata = ExperimentMetadata(
            experiment_name="HCC",
            source_file="(multiple)",
            load_timestamp=datetime.now().isoformat(),
        )
        results = []
        for raw in hcc_expanded:
            try:
                result = loader_hcc._convert_result(raw)
                results.append(result)
            except:
                pass

        hcc_combined = ResultsCollection(metadata=metadata, results=results)
        print(f"  Loaded {len(hcc_combined)} results (after NCM expansion)")
        all_collections.append(hcc_combined)
        individual['HCC'] = hcc_combined

    # CADE
    print("\n[4/4] Loading CADE...")
    cade_files = [
        "cade_mlp_cold-androzoo.json",
        "cade_mlp_cold-apigraph.json",
        "cade_mlp_cold-transcendent.json",
        "cade_mlp_warm-androzoo.json",
        "cade_mlp_warm-apigraph.json",
        "cade_mlp_warm-transcendent.json",
    ]

    all_cade_results = []
    for filename in cade_files:
        file_path = OTHER_RESULTS_DIR / filename
        if file_path.exists():
            with open(file_path, "r") as f:
                all_cade_results.extend(json.load(f))

    if all_cade_results:
        all_cade_results = make_np_arrays(all_cade_results)
        cutoff_filter = create_cutoff_month_filter()
        all_cade_results = [r for r in all_cade_results if cutoff_filter(r)]

        cade_expanded = expand_results_with_ncms(
            all_cade_results,
            trainer_mode_key="Trainer-Mode",
            uncertainty_key="Uncertainties (Month Ahead)",
            clean_temp_fields=True
        )

        loader_cade = JSONResultsLoader(
            base_name_mapper=cade_mapper,
            filters=[],
            auto_validate=False,
            rename_seed_field=True
        )

        metadata = ExperimentMetadata(
            experiment_name="CADE",
            source_file="(multiple)",
            load_timestamp=datetime.now().isoformat(),
        )
        results = []
        for raw in cade_expanded:
            try:
                result = loader_cade._convert_result(raw)
                results.append(result)
            except:
                pass

        cade_combined = ResultsCollection(metadata=metadata, results=results)
        print(f"  Loaded {len(cade_combined)} results (after NCM expansion)")
        all_collections.append(cade_combined)
        individual['CADE'] = cade_combined

    # Combine all
    if all_collections:
        all_results = combine_collections(*all_collections)
        print(f"\n{'='*80}")
        print(f"TOTAL: {len(all_results)} results loaded")
        print(f"Datasets: {sorted(all_results.get_unique_values('dataset')['dataset'])}")
        print(f"Budgets: {sorted(all_results.get_unique_values('monthly_label_budget')['monthly_label_budget'])}")
        print(f"Base names: {len(all_results.get_unique_values('base_name')['base_name'])}")
        return all_results, individual

    return None, {}


# ============================================================================
# DATA EXTRACTION FOR REJECTION ANALYSIS
# ============================================================================

def extract_monthly_data(
    collection: ResultsCollection,
    dataset: str,
    base_name: str,
    budget: int,
    sampler_mode: str = "subsample_first_year_subsample_months"
) -> Tuple[Optional[List[np.ndarray]], Optional[List[np.ndarray]], Optional[List[np.ndarray]], List[int]]:
    """
    Extract monthly predictions, labels, uncertainties for a specific configuration.

    Returns:
        (predictions_by_month, labels_by_month, uncertainties_by_month, months)
        or (None, None, None, []) if not found
    """
    # Filter results
    matching = [
        r for r in collection.results
        if r.dataset == dataset
        and r.base_name == base_name
        and r.monthly_label_budget == budget
        and r.sampler_mode == sampler_mode
    ]

    if not matching:
        return None, None, None, []

    # Group by month, aggregate across seeds
    by_month = defaultdict(list)
    for r in matching:
        by_month[r.test_month].append(r)

    months = sorted(by_month.keys())

    predictions_by_month = []
    labels_by_month = []
    uncertainties_by_month = []

    for month in months:
        results_for_month = by_month[month]

        # Take first seed's data (or could average across seeds)
        r = results_for_month[0]

        # Get uncertainties - prefer "past month" for rejection
        if r.uncertainties_past_month is not None and len(r.uncertainties_past_month) > 0:
            uncs = r.uncertainties_past_month
        elif r.uncertainties is not None and len(r.uncertainties) > 0:
            uncs = r.uncertainties
        else:
            # Skip this month if no uncertainties
            continue

        predictions_by_month.append(r.predictions)
        labels_by_month.append(r.labels)
        uncertainties_by_month.append(uncs)

    if not predictions_by_month:
        return None, None, None, []

    return predictions_by_month, labels_by_month, uncertainties_by_month, months


def split_validation_test(
    predictions_by_month: List[np.ndarray],
    labels_by_month: List[np.ndarray],
    uncertainties_by_month: List[np.ndarray],
    validation_months: int = 6
) -> Tuple[
    np.ndarray, np.ndarray, np.ndarray,  # Validation (pooled)
    List[np.ndarray], List[np.ndarray], List[np.ndarray]  # Test (by month)
]:
    """
    Split data into validation (pooled) and test (by month).

    Args:
        validation_months: Number of months to use for validation/calibration

    Returns:
        (val_preds, val_labels, val_uncs, test_preds, test_labels, test_uncs)
    """
    # Validation: Pool first N months
    val_preds = np.concatenate(predictions_by_month[:validation_months])
    val_labels = np.concatenate(labels_by_month[:validation_months])
    val_uncs = np.concatenate(uncertainties_by_month[:validation_months])

    # Test: Remaining months
    test_preds = predictions_by_month[validation_months:]
    test_labels = labels_by_month[validation_months:]
    test_uncs = uncertainties_by_month[validation_months:]

    return val_preds, val_labels, val_uncs, test_preds, test_labels, test_uncs


# ============================================================================
# MAIN ANALYSIS
# ============================================================================

def run_rejection_analysis(
    collection: ResultsCollection,
    fnr_targets: List[float] = FNR_TARGETS,
    fpr_targets: List[float] = FPR_TARGETS,
    validation_months: int = VALIDATION_MONTHS,
    budgets: List[int] = BUDGETS_TO_ANALYZE,
) -> pd.DataFrame:
    """
    Run performance-based rejection analysis on all configurations.

    Returns:
        DataFrame with adherence metrics for each configuration/target
    """
    print("\n" + "=" * 80)
    print("RUNNING PERFORMANCE-BASED REJECTION ANALYSIS")
    print("=" * 80)
    print(f"FNR targets: {fnr_targets}")
    print(f"FPR targets: {fpr_targets}")
    print(f"Validation months: {validation_months}")
    print(f"Budgets to analyze: {budgets}")

    # Get unique configurations
    datasets = sorted(collection.get_unique_values('dataset')['dataset'])
    base_names = sorted(collection.get_unique_values('base_name')['base_name'])

    # Sampler modes to test
    sampler_modes = [
        "subsample_first_year_subsample_months",
        "full_first_year_subsample_months",
    ]

    all_results = []
    simulator = PerformanceRejectionSimulator()

    # Count total configurations
    total_configs = 0
    for dataset in datasets:
        for base_name in base_names:
            for budget in budgets:
                for sampler in sampler_modes:
                    preds, labels, uncs, months = extract_monthly_data(
                        collection, dataset, base_name, budget, sampler
                    )
                    if preds is not None and len(months) > validation_months:
                        total_configs += 1

    print(f"\nTotal configurations to analyze: {total_configs}")

    # Process each configuration
    with tqdm(total=total_configs, desc="Analyzing") as pbar:
        for dataset in datasets:
            for base_name in base_names:
                for budget in budgets:
                    for sampler in sampler_modes:
                        # Extract monthly data
                        preds, labels, uncs, months = extract_monthly_data(
                            collection, dataset, base_name, budget, sampler
                        )

                        if preds is None or len(months) <= validation_months:
                            continue

                        # Split validation/test
                        try:
                            val_preds, val_labels, val_uncs, test_preds, test_labels, test_uncs = \
                                split_validation_test(preds, labels, uncs, validation_months)
                        except Exception as e:
                            pbar.update(1)
                            continue

                        if len(test_preds) == 0:
                            pbar.update(1)
                            continue

                        # Determine B_0 from sampler mode
                        if sampler == "subsample_first_year_subsample_months":
                            b0 = 4800
                        elif sampler == "full_first_year_subsample_months":
                            b0 = "Full"
                        else:
                            b0 = "0"

                        # Run FNR target analysis
                        for target in fnr_targets:
                            try:
                                threshold = simulator.calibrate_threshold(
                                    val_preds, val_labels, val_uncs,
                                    target_metric="FNR", target_value=target
                                )

                                monthly_results = simulator.simulate_deployment(
                                    test_preds, test_labels, test_uncs,
                                    threshold=threshold,
                                    target_metric="FNR", target_value=target
                                )

                                metrics = simulator.compute_adherence_metrics(
                                    monthly_results,
                                    target_metric="FNR", target_value=target
                                )

                                all_results.append({
                                    'dataset': dataset,
                                    'base_name': base_name,
                                    'B_M': budget,
                                    'B_0': b0,
                                    'sampler': sampler,
                                    'target_metric': 'FNR',
                                    'target_value': target,
                                    'mae': metrics.mae,
                                    'bias': metrics.bias,
                                    'rmse': metrics.rmse,
                                    'hit_rate_1pct': metrics.hit_rate_1pct,
                                    'hit_rate_2pct': metrics.hit_rate_2pct,
                                    'hit_rate_5pct': metrics.hit_rate_5pct,
                                    'mean_coverage': metrics.mean_coverage,
                                    'cv_coverage': metrics.cv_coverage,
                                    'min_coverage': metrics.min_coverage,
                                    'mann_kendall_tau': metrics.mann_kendall_tau,
                                    'mann_kendall_p': metrics.mann_kendall_p,
                                    'max_deviation': metrics.max_deviation,
                                    'n_test_months': len(test_preds),
                                    'threshold': threshold,
                                })
                            except Exception as e:
                                pass  # Skip failures

                        # Run FPR target analysis
                        for target in fpr_targets:
                            try:
                                threshold = simulator.calibrate_threshold(
                                    val_preds, val_labels, val_uncs,
                                    target_metric="FPR", target_value=target
                                )

                                monthly_results = simulator.simulate_deployment(
                                    test_preds, test_labels, test_uncs,
                                    threshold=threshold,
                                    target_metric="FPR", target_value=target
                                )

                                metrics = simulator.compute_adherence_metrics(
                                    monthly_results,
                                    target_metric="FPR", target_value=target
                                )

                                all_results.append({
                                    'dataset': dataset,
                                    'base_name': base_name,
                                    'B_M': budget,
                                    'B_0': b0,
                                    'sampler': sampler,
                                    'target_metric': 'FPR',
                                    'target_value': target,
                                    'mae': metrics.mae,
                                    'bias': metrics.bias,
                                    'rmse': metrics.rmse,
                                    'hit_rate_1pct': metrics.hit_rate_1pct,
                                    'hit_rate_2pct': metrics.hit_rate_2pct,
                                    'hit_rate_5pct': metrics.hit_rate_5pct,
                                    'mean_coverage': metrics.mean_coverage,
                                    'cv_coverage': metrics.cv_coverage,
                                    'min_coverage': metrics.min_coverage,
                                    'mann_kendall_tau': metrics.mann_kendall_tau,
                                    'mann_kendall_p': metrics.mann_kendall_p,
                                    'max_deviation': metrics.max_deviation,
                                    'n_test_months': len(test_preds),
                                    'threshold': threshold,
                                })
                            except Exception as e:
                                pass

                        pbar.update(1)

    return pd.DataFrame(all_results)


# ============================================================================
# TABLE GENERATION
# ============================================================================

def create_summary_table(df: pd.DataFrame, target_metric: str, target_value: float) -> pd.DataFrame:
    """
    Create summary table for a specific target.

    Columns: Method, Dataset metrics (MAE, HitRate, Coverage)
    """
    subset = df[(df['target_metric'] == target_metric) & (df['target_value'] == target_value)]

    if subset.empty:
        return pd.DataFrame()

    # Pivot: rows = method, columns = dataset
    pivot = subset.pivot_table(
        index=['base_name', 'B_M', 'B_0'],
        columns='dataset',
        values=['mae', 'hit_rate_2pct', 'mean_coverage'],
        aggfunc='mean'
    )

    return pivot


def create_latex_table(df: pd.DataFrame, target_metric: str = "FNR") -> str:
    """
    Create LaTeX table for paper.

    Shows MAE and HitRate@2% for key targets.
    """
    # Focus on B_M=200, subsample mode
    subset = df[
        (df['target_metric'] == target_metric) &
        (df['B_M'] == 200) &
        (df['B_0'] == 4800)
    ]

    if subset.empty:
        return "% No data for specified configuration"

    # Pivot by method and target
    pivot = subset.pivot_table(
        index='base_name',
        columns=['dataset', 'target_value'],
        values=['mae', 'hit_rate_2pct'],
        aggfunc='mean'
    )

    # Build LaTeX
    latex = []
    latex.append(r"\begin{table}[htbp]")
    latex.append(r"\centering")
    latex.append(r"\caption{Performance-Based Rejection: Target Adherence for " + target_metric + r" ($B_M = 200$)}")
    latex.append(r"\label{tab:rejection-" + target_metric.lower() + r"}")
    latex.append(r"\small")

    # Header
    datasets = sorted(subset['dataset'].unique())
    targets = sorted(subset['target_value'].unique())[:3]  # Top 3 targets

    ncols = 1 + len(datasets) * len(targets)
    latex.append(r"\begin{tabular}{l" + "c" * (ncols - 1) + r"}")
    latex.append(r"\toprule")

    # Dataset headers
    header1 = "Method"
    for ds in datasets:
        header1 += f" & \\multicolumn{{{len(targets)}}}{{c}}{{{ds}}}"
    header1 += r" \\"
    latex.append(header1)

    # Target headers
    header2 = ""
    for ds in datasets:
        for t in targets:
            header2 += f" & {t*100:.0f}\\%"
    header2 += r" \\"
    latex.append(r"\cmidrule(lr){2-" + str(ncols) + r"}")
    latex.append(header2)
    latex.append(r"\midrule")

    # Data rows
    for method in sorted(pivot.index):
        row = method.replace("_", r"\_")
        for ds in datasets:
            for t in targets:
                try:
                    mae = pivot.loc[method, ('mae', ds, t)]
                    hit = pivot.loc[method, ('hit_rate_2pct', ds, t)]
                    row += f" & {mae*100:.1f}/{hit*100:.0f}"
                except:
                    row += " & --"
        row += r" \\"
        latex.append(row)

    latex.append(r"\bottomrule")
    latex.append(r"\end{tabular}")
    latex.append(r"\vspace{1mm}")
    latex.append(r"\footnotesize")
    latex.append(r"Values show MAE (\%) / HitRate@2\% (\%). Lower MAE = better target adherence.")
    latex.append(r"\end{table}")

    return "\n".join(latex)


# ============================================================================
# VISUALIZATION
# ============================================================================

def plot_reliability_diagram(
    df: pd.DataFrame,
    target_metric: str,
    output_path: Path,
    methods_to_plot: Optional[List[str]] = None
):
    """
    Plot reliability diagram for target adherence.

    X-axis: Target value
    Y-axis: Observed (actual) value
    Diagonal = perfect calibration
    """
    try:
        import matplotlib.pyplot as plt
        import matplotlib.cm as cm
    except ImportError:
        print("matplotlib not available, skipping plot")
        return

    subset = df[df['target_metric'] == target_metric]

    if subset.empty:
        print(f"No data for {target_metric}")
        return

    # Average across datasets for cleaner plot
    avg_by_method = subset.groupby(['base_name', 'target_value']).agg({
        'mae': 'mean',
        'bias': 'mean',
        'mean_coverage': 'mean',
    }).reset_index()

    # Compute observed = target + bias
    avg_by_method['observed'] = avg_by_method['target_value'] + avg_by_method['bias']

    methods = sorted(avg_by_method['base_name'].unique())
    if methods_to_plot:
        methods = [m for m in methods_to_plot if m in methods]

    # Plot
    fig, ax = plt.subplots(figsize=(8, 8))

    colors = cm.tab10(np.linspace(0, 1, len(methods)))

    for i, method in enumerate(methods):
        method_data = avg_by_method[avg_by_method['base_name'] == method].sort_values('target_value')

        ax.plot(
            method_data['target_value'] * 100,
            method_data['observed'] * 100,
            'o-',
            label=method,
            color=colors[i],
            markersize=6,
            alpha=0.8
        )

    # Diagonal (perfect calibration)
    targets = sorted(avg_by_method['target_value'].unique())
    ax.plot(
        [min(targets) * 100, max(targets) * 100],
        [min(targets) * 100, max(targets) * 100],
        'k--',
        label='Perfect Calibration',
        linewidth=2
    )

    ax.set_xlabel(f'Target {target_metric} (%)', fontsize=12)
    ax.set_ylabel(f'Observed {target_metric} (%)', fontsize=12)
    ax.set_title(f'Reliability Diagram: {target_metric} Target Adherence', fontsize=14)
    ax.legend(loc='best', fontsize=8)
    ax.grid(True, alpha=0.3)

    # Equal aspect ratio
    ax.set_aspect('equal', adjustable='box')

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

    print(f"Saved: {output_path}")


# ============================================================================
# MAIN
# ============================================================================

def main():
    """Main entry point."""
    print("\n" + "=" * 80)
    print("PERFORMANCE-BASED REJECTION ANALYSIS")
    print("=" * 80)
    print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    # Load data
    all_results, individual = load_all_data()

    if all_results is None:
        print("ERROR: No data loaded!")
        return

    # Run analysis
    results_df = run_rejection_analysis(all_results)

    print(f"\n{'='*80}")
    print(f"ANALYSIS COMPLETE")
    print(f"{'='*80}")
    print(f"Total result rows: {len(results_df)}")

    if results_df.empty:
        print("WARNING: No results generated!")
        return

    # Summary statistics
    print(f"\nSummary by target metric:")
    for metric in ['FNR', 'FPR']:
        subset = results_df[results_df['target_metric'] == metric]
        print(f"  {metric}: {len(subset)} rows, avg MAE = {subset['mae'].mean()*100:.2f}%")

    # Save outputs
    OUTPUT_DIR.mkdir(exist_ok=True)

    # CSV
    csv_path = OUTPUT_DIR / "rejection_analysis.csv"
    results_df.to_csv(csv_path, index=False)
    print(f"\nSaved: {csv_path}")

    # Excel with multiple sheets (optional - requires openpyxl)
    excel_path = OUTPUT_DIR / "rejection_analysis.xlsx"
    try:
        with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
            results_df.to_excel(writer, sheet_name='Full Results', index=False)

            # Summary tables for key targets
            for target in [0.05, 0.10]:
                fnr_summary = create_summary_table(results_df, 'FNR', target)
                if not fnr_summary.empty:
                    fnr_summary.to_excel(writer, sheet_name=f'FNR_{int(target*100)}pct')

                fpr_summary = create_summary_table(results_df, 'FPR', target)
                if not fpr_summary.empty:
                    fpr_summary.to_excel(writer, sheet_name=f'FPR_{int(target*100)}pct')

        print(f"Saved: {excel_path}")
    except ImportError:
        print(f"⚠️  Excel export skipped (install openpyxl for Excel support)")

    # LaTeX table
    latex_fnr = create_latex_table(results_df, "FNR")
    latex_fpr = create_latex_table(results_df, "FPR")

    latex_path = OUTPUT_DIR / "rejection_table.tex"
    with open(latex_path, 'w') as f:
        f.write("% Performance-Based Rejection Tables\n")
        f.write("% Generated: " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + "\n\n")
        f.write("% FNR Target Adherence\n")
        f.write(latex_fnr)
        f.write("\n\n")
        f.write("% FPR Target Adherence\n")
        f.write(latex_fpr)

    print(f"Saved: {latex_path}")

    # Reliability diagrams
    print("\nGenerating reliability diagrams...")
    plot_reliability_diagram(
        results_df, 'FNR',
        OUTPUT_DIR / "rejection_reliability_fnr.png"
    )
    plot_reliability_diagram(
        results_df, 'FPR',
        OUTPUT_DIR / "rejection_reliability_fpr.png"
    )

    # Pickle for further analysis
    pickle_path = OUTPUT_DIR / "rejection_analysis.pkl"
    with open(pickle_path, 'wb') as f:
        pickle.dump({
            'results_df': results_df,
            'fnr_targets': FNR_TARGETS,
            'fpr_targets': FPR_TARGETS,
            'timestamp': datetime.now().isoformat(),
        }, f)
    print(f"Saved: {pickle_path}")

    print(f"\n{'='*80}")
    print("ALL OUTPUTS GENERATED SUCCESSFULLY")
    print(f"{'='*80}")


if __name__ == "__main__":
    main()
