"""
Rejection Simulation Module

Post-hoc rejection simulation for evaluating uncertainty quality.
Tests "what if we reject uncertain samples?" without expensive retraining.

Core Implementation:
-------------------
The core rejection logic is ported from tools.py with NO MODIFICATIONS.
All numba-accelerated functions are preserved exactly as-is for performance.

Methods Supported:
-----------------
1. single_thresh_simple: Per-month single threshold (no compounding)
2. single_thresh_compounded: Accumulated uncertainty across months
3. dual_thresh_simple: Per-month dual thresholds (class-conditional)
4. dual_thresh_compounded: Accumulated class-conditional rejection

Aurora Integration:
------------------
Clean wrapper classes for working with ResultsCollection and multi-seed data.
"""

from typing import List, Dict, Optional, Literal, Tuple
import numpy as np
import pandas as pd
from dataclasses import dataclass
from numba import njit, prange
from tqdm import tqdm

from .schema import ResultsCollection, ExperimentResult


# ============================================================================
# CORE REJECTION FUNCTIONS (from tools.py - NO MODIFICATIONS)
# ============================================================================

@njit
def compute_metrics_numba(y_true, y_pred):
    """
    Compute F1, FNR, and FPR for binary classification using numba.

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Args:
        y_true: 1D array of ground truth labels (0 or 1)
        y_pred: 1D array of predicted labels (0 or 1)

    Returns:
        tuple: (f1, fnr, fpr)
    """
    n = y_true.shape[0]

    if n != y_pred.shape[0] or n < 20:
        return np.nan, np.nan, np.nan

    same_ytrue = True
    for i in range(1, n):
        if y_true[i] != y_true[0]:
            same_ytrue = False
            break

    same_ypred = True
    for i in range(1, n):
        if y_pred[i] != y_pred[0]:
            same_ypred = False
            break

    if same_ytrue and same_ypred:
        if y_true[0] == y_pred[0]:
            return 1.0, 0.0, 0.0
        else:
            if y_true[0] == 1:
                return 0.0, 1.0, 0.0
            elif y_true[0] == 0:
                return 0.0, 0.0, 1.0
            else:
                return np.nan, np.nan, np.nan

    # Initialize confusion matrix counts
    TP = 0
    TN = 0
    FP = 0
    FN = 0

    for i in range(n):
        yt = y_true[i]
        yp = y_pred[i]

        if yt == 1 and yp == 1:
            TP += 1
        elif yt == 0 and yp == 0:
            TN += 1
        elif yt == 0 and yp == 1:
            FP += 1
        elif yt == 1 and yp == 0:
            FN += 1

    # Compute metrics
    # F1 = 2 * TP / (2*TP + FP + FN)
    denom_f1 = 2 * TP + FP + FN
    if denom_f1 == 0:
        f1 = 0.0
    else:
        f1 = (2.0 * TP) / denom_f1

    # FNR = FN / (TP + FN)
    actual_positives = TP + FN
    if actual_positives == 0:
        fnr = 0.0
    else:
        fnr = FN / actual_positives

    # FPR = FP / (FP + TN)
    actual_negatives = FP + TN
    if actual_negatives == 0:
        fpr = 0.0
    else:
        fpr = FP / actual_negatives

    return f1, fnr, fpr


@njit
def recover_positive_softmax(classifier_uncertainty: np.ndarray, predictions: np.ndarray) -> np.ndarray:
    """
    Recover the softmax probability for the positive class (class 1)

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Args:
        classifier_uncertainty: Uncertainty calculated as 1 - (np.abs(max_probs - 0.5) / 0.5)
        predictions: Binary predictions array (0 or 1)

    Returns:
        np.ndarray: Softmax probabilities for the positive class
    """
    max_probs = 0.5 * (2 - classifier_uncertainty)
    n = len(predictions)
    positive_softmax = np.empty(n, dtype=np.float64)

    for i in range(n):
        if predictions[i] == 1:
            positive_softmax[i] = max_probs[i]
        else:
            positive_softmax[i] = 1 - max_probs[i]

    return positive_softmax


