#!/usr/bin/env python3
"""
Produce Pareto-Based Multi-Objective Evaluation Table

This script produces the comprehensive evaluation table with Pareto analysis,
computing all Four Pillars metrics and determining Pareto frontier status.

The Four Pillars Framework:
- R (Reliability):  Mean F1 - "How well does it detect?"
- S (Stability):    σ[F1] - "How consistent is it?" (NOT CV[F1])
- T (Tail Risk):    Min[F1] - "What's the worst case?"
- τ (Temporal):     Mann-Kendall tau - "Is it degrading over time?"

Output Format:
- All metrics for all B_M values (50, 100, 200, 400)
- Pareto status (Frontier/Dominated) per dataset
- Universal Pareto optimality across datasets
"""

import sys
sys.path.insert(0, 'src')

import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from collections import defaultdict
from tqdm import tqdm

# Import Aurora framework
from aurora import (
    AuroraAnalyzer,
    ResultsCollection,
    ExperimentResult,
    ReliabilityMetrics,
    StabilityMetrics,
    DrawdownMetrics,
    ParetoAnalyzer,
    ParetoResult,
    UniversalParetoResult,
)

# Import data loading from load_all_aurora_data
from load_all_aurora_data import all_results

# Import metrics computation from tools
from aurora import compute_metrics_numba, compute_aurc


# ============================================================================
# CONFIGURATION
# ============================================================================

# Output directory for generated files
OUTPUT_DIR = Path("examples/reproduce_paper_table/results")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Number of months for validation (hyperparameter selection)
VALIDATION_MONTHS = 6

# Dataset names and abbreviations
DATASET_ABBREV = {
    "androzoo": "azoo",
    "apigraph": "apig",
    "transcendent": "trans",
}

DATASET_FULL_NAMES = {
    "azoo": "AndroZoo",
    "apig": "API-Graph",
    "trans": "Transcendent",
}

# Method ordering for table (Pareto frontier methods first)
METHOD_ORDER = [
    "DeepDrebin (cold) - MSP",
    "HCC (warm) - Pseudo-Loss",
    "HCC (warm) - MSP",
    "Drebin (cold) - Margin",
    "CADE (cold) - OOD",
    "CADE (cold) - MSP",
    "CADE (warm) - OOD",
    "CADE (warm) - MSP",
]

# Budget values to include (B_M dimension)
BUDGETS = [50, 100, 200, 400]

# Sampler modes to include (B_0 dimension)
SAMPLER_MODE_TO_B0 = {
    "full_first_year_subsample_months": "D_0",
    "subsample_first_year_subsample_months": 4800,
    "zero_first_year_subsample_months": 0,
}

# Four Pillars metrics
FOUR_PILLARS = ['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau']


# ============================================================================
# HELPER FUNCTIONS
# ============================================================================

def compute_monthly_f1_values(predictions: List[np.ndarray],
                              labels: List[np.ndarray]) -> List[float]:
    """
    Compute F1 for each month.

    Returns:
        List of F1 values, one per month
    """
    f1_list = []

    for preds, labs in zip(predictions, labels):
        f1, _, _ = compute_metrics_numba(labs, preds)
        f1_list.append(float(f1))

    return f1_list


def compute_four_pillars(monthly_f1: List[float]) -> Dict[str, float]:
    """
    Compute all Four Pillars metrics from monthly F1 values.

    Args:
        monthly_f1: List of F1 scores, one per month

    Returns:
        Dict with:
        - mean_f1: Average F1 (Reliability)
        - sigma_f1: Standard deviation of F1 (Stability)
        - min_f1: Minimum F1 (Tail Risk)
        - mann_kendall_tau: Mann-Kendall tau (Temporal)
        - Additional: cv_f1, max_f1, p5_f1 for reference
    """
    suite = StabilityMetrics.compute_stability_suite(monthly_f1)

    return {
        # Four Pillars
        'mean_f1': float(np.mean(monthly_f1)),
        'sigma_f1': suite['sigma'],
        'min_f1': suite['min'],
        'mann_kendall_tau': suite['mann_kendall_tau'],
        # Additional reference metrics
        'cv_f1': suite['cv'],
        'max_f1': suite['max'],
        'p5_f1': suite['p5'],
    }


