#!/usr/bin/env python3
"""
Publication-Ready Reliability Diagrams for Performance-Based Rejection

Creates dense, paper-ready figures for FNR and FPR target adherence.
"""

import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional, List, Tuple

try:
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    from matplotlib.ticker import MaxNLocator
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams['font.size'] = 8
    plt.rcParams['axes.labelsize'] = 9
    plt.rcParams['axes.titlesize'] = 10
    plt.rcParams['legend.fontsize'] = 7
    plt.rcParams['xtick.labelsize'] = 8
    plt.rcParams['ytick.labelsize'] = 8
except ImportError:
    print("matplotlib not available")
    exit(1)

# ============================================================================
# CONFIGURATION
# ============================================================================

OUTPUT_DIR = Path("examples/reproduce_paper_table/results")
DATA_FILE = OUTPUT_DIR / "rejection_analysis.csv"

# Paper figure sizes (in inches)
SINGLE_COLUMN_WIDTH = 3.5  # IEEE single column
DOUBLE_COLUMN_WIDTH = 7.0  # IEEE double column
GOLDEN_RATIO = 1.618

# Color scheme (colorblind-friendly)
COLORS = {
    'DeepDrebin (cold) - MSP': '#0072B2',  # Blue
    'SVC - Margin': '#009E73',              # Green
    'HCC (warm) - MSP': '#D55E00',          # Orange
    'HCC (warm) - Pseudo-Loss': '#CC79A7',  # Pink
    'CADE (cold) - MSP': '#E69F00',         # Yellow
    'CADE (cold) - OOD': '#56B4E9',         # Light blue
    'CADE (warm) - MSP': '#F0E442',         # Light yellow
    'CADE (warm) - OOD': '#999999',         # Gray
}

# Short names for legend
SHORT_NAMES = {
    'DeepDrebin (cold) - MSP': 'DeepDrebin',
    'SVC - Margin': 'SVC',
    'HCC (warm) - MSP': 'HCC-MSP',
    'HCC (warm) - Pseudo-Loss': 'HCC-PL',
    'CADE (cold) - MSP': 'CADE-c-MSP',
    'CADE (cold) - OOD': 'CADE-c-OOD',
    'CADE (warm) - MSP': 'CADE-w-MSP',
    'CADE (warm) - OOD': 'CADE-w-OOD',
}

# Methods to exclude from plots (outliers)
EXCLUDE_METHODS = ['CADE (cold) - MSP', 'CADE (warm) - MSP']  # Severe calibration issues


# ============================================================================
# PLOTTING FUNCTIONS
# ============================================================================

def load_data() -> pd.DataFrame:
    """Load rejection analysis data."""
    df = pd.read_csv(DATA_FILE)
    # Compute observed value
    df['observed'] = df['target_value'] + df['bias']
    return df


def aggregate_by_method(df: pd.DataFrame, target_metric: str) -> pd.DataFrame:
    """Aggregate across datasets and budgets for cleaner plot."""
    subset = df[df['target_metric'] == target_metric]

    agg = subset.groupby(['base_name', 'target_value']).agg({
        'mae': ['mean', 'std'],
        'bias': ['mean', 'std'],
        'observed': ['mean', 'std'],
        'hit_rate_2pct': ['mean'],
        'mean_coverage': ['mean'],
    }).reset_index()

    # Flatten column names
    agg.columns = ['_'.join(col).strip('_') for col in agg.columns.values]

    return agg