@njit
def binary_classification_rejection_thresholds(
    agg_scores_binary: np.ndarray,
    total_to_reject: int,
) -> Tuple[float, float]:
    """
    Compute lower and upper thresholds for rejection in binary classification.

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Args:
        agg_scores_binary: Array of binary classification scores (probabilities for one class)
        total_to_reject: Number of samples to reject

    Returns:
        tuple[float, float]: (lower_threshold, upper_threshold)
            Reject if score in (lower_threshold, upper_threshold)
    """
    n_samples = len(agg_scores_binary)

    if total_to_reject <= 0 or total_to_reject > n_samples:
        return 1.0, 0.0  # Reject nothing

    # Compute distances from decision boundary (0.5)
    distances = np.empty(n_samples, dtype=np.float64)
    for i in range(n_samples):
        distances[i] = abs(agg_scores_binary[i] - 0.5)

    # Sort indices by distance (ascending = most uncertain first)
    sorted_indices = np.argsort(distances)

    # Get the top N most uncertain samples
    threshold_idx = sorted_indices[total_to_reject - 1]
    threshold_distance = distances[threshold_idx]

    # Convert back to probability thresholds
    lower_cutoff = 0.5 - threshold_distance
    upper_cutoff = 0.5 + threshold_distance

    return lower_cutoff, upper_cutoff


@njit
def single_value_uncertainty_threshold(
    agg_scores_single: np.ndarray,
    total_to_reject: int,
) -> float:
    """
    Compute uncertainty threshold for rejection.

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Args:
        agg_scores_single: Array of uncertainty scores (higher = more uncertain)
        total_to_reject: Number of samples to reject

    Returns:
        float: Threshold value above which samples should be rejected
    """
    n_samples = len(agg_scores_single)

    if total_to_reject <= 0:
        return np.inf  # Reject nothing

    if total_to_reject >= n_samples:
        return -np.inf  # Reject everything

    # Sort scores from large (most uncertain) to small (most certain)
    sorted_scores = np.argsort(agg_scores_single)[::-1]

    # Get the cutoff index
    cutoff_idx = sorted_scores[total_to_reject - 1]

    return agg_scores_single[cutoff_idx]