def aggregate_across_seeds(results: List[ExperimentResult]) -> Tuple[List[np.ndarray],
                                                                      List[np.ndarray],
                                                                      Optional[List[np.ndarray]]]:
    """
    Aggregate predictions, labels, and uncertainties across multiple seeds.

    Returns:
        (predictions_by_month, labels_by_month, uncertainties_by_month)
    """
    if not results:
        raise ValueError("Cannot aggregate empty results list")

    if isinstance(results[0].predictions, np.ndarray):
        # Per-month format: each result is a single month
        results_by_month = defaultdict(list)

        for result in results:
            results_by_month[result.test_month].append(result)

        sorted_months = sorted(results_by_month.keys())

        predictions_by_month = []
        labels_by_month = []
        uncertainties_by_month = []

        for month in sorted_months:
            month_results = results_by_month[month]

            if not month_results:
                continue

            month_preds = np.concatenate([r.predictions for r in month_results])
            month_labels = np.concatenate([r.labels for r in month_results])
            month_uncs = np.concatenate([r.uncertainties_month_ahead for r in month_results])

            predictions_by_month.append(month_preds)
            labels_by_month.append(month_labels)
            uncertainties_by_month.append(month_uncs)

        if not predictions_by_month:
            raise ValueError("No data aggregated from results")

        return predictions_by_month, labels_by_month, uncertainties_by_month

    else:
        raise ValueError("Unexpected result format: predictions should be numpy arrays")


def get_b0_value(sampler_mode: str) -> str:
    """Get B_0 value based on sampler mode."""
    return SAMPLER_MODE_TO_B0.get(sampler_mode, "?")


# ============================================================================
# HYPERPARAMETER SELECTION
# ============================================================================

def select_best_hyperparameters(collection: ResultsCollection,
                                dataset: str,
                                budget: int,
                                base_name: str,
                                sampler_mode: str,
                                validation_months: int = 6) -> Optional[Dict]:
    """Select best hyperparameters based on validation period."""
    results = [
        r for r in collection.results
        if r.dataset == dataset
        and r.monthly_label_budget == budget
        and r.base_name == base_name
        and r.sampler_mode == sampler_mode
    ]

    if not results:
        return None

    hyperparam_groups = defaultdict(list)

    for result in results:
        hyperparam_key = tuple(
            (k, v) for k, v in sorted(result.hyperparameters.items())
            if k != "Random-Seed"
        )
        hyperparam_groups[hyperparam_key].append(result)

    if len(hyperparam_groups) <= 1:
        first_result = results[0]
        hyperparams = {
            k: v for k, v in first_result.hyperparameters.items()
            if k != "Random-Seed"
        }
        return hyperparams if hyperparams else None

    best_f1 = -np.inf
    best_hyperparams = None

    for hyperparam_key, group_results in hyperparam_groups.items():
        try:
            preds, labels, _ = aggregate_across_seeds(group_results)
            val_preds = preds[:validation_months]
            val_labels = labels[:validation_months]

            if not val_preds or not val_labels:
                continue

            all_preds = np.concatenate(val_preds)
            all_labels = np.concatenate(val_labels)
            val_f1, _, _ = compute_metrics_numba(all_labels, all_preds)

            if val_f1 > best_f1:
                best_f1 = val_f1
                best_hyperparams = dict(hyperparam_key)

        except (ValueError, IndexError):
            continue

    return best_hyperparams


def get_results_with_hyperparameters(collection: ResultsCollection,
                                    dataset: str,
                                    budget: int,
                                    base_name: str,
                                    sampler_mode: str,
                                    hyperparams: Optional[Dict]) -> List[ExperimentResult]:
    """Get all results matching the specified hyperparameters."""
    results = []

    for result in collection.results:
        if (result.dataset == dataset
            and result.monthly_label_budget == budget
            and result.base_name == base_name
            and result.sampler_mode == sampler_mode):

            if hyperparams is None:
                results.append(result)
            else:
                match = all(
                    result.hyperparameters.get(k) == v
                    for k, v in hyperparams.items()
                    if k != "Random-Seed"
                )
                if match:
                    results.append(result)

    return results