def plot_combined_reliability(
    df: pd.DataFrame,
    output_path: Path,
    exclude_outliers: bool = True,
    figsize: Tuple[float, float] = (DOUBLE_COLUMN_WIDTH, 2.8)
):
    """
    Create combined FNR + FPR reliability diagram (2-panel figure).

    Paper-ready format with dense layout.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    for idx, (ax, metric) in enumerate(zip(axes, ['FNR', 'FPR'])):
        agg = aggregate_by_method(df, metric)

        methods = sorted(agg['base_name'].unique())
        if exclude_outliers:
            methods = [m for m in methods if m not in EXCLUDE_METHODS]

        for method in methods:
            method_data = agg[agg['base_name'] == method].sort_values('target_value')

            color = COLORS.get(method, '#333333')
            label = SHORT_NAMES.get(method, method)

            ax.plot(
                method_data['target_value'] * 100,
                method_data['observed_mean'] * 100,
                'o-',
                label=label,
                color=color,
                markersize=4,
                linewidth=1.2,
                alpha=0.85
            )

        # Diagonal (perfect calibration)
        targets = sorted(agg['target_value'].unique())
        max_target = max(targets) * 100
        ax.plot([0, max_target], [0, max_target], 'k--', linewidth=1, alpha=0.5)

        ax.set_xlabel(f'Target {metric} (%)')
        ax.set_ylabel(f'Observed {metric} (%)')
        ax.set_title(f'({chr(97+idx)}) {metric} Target Adherence')

        # Set axis limits appropriately
        if metric == 'FNR':
            ax.set_xlim(-0.5, 22)
            ax.set_ylim(-0.5, 22)
        else:  # FPR
            ax.set_xlim(-0.2, 6)
            ax.set_ylim(-0.2, 6)

        ax.grid(True, alpha=0.3, linewidth=0.5)
        ax.set_aspect('equal', adjustable='box')

        if idx == 1:  # Legend only on right panel
            ax.legend(loc='upper left', framealpha=0.9, ncol=1)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_path}")


def plot_reliability_with_outliers(
    df: pd.DataFrame,
    output_path: Path,
    figsize: Tuple[float, float] = (DOUBLE_COLUMN_WIDTH, 3.2)
):
    """
    Create combined reliability diagram INCLUDING outliers (CADE-MSP).
    Uses broken axis or log scale to show extreme values.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    for idx, (ax, metric) in enumerate(zip(axes, ['FNR', 'FPR'])):
        agg = aggregate_by_method(df, metric)
        methods = sorted(agg['base_name'].unique())

        # Separate outliers
        outlier_methods = [m for m in methods if m in EXCLUDE_METHODS]
        normal_methods = [m for m in methods if m not in EXCLUDE_METHODS]

        # Plot normal methods
        for method in normal_methods:
            method_data = agg[agg['base_name'] == method].sort_values('target_value')
            color = COLORS.get(method, '#333333')
            label = SHORT_NAMES.get(method, method)

            ax.plot(
                method_data['target_value'] * 100,
                method_data['observed_mean'] * 100,
                'o-',
                label=label,
                color=color,
                markersize=4,
                linewidth=1.2,
                alpha=0.85
            )

        # Plot outliers with different style
        for method in outlier_methods:
            method_data = agg[agg['base_name'] == method].sort_values('target_value')
            color = COLORS.get(method, '#FF0000')
            label = SHORT_NAMES.get(method, method) + ' *'

            ax.plot(
                method_data['target_value'] * 100,
                method_data['observed_mean'] * 100,
                's--',  # Square markers, dashed line
                label=label,
                color=color,
                markersize=4,
                linewidth=1.2,
                alpha=0.7
            )

        # Diagonal
        targets = sorted(agg['target_value'].unique())
        max_val = max(agg['observed_mean'].max() * 100, max(targets) * 100)
        ax.plot([0, max_val], [0, max_val], 'k--', linewidth=1, alpha=0.5)

        ax.set_xlabel(f'Target {metric} (%)')
        ax.set_ylabel(f'Observed {metric} (%)')
        ax.set_title(f'({chr(97+idx)}) {metric} Target Adherence')

        ax.grid(True, alpha=0.3, linewidth=0.5)

        if idx == 1:
            ax.legend(loc='upper left', framealpha=0.9, ncol=1, fontsize=6)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_path}")