@njit(parallel=True)
def _PostHocRejectorSimulator_Refactor(
        uncertainties,  # List of np.ndarray (one per month)
        predictions,    # List of np.ndarray (one per month)
        labels,         # List of np.ndarray (one per month)
        rejection_Ns,   # 1D np.ndarray of int64 rejection quotas
        upto_reject,    # Boolean flag to limit rejections per month
        method          # String indicating which method to use
):
    """
    Numba-accelerated rejection simulator.

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Methods:
      - "single_thresh_simple": Per-month single threshold
      - "single_thresh_compounded": Accumulated uncertainty (single threshold)
      - "dual_thresh_simple": Per-month dual threshold (class-conditional)
      - "dual_thresh_compounded": Accumulated class-conditional

    Returns:
        3D array of shape (len(rejection_Ns), 16, n_months) with metrics
    """
    n_months = len(uncertainties)
    n_rej = rejection_Ns.shape[0]
    result = np.empty((n_rej, 16, n_months), dtype=np.float64)

    # Outer loop over each rejection quota (parallelized)
    for j in prange(n_rej):
        rej_N = rejection_Ns[j]

        # Preallocate aggregated_scores array
        total_possible = 0
        for Mi in range(n_months):
            total_possible += uncertainties[Mi].shape[0]

        aggregated_scores = np.empty(total_possible, dtype=np.float64)
        agg_score_index = 0

        # Prepare accumulators for aggregated accepted samples
        agg_labels = np.empty(total_possible, dtype=labels[0].dtype)
        agg_preds = np.empty(total_possible, dtype=predictions[0].dtype)
        agg_index = 0

        # Arrays to store per-month metrics and counts
        month_metrics = np.empty((n_months, 6), dtype=np.float64)
        month_rejections = np.empty(n_months, dtype=np.int64)
        month_acceptances = np.empty(n_months, dtype=np.int64)

        # Process each month
        for Mi in range(n_months):
            curr_uncert = uncertainties[Mi]
            curr_preds = predictions[Mi]
            curr_labels = labels[Mi]
            n_samples = curr_uncert.shape[0]

            # Prepare reject_mask
            reject_mask = np.zeros(n_samples, dtype=np.bool_)

            if method == "single_thresh_simple":
                if Mi == 0:
                    uncert_cutoff = single_value_uncertainty_threshold(curr_uncert, rej_N)
                    for k in range(n_samples):
                        if curr_uncert[k] > uncert_cutoff:
                            reject_mask[k] = True
                else:
                    rejection_counter = 0
                    for k in range(n_samples):
                        if curr_uncert[k] > uncert_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break
                    uncert_cutoff = single_value_uncertainty_threshold(curr_uncert, rej_N)

            elif method == "single_thresh_compounded":
                if Mi == 0:
                    total_to_reject = rej_N
                    for k in range(n_samples):
                        aggregated_scores[k] = curr_uncert[k]
                        agg_score_index += 1

                    cutoff_point = single_value_uncertainty_threshold(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    for k in range(n_samples):
                        if curr_uncert[k] > cutoff_point:
                            reject_mask[k] = True
                else:
                    total_to_reject = Mi * rej_N
                    cutoff_point = single_value_uncertainty_threshold(
                        aggregated_scores[:agg_score_index],
                        total_to_reject
                    )
                    for k in range(n_samples):
                        if curr_uncert[k] > cutoff_point:
                            reject_mask[k] = True

                    prev_agg_score_index = agg_score_index
                    for k in range(n_samples):
                        aggregated_scores[prev_agg_score_index + k] = curr_uncert[k]
                        agg_score_index += 1

            elif method == "dual_thresh_simple":
                recovered_scores = recover_positive_softmax(curr_uncert, curr_preds)
                if Mi == 0:
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        recovered_scores, rej_N
                    )
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                else:
                    rejection_counter = 0
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        recovered_scores, rej_N
                    )

            elif method == "dual_thresh_compounded":
                recovered_scores = recover_positive_softmax(curr_uncert, curr_preds)
                if Mi == 0:
                    total_to_reject = rej_N
                    for k in range(n_samples):
                        aggregated_scores[k] = recovered_scores[k]
                        agg_score_index += 1

                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True

                else:
                    total_to_reject = Mi * rej_N
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    rejection_counter = 0
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break

                    prev_agg_score_index = agg_score_index
                    for k in range(n_samples):
                        aggregated_scores[prev_agg_score_index + k] = recovered_scores[k]
                        agg_score_index += 1

            # Build accepted mask
            accept_mask = np.empty(n_samples, dtype=np.bool_)
            for k in range(n_samples):
                accept_mask[k] = not reject_mask[k]

            # Count accepted samples
            accepted_count = 0
            for k in range(n_samples):
                if accept_mask[k]:
                    accepted_count += 1

            accepted_labels = np.empty(accepted_count, dtype=curr_labels.dtype)
            accepted_preds = np.empty(accepted_count, dtype=curr_preds.dtype)
            pos = 0
            for k in range(n_samples):
                if accept_mask[k]:
                    accepted_labels[pos] = curr_labels[k]
                    accepted_preds[pos] = curr_preds[k]
                    pos += 1

            # Compute per-month metrics
            f1_baseline, fnr_baseline, fpr_baseline = compute_metrics_numba(curr_labels, curr_preds)
            f1, fnr, fpr = compute_metrics_numba(accepted_labels, accepted_preds)
            month_metrics[Mi, 0] = f1_baseline
            month_metrics[Mi, 1] = fnr_baseline
            month_metrics[Mi, 2] = fpr_baseline
            month_metrics[Mi, 3] = f1
            month_metrics[Mi, 4] = fnr
            month_metrics[Mi, 5] = fpr

            # Count rejections
            rej_count = 0
            for k in range(n_samples):
                if reject_mask[k]:
                    rej_count += 1
            month_rejections[Mi] = rej_count
            month_acceptances[Mi] = accepted_count

            # Accumulate accepted samples for aggregated metrics
            for k in range(accepted_count):
                agg_labels[agg_index] = accepted_labels[k]
                agg_preds[agg_index] = accepted_preds[k]
                agg_index += 1

        # Compute aggregated metrics across all months
        if agg_index > 0:
            agg_f1, agg_fnr, agg_fpr = compute_metrics_numba(agg_labels[:agg_index], agg_preds[:agg_index])
        else:
            agg_f1 = 0.0
            agg_fnr = 0.0
            agg_fpr = 0.0

        # Compute average per-month metrics
        sum_f1 = 0.0
        sum_fnr = 0.0
        sum_fpr = 0.0
        sum_rej = 0.0
        sum_acc = 0.0
        for Mi in range(n_months):
            sum_f1 += month_metrics[Mi, 0]
            sum_fnr += month_metrics[Mi, 1]
            sum_fpr += month_metrics[Mi, 2]
            sum_rej += month_rejections[Mi]
            sum_acc += month_acceptances[Mi]

        avg_month_f1 = sum_f1 / n_months
        avg_month_fnr = sum_fnr / n_months
        avg_month_fpr = sum_fpr / n_months
        avg_month_rej = sum_rej / n_months
        avg_month_acc = sum_acc / n_months

        # Populate result
        result[j, 0] = np.repeat(agg_f1, n_months)
        result[j, 1] = np.repeat(agg_fnr, n_months)
        result[j, 2] = np.repeat(agg_fpr, n_months)
        result[j, 3] = np.repeat(avg_month_rej, n_months)
        result[j, 4] = np.repeat(avg_month_f1, n_months)
        result[j, 5] = np.repeat(avg_month_fnr, n_months)
        result[j, 6] = np.repeat(avg_month_fpr, n_months)
        result[j, 7] = np.repeat(avg_month_acc, n_months)
        result[j, 8] = month_rejections
        result[j, 9] = month_acceptances
        result[j, 10] = month_metrics[:, 0]
        result[j, 11] = month_metrics[:, 1]
        result[j, 12] = month_metrics[:, 2]
        result[j, 13] = month_metrics[:, 3]
        result[j, 14] = month_metrics[:, 4]
        result[j, 15] = month_metrics[:, 5]

    return result