# ============================================================================
# TABLE GENERATION
# ============================================================================

def generate_pareto_table(collection: ResultsCollection,
                         validation_months: int = 6) -> Tuple[pd.DataFrame, Dict[str, ParetoAnalyzer]]:
    """
    Generate comprehensive Pareto evaluation table.

    Returns:
        Tuple of (DataFrame with all metrics, dict of ParetoAnalyzers by budget)
    """
    rows = []
    pareto_analyzers = {}  # budget -> ParetoAnalyzer

    datasets = sorted(collection.get_unique_values('dataset')['dataset'])
    budgets = sorted([b for b in collection.get_unique_values('monthly_label_budget')['monthly_label_budget']
                     if b in BUDGETS])
    base_names = sorted(collection.get_unique_values('base_name')['base_name'])
    sampler_modes = sorted([s for s in collection.get_unique_values('sampler_mode')['sampler_mode']
                           if s in SAMPLER_MODE_TO_B0])

    # Order base names
    ordered_base_names = [m for m in METHOD_ORDER if m in base_names]
    ordered_base_names += [m for m in base_names if m not in METHOD_ORDER]

    total_configs = len(datasets) * len(budgets) * len(ordered_base_names) * len(sampler_modes)

    print(f"\n{'='*80}")
    print(f"GENERATING PARETO EVALUATION TABLE")
    print(f"{'='*80}\n")
    print(f"Configurations: {len(datasets)} datasets × {len(budgets)} budgets × {len(ordered_base_names)} methods × {len(sampler_modes)} samplers")
    print(f"Total: {total_configs} configurations\n")

    pbar = tqdm(total=total_configs, desc="Processing configurations")

    for budget in budgets:
        # Initialize Pareto analyzer for this budget
        pareto_analyzers[budget] = ParetoAnalyzer(metrics=FOUR_PILLARS)

        for sampler_mode in sampler_modes:
            b0_value = get_b0_value(sampler_mode)

            for base_name in ordered_base_names:
                row = {
                    'B_M': budget,
                    'B_0': b0_value,
                    'Method': base_name,
                }

                for dataset in datasets:
                    pbar.set_description(f"{dataset[:4]} | {budget:3d} | {base_name[:15]:15s}")

                    # Select best hyperparameters
                    best_hyperparams = select_best_hyperparameters(
                        collection, dataset, budget, base_name, sampler_mode, validation_months
                    )

                    # Get results
                    results = get_results_with_hyperparameters(
                        collection, dataset, budget, base_name, sampler_mode, best_hyperparams
                    )

                    abbrev = DATASET_ABBREV.get(dataset, dataset)

                    if not results:
                        row[f'{abbrev}_mean_f1'] = np.nan
                        row[f'{abbrev}_sigma_f1'] = np.nan
                        row[f'{abbrev}_cv_f1'] = np.nan
                        row[f'{abbrev}_min_f1'] = np.nan
                        row[f'{abbrev}_max_f1'] = np.nan
                        row[f'{abbrev}_p5_f1'] = np.nan
                        row[f'{abbrev}_tau'] = np.nan
                        row[f'{abbrev}_aurc'] = np.nan
                        pbar.update(1)
                        continue

                    try:
                        all_preds, all_labels, all_uncs = aggregate_across_seeds(results)
                    except (ValueError, IndexError):
                        row[f'{abbrev}_mean_f1'] = np.nan
                        row[f'{abbrev}_sigma_f1'] = np.nan
                        row[f'{abbrev}_cv_f1'] = np.nan
                        row[f'{abbrev}_min_f1'] = np.nan
                        row[f'{abbrev}_max_f1'] = np.nan
                        row[f'{abbrev}_p5_f1'] = np.nan
                        row[f'{abbrev}_tau'] = np.nan
                        row[f'{abbrev}_aurc'] = np.nan
                        pbar.update(1)
                        continue

                    # Split into validation and test periods
                    test_preds = all_preds[validation_months:]
                    test_labels = all_labels[validation_months:]
                    test_uncs = all_uncs[validation_months:] if all_uncs else None

                    if not test_preds:
                        pbar.update(1)
                        continue

                    # Compute monthly F1 values
                    monthly_f1 = compute_monthly_f1_values(test_preds, test_labels)

                    # Compute Four Pillars metrics
                    pillars = compute_four_pillars(monthly_f1)

                    # Store in row
                    row[f'{abbrev}_mean_f1'] = pillars['mean_f1']
                    row[f'{abbrev}_sigma_f1'] = pillars['sigma_f1']
                    row[f'{abbrev}_cv_f1'] = pillars['cv_f1']
                    row[f'{abbrev}_min_f1'] = pillars['min_f1']
                    row[f'{abbrev}_max_f1'] = pillars['max_f1']
                    row[f'{abbrev}_p5_f1'] = pillars['p5_f1']
                    row[f'{abbrev}_tau'] = pillars['mann_kendall_tau']

                    # Compute AURC if uncertainties available
                    if test_uncs is not None:
                        try:
                            aurc = compute_aurc(test_labels, test_preds, test_uncs, compute_eaurc=False)
                            row[f'{abbrev}_aurc'] = float(aurc)
                        except Exception:
                            row[f'{abbrev}_aurc'] = np.nan
                    else:
                        row[f'{abbrev}_aurc'] = np.nan

                    # Add to Pareto analyzer (use method + sampler as unique key)
                    method_key = f"{base_name} ({'Full' if b0_value == 'D_0' else 'Subsample'})"
                    pareto_analyzers[budget].add_method_results(
                        method_name=method_key,
                        dataset=dataset,
                        monthly_f1=monthly_f1,
                    )

                    pbar.update(1)

                rows.append(row)

    pbar.close()

    df = pd.DataFrame(rows)
    return df, pareto_analyzers