def plot_per_dataset_reliability(
    df: pd.DataFrame,
    output_path: Path,
    target_metric: str = 'FPR',
    exclude_outliers: bool = True,
    figsize: Tuple[float, float] = (DOUBLE_COLUMN_WIDTH, 2.2)
):
    """
    Create per-dataset reliability diagram (3-panel facet).
    """
    datasets = sorted(df['dataset'].unique())
    fig, axes = plt.subplots(1, len(datasets), figsize=figsize, sharey=True)

    subset = df[df['target_metric'] == target_metric]

    for idx, (ax, dataset) in enumerate(zip(axes, datasets)):
        ds_data = subset[subset['dataset'] == dataset]

        # Aggregate by method
        agg = ds_data.groupby(['base_name', 'target_value']).agg({
            'observed': 'mean',
        }).reset_index()

        methods = sorted(agg['base_name'].unique())
        if exclude_outliers:
            methods = [m for m in methods if m not in EXCLUDE_METHODS]

        for method in methods:
            method_data = agg[agg['base_name'] == method].sort_values('target_value')
            color = COLORS.get(method, '#333333')
            label = SHORT_NAMES.get(method, method)

            ax.plot(
                method_data['target_value'] * 100,
                method_data['observed'] * 100,
                'o-',
                label=label if idx == 0 else None,
                color=color,
                markersize=3,
                linewidth=1,
                alpha=0.8
            )

        # Diagonal
        targets = sorted(agg['target_value'].unique())
        max_t = max(targets) * 100
        ax.plot([0, max_t], [0, max_t], 'k--', linewidth=0.8, alpha=0.5)

        ax.set_xlabel(f'Target {target_metric} (%)')
        if idx == 0:
            ax.set_ylabel(f'Observed {target_metric} (%)')
        ax.set_title(dataset.capitalize())
        ax.grid(True, alpha=0.3, linewidth=0.5)
        ax.set_aspect('equal', adjustable='box')

    # Shared legend at bottom
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', ncol=6, framealpha=0.9,
               bbox_to_anchor=(0.5, -0.15), fontsize=6)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_path}")


def plot_mae_heatmap(
    df: pd.DataFrame,
    output_path: Path,
    target_metric: str = 'FPR',
    figsize: Tuple[float, float] = (SINGLE_COLUMN_WIDTH, 3.0)
):
    """
    Create MAE heatmap: Methods x Targets.
    """
    subset = df[df['target_metric'] == target_metric]

    # Aggregate
    agg = subset.groupby(['base_name', 'target_value']).agg({
        'mae': 'mean'
    }).reset_index()

    # Pivot to matrix
    pivot = agg.pivot(index='base_name', columns='target_value', values='mae')
    pivot = pivot * 100  # Convert to percentage

    # Rename for display
    pivot.index = [SHORT_NAMES.get(m, m) for m in pivot.index]
    pivot.columns = [f'{t*100:.1f}%' for t in pivot.columns]

    fig, ax = plt.subplots(figsize=figsize)

    im = ax.imshow(pivot.values, cmap='RdYlGn_r', aspect='auto')

    # Ticks
    ax.set_xticks(np.arange(len(pivot.columns)))
    ax.set_yticks(np.arange(len(pivot.index)))
    ax.set_xticklabels(pivot.columns)
    ax.set_yticklabels(pivot.index)

    # Annotate
    for i in range(len(pivot.index)):
        for j in range(len(pivot.columns)):
            val = pivot.values[i, j]
            text_color = 'white' if val > 15 else 'black'
            ax.text(j, i, f'{val:.1f}', ha='center', va='center',
                   fontsize=6, color=text_color)

    ax.set_xlabel(f'Target {target_metric}')
    ax.set_ylabel('Method')
    ax.set_title(f'MAE (%) by Method and {target_metric} Target')

    # Colorbar
    cbar = plt.colorbar(im, ax=ax, shrink=0.8)
    cbar.set_label('MAE (%)')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_path}")