def PostHocRejectorSimulator(
    uncertainties: List[np.ndarray],
    predictions: List[np.ndarray],
    labels: List[np.ndarray],
    rejection_Ns: np.ndarray,
    upto_reject: bool,
    method: str
) -> Dict[int, Dict]:
    """
    Wrapper for numba-accelerated rejection simulator.

    COPIED EXACTLY FROM TOOLS.PY - DO NOT MODIFY

    Args:
        uncertainties: List of uncertainty arrays (one per month)
        predictions: List of prediction arrays (one per month)
        labels: List of label arrays (one per month)
        rejection_Ns: Array of rejection budgets to test
        upto_reject: Limit rejections per month to budget
        method: Rejection method ("single_thresh_simple", "single_thresh_compounded",
                "dual_thresh_simple", "dual_thresh_compounded")

    Returns:
        Dict mapping rejection_budget -> metrics dict
    """
    rejection_Ns_arr = np.array(rejection_Ns, dtype=np.int64)

    result_array = _PostHocRejectorSimulator_Refactor(
        uncertainties, predictions, labels,
        rejection_Ns_arr, upto_reject, method
    )

    result = {}
    for idx, rej_N in enumerate(rejection_Ns_arr):
        result[int(rej_N)] = {
            "F1": result_array[idx, 0][0],
            "FNR": result_array[idx, 1][0],
            "FPR": result_array[idx, 2][0],
            "avg_monthly_Rejections": result_array[idx, 3][0],
            "avg_monthly_F1": result_array[idx, 4][0],
            "avg_monthly_FNR": result_array[idx, 5][0],
            "avg_monthly_FPR": result_array[idx, 6][0],
            "avg_monthly_Acceptances": result_array[idx, 7][0],
            "monthly_Rejections": result_array[idx, 8],
            "monthly_Acceptances": result_array[idx, 9],
            "monthly_Total": result_array[idx, 8] + result_array[idx, 9],
            "monthly_F1_no_rejection": result_array[idx, 10],
            "monthly_FNR_no_rejection": result_array[idx, 11],
            "monthly_FPR_no_rejection": result_array[idx, 12],
            "monthly_F1": result_array[idx, 13],
            "monthly_FNR": result_array[idx, 14],
            "monthly_FPR": result_array[idx, 15],
        }

    return result