def add_pareto_status(df: pd.DataFrame,
                     pareto_analyzers: Dict[int, ParetoAnalyzer]) -> pd.DataFrame:
    """
    Add Pareto frontier status columns to the table.

    Adds columns:
    - {dataset}_pareto: "F" for frontier, "D" for dominated
    - pareto_count: Number of datasets where on frontier
    - pareto_status: "UNIVERSAL", "PARTIAL", or "DOMINATED"
    """
    df = df.copy()

    datasets = ['azoo', 'apig', 'trans']
    dataset_full = {'azoo': 'androzoo', 'apig': 'apigraph', 'trans': 'transcendent'}

    # Initialize columns
    for ds in datasets:
        df[f'{ds}_pareto'] = ''
    df['pareto_count'] = 0
    df['pareto_status'] = ''

    for idx, row in df.iterrows():
        budget = row['B_M']
        b0 = row['B_0']
        method = row['Method']

        # Construct method key
        method_key = f"{method} ({'Full' if b0 == 'D_0' else 'Subsample'})"

        if budget not in pareto_analyzers:
            continue

        analyzer = pareto_analyzers[budget]

        try:
            universal_result = analyzer.analyze_universal()
        except ValueError:
            continue

        pareto_count = 0

        for ds in datasets:
            full_ds = dataset_full[ds]

            if full_ds in universal_result.dataset_results:
                ds_result = universal_result.dataset_results[full_ds]

                if method_key in ds_result.frontier:
                    df.at[idx, f'{ds}_pareto'] = 'F'
                    pareto_count += 1
                else:
                    df.at[idx, f'{ds}_pareto'] = 'D'

        df.at[idx, 'pareto_count'] = pareto_count

        if pareto_count == len(datasets):
            df.at[idx, 'pareto_status'] = 'UNIVERSAL'
        elif pareto_count > 0:
            df.at[idx, 'pareto_status'] = f'PARTIAL({pareto_count}/{len(datasets)})'
        else:
            df.at[idx, 'pareto_status'] = 'DOMINATED'

    return df