def plot_hitrate_comparison(
    df: pd.DataFrame,
    output_path: Path,
    figsize: Tuple[float, float] = (DOUBLE_COLUMN_WIDTH, 2.5)
):
    """
    Bar chart comparing HitRate@2% across methods for FNR vs FPR.
    """
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    for idx, (ax, metric) in enumerate(zip(axes, ['FNR', 'FPR'])):
        subset = df[df['target_metric'] == metric]

        # Aggregate: mean HitRate@2% per method
        agg = subset.groupby('base_name').agg({
            'hit_rate_2pct': 'mean'
        }).reset_index()

        agg = agg.sort_values('hit_rate_2pct', ascending=True)
        agg['short_name'] = agg['base_name'].map(SHORT_NAMES)

        colors = [COLORS.get(m, '#333333') for m in agg['base_name']]

        bars = ax.barh(agg['short_name'], agg['hit_rate_2pct'] * 100, color=colors, alpha=0.8)

        ax.set_xlabel('HitRate@2% (%)')
        ax.set_title(f'({chr(97+idx)}) {metric} Target HitRate')
        ax.set_xlim(0, 100)
        ax.grid(True, axis='x', alpha=0.3, linewidth=0.5)

        # Add value labels
        for bar, val in zip(bars, agg['hit_rate_2pct']):
            ax.text(val * 100 + 1, bar.get_y() + bar.get_height()/2,
                   f'{val*100:.0f}%', va='center', fontsize=6)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_path}")


def create_summary_table(df: pd.DataFrame) -> pd.DataFrame:
    """
    Create summary table for paper: MAE and HitRate by method.
    """
    results = []

    for metric in ['FNR', 'FPR']:
        subset = df[df['target_metric'] == metric]

        agg = subset.groupby('base_name').agg({
            'mae': 'mean',
            'hit_rate_2pct': 'mean',
            'mean_coverage': 'mean',
        }).reset_index()

        for _, row in agg.iterrows():
            results.append({
                'Method': SHORT_NAMES.get(row['base_name'], row['base_name']),
                'Target': metric,
                'MAE (%)': f"{row['mae']*100:.1f}",
                'HitRate@2% (%)': f"{row['hit_rate_2pct']*100:.0f}",
                'Coverage (%)': f"{row['mean_coverage']*100:.1f}",
            })

    return pd.DataFrame(results)


# ============================================================================
# MAIN
# ============================================================================

def main():
    print("=" * 60)
    print("GENERATING PUBLICATION-READY REJECTION DIAGRAMS")
    print("=" * 60)

    # Load data
    df = load_data()
    print(f"Loaded {len(df)} rows")

    # Generate plots
    print("\n[1/6] Combined reliability diagram (excluding outliers)...")
    plot_combined_reliability(
        df,
        OUTPUT_DIR / "fig_rejection_reliability_clean.pdf",
        exclude_outliers=True
    )
    plot_combined_reliability(
        df,
        OUTPUT_DIR / "fig_rejection_reliability_clean.png",
        exclude_outliers=True
    )

    print("\n[2/6] Combined reliability diagram (with outliers)...")
    plot_reliability_with_outliers(
        df,
        OUTPUT_DIR / "fig_rejection_reliability_full.pdf"
    )
    plot_reliability_with_outliers(
        df,
        OUTPUT_DIR / "fig_rejection_reliability_full.png"
    )

    print("\n[3/6] Per-dataset FPR reliability diagram...")
    plot_per_dataset_reliability(
        df,
        OUTPUT_DIR / "fig_rejection_fpr_by_dataset.pdf",
        target_metric='FPR'
    )

    print("\n[4/6] MAE heatmaps...")
    plot_mae_heatmap(df, OUTPUT_DIR / "fig_rejection_mae_fpr.pdf", 'FPR')
    plot_mae_heatmap(df, OUTPUT_DIR / "fig_rejection_mae_fnr.pdf", 'FNR')

    print("\n[5/6] HitRate comparison...")
    plot_hitrate_comparison(df, OUTPUT_DIR / "fig_rejection_hitrate.pdf")
    plot_hitrate_comparison(df, OUTPUT_DIR / "fig_rejection_hitrate.png")

    print("\n[6/6] Summary table...")
    summary = create_summary_table(df)
    print(summary.to_string(index=False))
    summary.to_csv(OUTPUT_DIR / "rejection_summary_table.csv", index=False)

    print("\n" + "=" * 60)
    print("DONE - All figures saved to research-and-paper/")
    print("=" * 60)


if __name__ == "__main__":
    main()