# ============================================================================
# AURORA-NATIVE WRAPPER (NEW - Clean Interface)
# ============================================================================

@dataclass
class RejectionMetrics:
    """Metrics for a single rejection budget"""
    rejection_budget: int
    method: str

    # Baseline (no rejection)
    baseline_f1: float
    baseline_fnr: float
    baseline_fpr: float

    # With rejection
    rejected_f1: float
    rejected_fnr: float
    rejected_fpr: float

    # Improvement
    improvement_f1: float
    improvement_fnr: float
    improvement_fpr: float

    # Rejection quality
    avg_monthly_rejections: float
    avg_monthly_acceptances: float
    coverage: float  # % kept

    # Per-month arrays
    monthly_rejections: np.ndarray
    monthly_acceptances: np.ndarray
    monthly_f1_baseline: np.ndarray
    monthly_f1_rejected: np.ndarray


class RejectionSimulator:
    """
    Aurora-native rejection simulator

    Clean interface for working with ResultsCollection.
    Handles multi-seed aggregation automatically.

    Example:
        simulator = RejectionSimulator(collection)
        results = simulator.simulate(
            rejection_budgets=[0, 100, 200, 400],
            method="dual_thresh_compounded",
            group_by=["dataset", "base_name", "monthly_label_budget"]
        )
    """

    def __init__(self, collection: ResultsCollection):
        """
        Initialize simulator with results collection

        Args:
            collection: ResultsCollection with experimental results
        """
        self.collection = collection

    def _aggregate_by_group(self, group_by: List[str]) -> Dict[tuple, Dict]:
        """
        Aggregate results by group, handling multi-seed experiments

        Returns:
            Dict mapping group_key -> {
                'predictions': List[np.ndarray] (one per month) or List[List[np.ndarray]] (multi-seed),
                'labels': List[np.ndarray] or List[List[np.ndarray]],
                'uncertainties': List[np.ndarray] or List[List[np.ndarray]],
                'is_multiseed': bool,
                'num_seeds': Optional[int],
                'months': List[int],
                'group_info': Dict (for metadata)
            }
        """
        from collections import defaultdict

        # Group results by specified dimensions
        grouped = defaultdict(list)
        for result in self.collection.results:
            group_key = result.get_grouping_key(*group_by)
            grouped[group_key].append(result)

        # Process each group
        aggregated = {}
        for group_key, group_results in grouped.items():
            # Sort by test month
            group_results = sorted(group_results, key=lambda r: (r.get_hyperparameter('Random-Seed', 0), r.test_month))

            # Check if multi-seed
            seeds = set(r.get_hyperparameter('Random-Seed', 0) for r in group_results)
            is_multiseed = len(seeds) > 1

            if is_multiseed:
                # Organize by seed, then by month
                by_seed = defaultdict(lambda: defaultdict(list))
                for result in group_results:
                    seed = result.get_hyperparameter('Random-Seed', 0)
                    month = result.test_month
                    by_seed[seed][month].append(result)

                # Get all months
                all_months = sorted(set(r.test_month for r in group_results))

                # Build lists: one sublist per month, containing one array per seed
                predictions = []
                labels = []
                uncertainties = []

                for month in all_months:
                    month_preds = []
                    month_labels = []
                    month_uncs = []

                    for seed in sorted(seeds):
                        if month in by_seed[seed]:
                            # Should be exactly one result per seed-month
                            result = by_seed[seed][month][0]
                            month_preds.append(result.predictions)
                            month_labels.append(result.labels)
                            month_uncs.append(result.uncertainties_past_month)

                    predictions.append(month_preds)
                    labels.append(month_labels)
                    uncertainties.append(month_uncs)

                aggregated[group_key] = {
                    'predictions': predictions,
                    'labels': labels,
                    'uncertainties': uncertainties,
                    'is_multiseed': True,
                    'num_seeds': len(seeds),
                    'months': all_months,
                    'group_info': dict(zip(group_by, group_key))
                }
            else:
                # Single seed - just one array per month
                by_month = defaultdict(list)
                for result in group_results:
                    by_month[result.test_month].append(result)

                all_months = sorted(by_month.keys())

                predictions = []
                labels = []
                uncertainties = []

                for month in all_months:
                    # Should be exactly one result per month
                    result = by_month[month][0]
                    predictions.append(result.predictions)
                    labels.append(result.labels)
                    uncertainties.append(result.uncertainties_past_month)

                aggregated[group_key] = {
                    'predictions': predictions,
                    'labels': labels,
                    'uncertainties': uncertainties,
                    'is_multiseed': False,
                    'num_seeds': 1,
                    'months': all_months,
                    'group_info': dict(zip(group_by, group_key))
                }

        return aggregated

    def _simulate_single_seed(
        self,
        group_data: Dict,
        rejection_budgets: List[int],
        method: str,
        upto_reject: bool,
    ) -> Dict:
        """
        Run rejection simulation for single-seed data

        Returns:
            Dict mapping budget -> metrics (from PostHocRejectorSimulator)
        """
        return PostHocRejectorSimulator(
            uncertainties=group_data['uncertainties'],
            predictions=group_data['predictions'],
            labels=group_data['labels'],
            rejection_Ns=rejection_budgets,
            upto_reject=upto_reject,
            method=method,
        )

    def _simulate_multiseed(
        self,
        group_data: Dict,
        rejection_budgets: List[int],
        method: str,
        upto_reject: bool,
    ) -> Dict:
        """
        Run rejection simulation for multi-seed data

        Runs per-seed, then averages metrics across seeds
        (Replicates user's average_metrics_dict pattern)

        Returns:
            Dict mapping budget -> averaged metrics
        """
        num_seeds = group_data['num_seeds']
        num_months = len(group_data['months'])

        # Collect results per seed
        per_seed_results = {}

        for seed_idx in range(num_seeds):
            # Extract data for this seed
            seed_preds = []
            seed_labels = []
            seed_uncs = []

            for month_idx in range(num_months):
                seed_preds.append(group_data['predictions'][month_idx][seed_idx])
                seed_labels.append(group_data['labels'][month_idx][seed_idx])
                seed_uncs.append(group_data['uncertainties'][month_idx][seed_idx])

            # Run rejection for this seed
            per_seed_results[seed_idx] = PostHocRejectorSimulator(
                uncertainties=seed_uncs,
                predictions=seed_preds,
                labels=seed_labels,
                rejection_Ns=rejection_budgets,
                upto_reject=upto_reject,
                method=method,
            )

        # Average across seeds (replicates average_metrics_dict pattern)
        averaged_results = {}

        for budget in rejection_budgets:
            # Collect metrics for this budget across all seeds
            budget_metrics = [per_seed_results[seed][budget] for seed in range(num_seeds)]

            # Average each metric
            averaged = {}
            for key in budget_metrics[0].keys():
                values = [m[key] for m in budget_metrics]

                if isinstance(values[0], np.ndarray):
                    # Average arrays element-wise
                    averaged[key] = np.mean(values, axis=0)
                elif isinstance(values[0], (int, float)):
                    # Average scalars
                    averaged[key] = np.mean(values)
                else:
                    # Keep first value for non-numeric fields
                    averaged[key] = values[0]

            averaged_results[budget] = averaged

        return averaged_results

    def simulate(
        self,
        rejection_budgets: List[int],
        method: str = "dual_thresh_compounded",
        upto_reject: bool = False,
        group_by: Optional[List[str]] = None,
    ) -> pd.DataFrame:
        """
        Simulate rejection for all groups in collection

        Args:
            rejection_budgets: List of rejection budgets to test
            method: Rejection method
                - "single_thresh_simple": Per-month single threshold
                - "single_thresh_compounded": Accumulated single threshold
                - "dual_thresh_simple": Per-month class-conditional
                - "dual_thresh_compounded": Accumulated class-conditional (DEFAULT)
            upto_reject: Limit rejections per month to budget
            group_by: Grouping dimensions (default: dataset, base_name, monthly_label_budget, sampler_mode)

        Returns:
            DataFrame with rejection metrics for each group/budget combination
        """
        from tqdm.auto import tqdm

        if group_by is None:
            group_by = ['dataset', 'base_name', 'monthly_label_budget', 'sampler_mode']

        print(f"Simulating rejection with method={method}")
        print(f"Rejection budgets: {rejection_budgets}")
        print(f"Group by: {group_by}")

        # Step 1: Aggregate by group
        print("\nAggregating results by group...")
        aggregated = self._aggregate_by_group(group_by)
        print(f"Found {len(aggregated)} groups")

        # Step 2: Run rejection simulation for each group
        all_results = []

        for group_key, group_data in tqdm(aggregated.items(), desc="Simulating rejection"):
            # Run simulation (handles multi-seed automatically)
            if group_data['is_multiseed']:
                rejection_results = self._simulate_multiseed(
                    group_data, rejection_budgets, method, upto_reject
                )
            else:
                rejection_results = self._simulate_single_seed(
                    group_data, rejection_budgets, method, upto_reject
                )

            # Extract metrics for each budget
            for budget in rejection_budgets:
                metrics = rejection_results[budget]

                # Create row with group info + metrics
                row = group_data['group_info'].copy()
                row['rejection_budget'] = budget
                row['method'] = method

                # Compute baseline (no rejection) metrics
                # Use average across months
                baseline_f1 = np.mean(metrics['monthly_F1_no_rejection'])
                baseline_fnr = np.mean(metrics['monthly_FNR_no_rejection'])
                baseline_fpr = np.mean(metrics['monthly_FPR_no_rejection'])

                # Add baseline metrics (no rejection)
                row['baseline_f1'] = baseline_f1
                row['baseline_fnr'] = baseline_fnr
                row['baseline_fpr'] = baseline_fpr

                # Add rejection metrics (aggregated)
                row['rejected_f1'] = metrics['F1']
                row['rejected_fnr'] = metrics['FNR']
                row['rejected_fpr'] = metrics['FPR']

                # Compute improvements
                row['improvement_f1'] = metrics['F1'] - baseline_f1
                row['improvement_fnr'] = metrics['FNR'] - baseline_fnr
                row['improvement_fpr'] = metrics['FPR'] - baseline_fpr

                # Add rejection stats
                row['avg_monthly_rejections'] = metrics['avg_monthly_Rejections']
                row['avg_monthly_acceptances'] = metrics['avg_monthly_Acceptances']

                # Compute coverage (fraction of samples accepted)
                total = metrics['avg_monthly_Rejections'] + metrics['avg_monthly_Acceptances']
                row['coverage'] = metrics['avg_monthly_Acceptances'] / total if total > 0 else 1.0

                # Add multi-seed info
                row['is_multiseed'] = group_data['is_multiseed']
                row['num_seeds'] = group_data['num_seeds']
                row['num_months'] = len(group_data['months'])

                all_results.append(row)

        # Step 3: Convert to DataFrame
        results_df = pd.DataFrame(all_results)

        # Sort by group + budget
        sort_cols = group_by + ['rejection_budget']
        results_df = results_df.sort_values(sort_cols).reset_index(drop=True)

        return results_df