def format_pareto_table(df: pd.DataFrame) -> pd.DataFrame:
    """Format results table for display."""
    formatted = df.copy()

    # Convert F1 metrics to percentages
    f1_cols = [c for c in formatted.columns if any(x in c for x in ['_f1', '_mean', '_sigma', '_min', '_max', '_p5'])]
    for col in f1_cols:
        if col in formatted.columns:
            formatted[col] = (formatted[col] * 100).round(1)

    # Format tau
    tau_cols = [c for c in formatted.columns if '_tau' in c]
    for col in tau_cols:
        if col in formatted.columns:
            formatted[col] = formatted[col].round(3)

    # Format AURC
    aurc_cols = [c for c in formatted.columns if 'aurc' in c]
    for col in aurc_cols:
        if col in formatted.columns:
            formatted[col] = (formatted[col] * 100).round(1)

    return formatted


def print_pareto_summary(pareto_analyzers: Dict[int, ParetoAnalyzer]):
    """Print Pareto analysis summary."""
    print(f"\n{'='*80}")
    print(f"PARETO ANALYSIS SUMMARY")
    print(f"{'='*80}\n")

    for budget in sorted(pareto_analyzers.keys()):
        analyzer = pareto_analyzers[budget]

        try:
            universal_result = analyzer.analyze_universal()
        except ValueError:
            print(f"B_M = {budget}: No valid data")
            continue

        print(f"\nB_M = {budget}")
        print("-" * 40)

        if universal_result.universal_optimal:
            print(f"  UNIVERSAL Pareto-optimal ({len(universal_result.universal_optimal)}):")
            for method in sorted(universal_result.universal_optimal):
                print(f"    ★ {method}")

        if universal_result.partial_optimal:
            print(f"\n  PARTIAL Pareto-optimal ({len(universal_result.partial_optimal)}):")
            for method, datasets in sorted(universal_result.partial_optimal.items()):
                ds_str = ", ".join(d[:4] for d in datasets)
                print(f"    ◐ {method} ({ds_str})")

        if universal_result.always_dominated:
            print(f"\n  Always DOMINATED ({len(universal_result.always_dominated)}):")
            for method in sorted(universal_result.always_dominated):
                print(f"    ○ {method}")

    print()


def export_latex_table(df: pd.DataFrame, output_file: Path):
    """Export table in LaTeX format for the paper."""

    datasets = ['azoo', 'apig', 'trans']

    latex_lines = [
        r"\begin{table*}[t]",
        r"\centering",
        r"\caption{Multi-Objective Evaluation: Four Pillars Framework with Pareto Analysis}",
        r"\label{tab:pareto-evaluation}",
        r"\resizebox{\textwidth}{!}{%",
        r"\begin{tabular}{ll|ccc|ccc|ccc|ccc|c}",
        r"\toprule",
    ]

    # Header row 1: dataset names
    header1 = r"\multirow{2}{*}{\textbf{Method}} & \multirow{2}{*}{\textbf{$B_0$}}"
    for ds in datasets:
        full_name = DATASET_FULL_NAMES.get(ds, ds)
        header1 += f" & \\multicolumn{{3}}{{c|}}{{{full_name}}}"
    header1 += r" & \multirow{2}{*}{\textbf{Status}} \\"
    latex_lines.append(header1)

    # Header row 2: metrics
    header2 = r" & "
    for ds in datasets:
        header2 += r" & $\bar{F_1}$ & $\sigma$ & $\tau$"
    header2 += r" \\"
    latex_lines.append(header2)
    latex_lines.append(r"\midrule")

    # Group by budget
    for budget in sorted(df['B_M'].unique()):
        budget_df = df[df['B_M'] == budget]

        latex_lines.append(f"\\multicolumn{{14}}{{l}}{{\\textbf{{$B_M = {budget}$}}}} \\\\")
        latex_lines.append(r"\midrule")

        for _, row in budget_df.iterrows():
            method = row['Method'].replace("_", "\\_")
            b0 = str(row['B_0']).replace("D_0", "$\\mathcal{D}_0$")

            # Build row
            line = f"{method} & {b0}"

            for ds in datasets:
                mean_f1 = row.get(f'{ds}_mean_f1', np.nan)
                sigma = row.get(f'{ds}_sigma_f1', np.nan)
                tau = row.get(f'{ds}_tau', np.nan)
                pareto = row.get(f'{ds}_pareto', '')

                # Format values
                mean_str = f"{mean_f1:.1f}" if not pd.isna(mean_f1) else "--"
                sigma_str = f"{sigma:.1f}" if not pd.isna(sigma) else "--"
                tau_str = f"{tau:.2f}" if not pd.isna(tau) else "--"

                # Bold if on frontier
                if pareto == 'F':
                    mean_str = f"\\textbf{{{mean_str}}}"

                line += f" & {mean_str} & {sigma_str} & {tau_str}"

            # Status column
            status = row.get('pareto_status', '')
            if status == 'UNIVERSAL':
                status_str = r"\ding{72}"  # Star
            elif 'PARTIAL' in status:
                status_str = r"\ding{115}"  # Half circle
            else:
                status_str = r"\ding{109}"  # Empty circle

            line += f" & {status_str} \\\\"
            latex_lines.append(line)

        latex_lines.append(r"\midrule")

    latex_lines.extend([
        r"\bottomrule",
        r"\end{tabular}%",
        r"}",
        r"\end{table*}",
    ])

    with open(output_file, 'w') as f:
        f.write('\n'.join(latex_lines))

    print(f"✅ LaTeX table saved to: {output_file}")


# ============================================================================
# MAIN
# ============================================================================

def main():
    """Main entry point"""

    print(f"\n{'='*80}")
    print(f"PRODUCE PARETO EVALUATION TABLE")
    print(f"Four Pillars: R (Mean F1), S (σ[F1]), T (Min[F1]), τ (Mann-Kendall)")
    print(f"{'='*80}\n")

    print(f"Total results loaded: {len(all_results)}")
    print(f"Datasets: {sorted(all_results.get_unique_values('dataset')['dataset'])}")
    print(f"Budgets: {sorted(all_results.get_unique_values('monthly_label_budget')['monthly_label_budget'])}")
    print(f"Base names: {len(all_results.get_unique_values('base_name')['base_name'])}")

    # Generate table
    results_df, pareto_analyzers = generate_pareto_table(all_results, validation_months=VALIDATION_MONTHS)

    # Add Pareto status
    results_df = add_pareto_status(results_df, pareto_analyzers)

    # Format table
    formatted_df = format_pareto_table(results_df)

    # Print Pareto summary
    print_pareto_summary(pareto_analyzers)

    # Print sample of table
    print(f"\n{'='*80}")
    print(f"SAMPLE OUTPUT (First 10 rows)")
    print(f"{'='*80}\n")

    display_cols = ['B_M', 'B_0', 'Method',
                   'azoo_mean_f1', 'azoo_sigma_f1', 'azoo_tau', 'azoo_pareto',
                   'pareto_status']
    available_cols = [c for c in display_cols if c in formatted_df.columns]
    print(formatted_df[available_cols].head(10).to_string(index=False))

    # Save outputs
    output_csv = OUTPUT_DIR / "pareto_evaluation_table.csv"
    formatted_df.to_csv(output_csv, index=False)
    print(f"\n✅ CSV saved to: {output_csv}")

    output_excel = OUTPUT_DIR / "pareto_evaluation_table.xlsx"
    try:
        formatted_df.to_excel(output_excel, index=False)
        print(f"✅ Excel saved to: {output_excel}")
    except ImportError:
        print(f"⚠️  Excel export skipped (install openpyxl for Excel support)")

    # Export LaTeX
    export_latex_table(formatted_df, OUTPUT_DIR / "pareto_table.tex")

    # Save raw (non-formatted) for programmatic use
    results_df.to_pickle(OUTPUT_DIR / "pareto_evaluation_raw.pkl")
    print(f"✅ Raw data saved to: {OUTPUT_DIR / 'pareto_evaluation_raw.pkl'}")

    print(f"\n{'='*80}")
    print(f"✅ PARETO TABLE GENERATION COMPLETE")
    print(f"{'='*80}\n")

    return formatted_df, pareto_analyzers


if __name__ == "__main__":
    main()
