from sklearn.metrics import f1_score, confusion_matrix
import pandas as pd
import numba
from numba import njit, prange
import numpy as np
from numba.core.errors import NumbaError
from numba import njit
import numpy as np
import pandas as pd
import numba
from concurrent.futures import ProcessPoolExecutor
import os
import pickle
import time
import tempfile
import shutil
import random
import logging
from typing import Any, Optional, List, Dict, Tuple
from filelock import FileLock


#############
### Utils ###
#############


def aggregate_df_with_basename(df, get_base_name):
    """
    Aggregates a dataframe with multiindex by applying a base_name function to the 'Method' level
    and filling NaN values within groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        A dataframe with 3-level multiindex ('Monthly-Label-Budget', 'Sampler-Mode', 'Method')
        and multi-column setup with dataset names and metrics
    get_base_name : function
        Function that extracts a base name from the 'Method' column values
    
    Returns:
    --------
    pandas.DataFrame
        Aggregated dataframe with NaN values filled within groups
    """
    # Create a copy of the dataframe to avoid modifying the original
    result_df = df.copy()
    
    # Extract the levels of the multiindex
    idx = result_df.index
    budget = idx.get_level_values(0)
    sampler_mode = idx.get_level_values(1)
    method = idx.get_level_values(2)
    
    # Apply the get_base_name function to the 'Method' level
    base_names = [get_base_name(m) for m in method]
    
    # Create a new 4-level multiindex
    new_index = pd.MultiIndex.from_arrays(
        [budget, sampler_mode, base_names, method],
        names=['Monthly-Label-Budget', 'Sampler-Mode', 'Base-Method', 'Method']
    )
    result_df.index = new_index
    
    # Define a function to fill NaN values within a group
    def fill_group_nans(group):
        # For each column, find non-NaN values and use them to fill NaNs
        for col in group.columns:
            non_nan_values = group[col].dropna().unique()
            # Only fill if there's exactly one non-NaN value in the group
            if len(non_nan_values) == 1:
                group[col] = non_nan_values[0]
            # If all values are NaN, preserve them
            # (no action needed as they're already NaN)
        return group
    
    # Group by the first three levels and apply the fill function
    grouped = result_df.groupby(level=['Monthly-Label-Budget', 'Sampler-Mode', 'Base-Method'])
    filled_df = grouped.apply(fill_group_nans)
    
    # Now we need to aggregate to get one row per unique combination of the first three levels
    # We'll take the first row for each group after filling, but preserve NaNs
    # We'll use a custom aggregation function that preserves NaNs in all-NaN columns
    def preserve_nans_agg(group):
        result = {}
        for col in group.columns:
            if group[col].isna().all():
                # If all values are NaN, keep it as NaN
                result[col] = np.nan
            else:
                # Otherwise take the first non-NaN value
                non_nan_values = group[col].dropna()
                if len(non_nan_values) > 0:
                    result[col] = non_nan_values.iloc[0]
                else:
                    result[col] = np.nan
        return pd.Series(result)
    
    aggregated_df = grouped.apply(preserve_nans_agg)
    
    # Reset the index to get back to the original 3-level structure, but with base names
    final_df = aggregated_df.reset_index(level='Base-Method', drop=False)
    
    # Rename the 'Base-Method' level to 'Method'
    final_df.index = pd.MultiIndex.from_arrays(
        [final_df.index.get_level_values(0), final_df.index.get_level_values(1), final_df['Base-Method']],
        names=['Monthly-Label-Budget', 'Sampler-Mode', 'Method']
    )
    
    # Drop the 'Base-Method' column as it's now part of the index
    final_df = final_df.drop('Base-Method', axis=1)
    
    return final_df


#########################################
#### Robust Pickle Reader and Writer ####
#########################################

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('robust_pickle')

class RobustPickleIO:
    """
    A class for robust reading and writing of pickle files.
    Handles cases where files might be partially written or being accessed by multiple processes.
    
    Features:
    - Atomic writes using temporary files
    - File locking to prevent simultaneous access
    - Retry mechanism for reading incomplete files
    - Backup mechanism to prevent data loss
    """
    
    def __init__(self, max_retries: int = 5, retry_delay: float = 0.5):
        """
        Initialize the RobustPickleIO handler.
        
        Args:
            max_retries: Maximum number of read retries before giving up
            retry_delay: Base delay between retries (will use exponential backoff)
        """
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self._locks = {}  # Cache for file locks
    
    def _get_lock_filename(self, filename: str) -> str:
        """Get the lock file path for a given filename"""
        return f"{filename}.lock"
    
    def _get_backup_filename(self, filename: str) -> str:
        """Get the backup file path for a given filename"""
        return f"{filename}.bak"
    
    def _get_lock(self, filename: str) -> FileLock:
        """Get or create a file lock for the specified file"""
        lock_file = self._get_lock_filename(filename)
        if lock_file not in self._locks:
            self._locks[lock_file] = FileLock(lock_file)
        return self._locks[lock_file]
    
    def _ensure_directory_exists(self, filename: str) -> None:
        """Ensure the directory for the file exists"""
        directory = os.path.dirname(os.path.abspath(filename))
        if not os.path.exists(directory):
            try:
                os.makedirs(directory, exist_ok=True)
                logger.debug(f"Created directory: {directory}")
            except Exception as e:
                logger.warning(f"Failed to create directory {directory}: {str(e)}")
    
    def write_pickle(self, data: Any, filename: str, create_backup: bool = True) -> bool:
        """
        Write data to a pickle file atomically with file locking.
        
        Args:
            data: The Python object to pickle and write
            filename: Target file path
            create_backup: Whether to create a backup of existing data
            
        Returns:
            bool: True if the write was successful
        """
        # Ensure the directory exists
        self._ensure_directory_exists(filename)
        
        lock = self._get_lock(filename)
        
        try:
            with lock:
                # Create backup if requested and file exists
                if create_backup and os.path.exists(filename):
                    backup_file = self._get_backup_filename(filename)
                    try:
                        shutil.copy2(filename, backup_file)
                        logger.debug(f"Created backup at {backup_file}")
                    except Exception as e:
                        logger.warning(f"Failed to create backup: {str(e)}")
                
                # Create a temporary file in the same directory
                directory = os.path.dirname(os.path.abspath(filename))
                try:
                    # Use NamedTemporaryFile instead of mkstemp to avoid file descriptor issues
                    with tempfile.NamedTemporaryFile(dir=directory, 
                                                   prefix=os.path.basename(filename) + '_tmp',
                                                   suffix='.tmp',
                                                   delete=False) as tmp_file:
                        temp_path = tmp_file.name
                        # Write to the temporary file
                        pickle.dump(data, tmp_file)
                        # Flush to ensure data is written to disk
                        tmp_file.flush()
                        os.fsync(tmp_file.fileno())
                    
                    # Atomically replace the destination file
                    if os.name == 'nt':  # Windows
                        if os.path.exists(filename):
                            os.replace(temp_path, filename)
                        else:
                            shutil.move(temp_path, filename)
                    else:  # Unix-like
                        shutil.move(temp_path, filename)
                    
                    logger.debug(f"Successfully wrote data to {filename}")
                    return True
                    
                except Exception as e:
                    # Clean up the temp file if still exists
                    try:
                        if os.path.exists(temp_path):
                            os.unlink(temp_path)
                    except:
                        pass
                    logger.error(f"Error writing pickle file {filename}: {str(e)}")
                    raise
        
        except Exception as e:
            logger.error(f"Failed to acquire lock for {filename}: {str(e)}")
            return False
    
    def append_to_pickle(self, new_items: List[Any], filename: str, 
                          default_empty: Any = None) -> bool:
        """
        Append items to an existing pickle file. If the file doesn't exist,
        initialize it with default_empty (defaults to empty list).
        
        Args:
            new_items: New items to append to the existing data
            filename: Target file path
            default_empty: Default value if file doesn't exist
            
        Returns:
            bool: True if the append was successful
        """
        if default_empty is None:
            default_empty = []
        
        # Ensure the directory exists
        self._ensure_directory_exists(filename)    
        lock = self._get_lock(filename)
        
        try:
            with lock:
                # Read existing data
                existing_data = self.read_pickle(filename, default=default_empty)
                
                if isinstance(existing_data, list) and isinstance(new_items, list):
                    combined_data = existing_data + new_items
                elif isinstance(existing_data, dict) and isinstance(new_items, dict):
                    combined_data = {**existing_data, **new_items}
                else:
                    raise TypeError(f"Cannot append {type(new_items)} to {type(existing_data)}")
                
                # Write combined data
                return self.write_pickle(combined_data, filename)
                
        except Exception as e:
            logger.error(f"Failed to append to pickle file {filename}: {str(e)}")
            return False
    
    def read_pickle(self, filename: str, default: Any = None, 
                     use_backup_on_failure: bool = True) -> Any:
        """
        Read data from a pickle file with retry logic and backup recovery.
        
        Args:
            filename: File to read
            default: Value to return if file doesn't exist
            use_backup_on_failure: Try reading the backup if the main file fails
            
        Returns:
            The unpickled data or default value
        """
        if not os.path.exists(filename):
            logger.debug(f"File {filename} does not exist, returning default")
            return default
        
        lock = self._get_lock(filename)
        
        try:
            with lock:
                for attempt in range(self.max_retries):
                    try:
                        with open(filename, 'rb') as f:
                            data = pickle.load(f)
                        logger.debug(f"Successfully read data from {filename}")
                        return data
                    except (EOFError, pickle.UnpicklingError) as e:
                        # File might be partially written
                        delay = self.retry_delay * (2 ** attempt) + random.uniform(0, 0.1)
                        logger.warning(f"Read attempt {attempt+1} failed: {str(e)}. "
                                      f"Retrying in {delay:.2f}s...")
                        time.sleep(delay)
                
                # All attempts failed, try backup if requested
                if use_backup_on_failure:
                    backup_file = self._get_backup_filename(filename)
                    if os.path.exists(backup_file):
                        logger.info(f"Attempting to read from backup file {backup_file}")
                        try:
                            with open(backup_file, 'rb') as f:
                                data = pickle.load(f)
                            logger.info(f"Successfully recovered data from backup")
                            return data
                        except Exception as e:
                            logger.error(f"Failed to read backup: {str(e)}")
                
                # If we got here, all attempts failed
                logger.error(f"All attempts to read {filename} failed")
                return default
                
        except Exception as e:
            logger.error(f"Error acquiring lock for reading {filename}: {str(e)}")
            # Try without lock as last resort
            try:
                with open(filename, 'rb') as f:
                    return pickle.load(f)
            except:
                return default


#####################
### Metrics Utils ###
#####################


@numba.njit
def compute_metrics_numba(y_true, y_pred):
    """
    Compute F1, FNR, and FPR for binary classification using numba.
    If there are insufficient samples (less than 20) or non-binary labels,
    return (np.nan, np.nan, np.nan).

    Args:
        y_true (np.ndarray): 1D array of ground truth labels (expected 0 or 1)
        y_pred (np.ndarray): 1D array of predicted labels (expected 0 or 1)

    Returns:
        tuple: (f1, fnr, fpr)
    """
    n = y_true.shape[0]
    
    # Check that y_true and y_pred have the same length and enough samples
    if n != y_pred.shape[0] or n < 20:
        return np.nan, np.nan, np.nan

    # Check if all values in y_true are the same.
    same_ytrue = True
    for i in range(1, n):
        if y_true[i] != y_true[0]:
            same_ytrue = False
            break

    # Check if all values in y_pred are the same.
    same_ypred = True
    for i in range(1, n):
        if y_pred[i] != y_pred[0]:
            same_ypred = False
            break

    # If both are constant:
    if same_ytrue and same_ypred:
        # Perfect prediction if they are equal.
        if y_true[0] == y_pred[0]:
            return 1.0, 0.0, 0.0
        else:
            # Complete misclassification: if y_true is 1 then always predicted 0,
            # so FNR (false negative rate) is 1.0; if y_true is 0 then always predicted 1,
            # so FPR (false positive rate) is 1.0.
            if y_true[0] == 1:
                return 0.0, 1.0, 0.0
            elif y_true[0] == 0:
                return 0.0, 0.0, 1.0
            else:
                # In case the label is neither 0 nor 1.
                return np.nan, np.nan, np.nan

    # Initialize confusion matrix counts.
    tn = 0
    fp = 0
    fn = 0
    tp = 0

    for i in range(n):
        # Ensure labels are binary (0 or 1)
        if y_true[i] != 0 and y_true[i] != 1:
            return np.nan, np.nan, np.nan
        if y_pred[i] != 0 and y_pred[i] != 1:
            return np.nan, np.nan, np.nan

        if y_true[i] == 1:
            if y_pred[i] == 1:
                tp += 1
            else:
                fn += 1
        else:  # y_true[i] == 0
            if y_pred[i] == 1:
                fp += 1
            else:
                tn += 1

    # Compute False Negative Rate (FNR)
    if (fn + tp) == 0:
        fnr = 0.0
    else:
        fnr = fn / (fn + tp)

    # Compute False Positive Rate (FPR)
    if (fp + tn) == 0:
        fpr = 0.0
    else:
        fpr = fp / (fp + tn)

    # Compute F1 score.
    # When there are no true positives, define F1 as 0.0.
    if tp == 0:
        f1 = 0.0
    else:
        # Using the formula: F1 = 2*TP / (2*TP + FP + FN)
        f1 = 2.0 * tp / (2 * tp + fp + fn)

    return f1, fnr, fpr



############################
### Risk Coverage Curves ###
############################


# Vectorized quantile computation (as before)
def values_to_quantiles_vectorized(arr):
    """Convert array of values to their quantile distribution using vectorization."""
    arr = np.asarray(arr)
    # Use argsort twice to efficiently compute ranks
    ranks = np.argsort(np.argsort(arr)).astype(float) / (len(arr) - 1)
    
    # Handle the case where all values are the same
    if np.all(arr == arr[0]):
        return np.ones_like(arr, dtype=float)
    
    return ranks

# Numba-accelerated histogram and AUC computation
@numba.jit(nopython=True, parallel=True)
def compute_histograms_and_auc(all_labels, all_predictions, all_uncertainties):
    """
    Compute histograms and AUC (i.e. AURC) using Numba for speed.
    This function bins the quantile-transformed uncertainties, computes cumulative
    correct/incorrect counts, then integrates the risk-coverage curve using the trapezoidal rule.
    """
    n_total = len(all_uncertainties)
    
    # Compute a boolean mask for correct predictions
    correct_mask = np.zeros(n_total, dtype=np.bool_)
    for i in numba.prange(n_total):
        correct_mask[i] = all_labels[i] == all_predictions[i]
    
    # Choose number of bins (e.g. based on the total number of samples)
    n_bins = min(100, max(20, int(np.sqrt(n_total))))
    hist_correct = np.zeros(n_bins, dtype=np.int64)
    hist_incorrect = np.zeros(n_bins, dtype=np.int64)
    
    # Bin width based on quantiles (0 to 1)
    bin_width = 1.0 / n_bins
    for i in range(n_total):
        bin_idx = min(n_bins - 1, int(all_uncertainties[i] / bin_width))
        if correct_mask[i]:
            hist_correct[bin_idx] += 1
        else:
            hist_incorrect[bin_idx] += 1
    
    # Compute cumulative counts per bin
    cum_correct = np.zeros(n_bins, dtype=np.int64)
    cum_incorrect = np.zeros(n_bins, dtype=np.int64)
    cum_correct[0] = hist_correct[0]
    cum_incorrect[0] = hist_incorrect[0]
    for i in range(1, n_bins):
        cum_correct[i] = cum_correct[i - 1] + hist_correct[i]
        cum_incorrect[i] = cum_incorrect[i - 1] + hist_incorrect[i]
    
    # Compute risk and coverage for each bin
    cum_total = np.zeros(n_bins, dtype=np.int64)
    risk = np.zeros(n_bins, dtype=np.float64)
    coverage = np.zeros(n_bins, dtype=np.float64)
    for i in range(n_bins):
        cum_total[i] = cum_correct[i] + cum_incorrect[i]
        coverage[i] = cum_total[i] / n_total
        if cum_total[i] > 0:
            risk[i] = cum_incorrect[i] / cum_total[i]
    
    # Compute the AUC using the trapezoidal rule.
    # Optionally, add a point at (0, 0) if the first coverage is not near zero.
    auc = 0.0
    if coverage[0] > 0.01:
        # Segment from (0,0) to the first point
        auc += (0.0 + risk[0]) * coverage[0] / 2.0
        # Remaining segments
        for i in range(n_bins - 1):
            auc += (risk[i] + risk[i + 1]) * (coverage[i + 1] - coverage[i]) / 2.0
    else:
        for i in range(n_bins - 1):
            auc += (risk[i] + risk[i + 1]) * (coverage[i + 1] - coverage[i]) / 2.0
    
    return auc

def compute_aurc(labels, predictions, uncertainties, compute_eaurc=True):
    """
    Compute the AURC metric or, if requested via the flag, the Excess-AURC (E-AURC).

    Parameters:
      labels: list or array-like of true labels (or list of arrays if using multiple seeds).
      predictions: list or array-like of predicted labels.
      uncertainties: list or array-like of uncertainty scores (or confidence scores).
                     These will be converted to quantiles.
      compute_eaurc (bool): If True, return E-AURC, computed as:
          E-AURC = AURC - (r_hat + (1 - r_hat) * ln(1 - r_hat))
          where r_hat is the overall error rate.
          If False, return the raw AURC.

    Returns:
      The computed metric (AURC or E-AURC).
    """
    # Concatenate inputs (assuming they might be provided as lists or arrays from different seeds)
    all_labels = np.concatenate(labels)
    all_predictions = np.concatenate(predictions)
    all_uncertainties = np.concatenate(uncertainties)
    
    # Convert uncertainty values to quantiles
    all_uncertainties = values_to_quantiles_vectorized(all_uncertainties)
    
    # Compute the raw AURC (area under the risk–coverage curve)
    aurc = compute_histograms_and_auc(all_labels, all_predictions, all_uncertainties)
    
    if compute_eaurc:
        # Compute overall error rate (r_hat)
        r_hat = np.mean(all_labels != all_predictions)
        
        # Compute the optimal AURC using the closed-form approximation:
        # AURC* = r_hat + (1 - r_hat) * ln(1 - r_hat)
        # Handle edge-case where r_hat == 0 (to avoid log(1) == 0, which is fine) 
        # and r_hat == 1 (log(0) would be -inf).
        if r_hat < 1.0:
            optimal_aurc = r_hat + (1 - r_hat) * np.log(1 - r_hat)
        else:
            optimal_aurc = r_hat  # Fallback: if all predictions are wrong
        
        # Excess-AURC is the difference between the computed AURC and the optimal AURC.
        return aurc - optimal_aurc
    else:
        return aurc

def process_dataset_row_multi_seed(
        labels, 
        predictions, 
        uncertainties, 
        is_multiseed, 
        n_jobs=4,
        compute_eaurc=True
    ):
    """
    Process a single row from a dataset and compute AUC for all seeds.
    
    Parameters:
    -----------
    labels, predictions, uncertainties: lists
        These can be either lists of arrays (single seed) or lists of lists of arrays (multi-seed)
    is_multiseed: bool
        Whether the data contains multiple seeds
    n_jobs: int
        Number of parallel jobs for processing seeds
        
    Returns:
    --------
    dict: Dictionary containing AUC statistics (mean, std, min, max, all_values)
    """
    if not is_multiseed:
        # For single seed, just compute one AUC
        auc = compute_aurc(labels, predictions, uncertainties, compute_eaurc)
        return {
            'mean': auc,
            'std': 0.0,
            'min': auc,
            'max': auc,
            'median': auc,
            'all_values': [auc],
            'n_seeds': 1
        }
    
    # For multiple seeds, compute AUC for each seed
    n_seeds = len(labels[0]) if labels else 0  # Get number of seeds from first entry
    
    if n_seeds == 0:
        print(f"Warning: No seeds found in multi-seed data. Structure: labels shape = {np.shape(labels)}")
        return {
            'mean': float('NaN'),
            'std': float('NaN'),
            'min': float('NaN'),
            'max': float('NaN'),
            'median': float('NaN'),
            'all_values': [],
            'n_seeds': 0
        }
    
    # Process all seeds (either in parallel or sequentially)
    auc_values = []
    
    try:
        # For multi-seed data, structure is typically:
        # labels: list of lists, where outer list is data chunks and inner lists are seeds
        # We need to transpose this structure to iterate over seeds
        
        # Prepare data for each seed
        seed_data = []
        for seed_idx in range(n_seeds):
            # Extract data for this seed from all chunks
            seed_labels = [chunk[seed_idx] for chunk in labels if seed_idx < len(chunk)]
            seed_predictions = [chunk[seed_idx] for chunk in predictions if seed_idx < len(chunk)]
            seed_uncertainties = [chunk[seed_idx] for chunk in uncertainties if seed_idx < len(chunk)]
            
            seed_data.append((seed_labels, seed_predictions, seed_uncertainties))
        
        if n_seeds < 3 or n_jobs == 1:
            # Process sequentially for small number of seeds
            for seed_labels, seed_predictions, seed_uncertainties in seed_data:
                auc = compute_aurc(seed_labels, seed_predictions, seed_uncertainties, compute_eaurc)
                auc_values.append(auc)
        else:
            # Process in parallel for larger number of seeds
            with ProcessPoolExecutor(max_workers=min(n_jobs, n_seeds)) as executor:
                futures = []
                
                for seed_labels, seed_predictions, seed_uncertainties in seed_data:
                    futures.append(executor.submit(
                        compute_aurc,
                        seed_labels,
                        seed_predictions,
                        seed_uncertainties,
                        compute_eaurc
                    ))
                
                # Collect results
                for future in futures:
                    try:
                        auc = future.result()
                        auc_values.append(auc)
                    except Exception as e:
                        print(f"Error processing seed: {e}")
                        # Don't add NaN values, just skip this seed
                
        # Compute statistics if we have any valid values
        if auc_values:
            auc_array = np.array(auc_values)
            return {
                'mean': np.mean(auc_array),
                'std': np.std(auc_array),
                'min': np.min(auc_array),
                'max': np.max(auc_array),
                'median': np.median(auc_array),
                'all_values': auc_values,
                'n_seeds': len(auc_values)
            }
        else:
            print("Warning: No valid AUC values calculated for any seed")
            return {
                'mean': float('NaN'),
                'std': float('NaN'),
                'min': float('NaN'),
                'max': float('NaN'),
                'median': float('NaN'),
                'all_values': [],
                'n_seeds': 0
            }
            
    except Exception as e:
        print(f"Error in multi-seed processing: {e}")
        import traceback
        traceback.print_exc()
        return {
            'mean': float('NaN'),
            'std': float('NaN'),
            'min': float('NaN'),
            'max': float('NaN'),
            'median': float('NaN'),
            'all_values': [],
            'n_seeds': 0
        }

def compute_risk_coverage_aurc_multi_seed(
        dataset_dict, 
        method_name_mpr=None,
        n_jobs=4, 
        return_all_seeds=False, 
        compute_eaurc=True,
        debug=False
    ):
    """
    Compute AUC values for risk-coverage curves across all seeds.
    
    Parameters:
    -----------
    dataset_dict: dict
        Dictionary where keys are dataset names and values are DataFrames.
    method_name_mpr: function, optional
        A function to postprocess method names for display.
    n_jobs: int, default=4
        Number of workers for parallel processing.
    return_all_seeds: bool, default=False
        If True, return all individual seed values as a separate DataFrame.
    debug: bool, default=False
        If True, print extra debugging information.
        
    Returns:
    --------
    dict: Dictionary of DataFrames containing different AUC statistics:
        - 'mean': DataFrame with mean AUC values
        - 'std': DataFrame with standard deviation of AUC values
        - 'min', 'max', 'median': min/max/median AUC values
        - 'n_seeds': number of seeds per experiment
        - 'all_values': all individual seed values (if return_all_seeds=True)
    """
    # Find a reference DataFrame to understand the index structure
    reference_df = None
    for df in dataset_dict.values():
        if not df.empty:
            reference_df = df
            break
    
    if reference_df is None:
        print("Warning: All datasets are empty")
        empty_result = {'mean': pd.DataFrame(), 'std': pd.DataFrame(), 
                        'min': pd.DataFrame(), 'max': pd.DataFrame(), 
                        'median': pd.DataFrame(), 'n_seeds': pd.DataFrame()}
        if return_all_seeds:
            empty_result['all_values'] = pd.DataFrame()
        return empty_result
    
    # Determine if we have a MultiIndex
    has_multiindex = reference_df.index.nlevels > 1
    index_names = reference_df.index.names
    
    # Collect all unique indices
    all_tuples = set()
    for df in dataset_dict.values():
        if df.index.nlevels > 1:
            # For MultiIndex, convert to tuples
            all_tuples.update([tuple(x) for x in df.index])
        else:
            # For single-level index
            all_tuples.update(df.index)
    
    # Sort tuples for consistent ordering
    all_tuples_sorted = sorted(all_tuples)
    
    # Create dicts to store results
    results_mean = {dataset: {} for dataset in dataset_dict}
    results_std = {dataset: {} for dataset in dataset_dict}
    results_min = {dataset: {} for dataset in dataset_dict}
    results_max = {dataset: {} for dataset in dataset_dict}
    results_median = {dataset: {} for dataset in dataset_dict}
    results_n_seeds = {dataset: {} for dataset in dataset_dict}
    
    if return_all_seeds:
        # For storing all individual seed values
        results_all = {dataset: {} for dataset in dataset_dict}
    
    # Process each dataset
    for dataset_name, df in dataset_dict.items():
        # Skip empty datasets
        if df.empty:
            if debug:
                print(f"Skipping empty dataset: {dataset_name}")
            continue
            
        if debug:
            print(f"Processing dataset: {dataset_name} with {len(df)} rows")
        
        # Determine parallelism based on dataset size
        if len(df) < 5 or n_jobs == 1:
            # Process sequentially for small datasets
            for idx, row in df.iterrows():
                if debug:
                    is_multiseed = row.get("is_multiseed", False)
                    labels_shape = np.shape(row["Labels"])
                    print(f"  Row {idx}: is_multiseed={is_multiseed}, labels shape={labels_shape}")
                
                try:
                    result = process_dataset_row_multi_seed(
                        row["Labels"],
                        row["Predictions"],
                        row["Uncertainties (Month Ahead)"],
                        row["is_multiseed"],
                        n_jobs=n_jobs,
                        compute_eaurc=compute_eaurc
                    )
                    
                    if debug and result['n_seeds'] > 0:
                        print(f"    Success: {result['n_seeds']} seeds, mean AUC = {result['mean']:.4f}")
                    
                    results_mean[dataset_name][idx] = result['mean']
                    results_std[dataset_name][idx] = result['std']
                    results_min[dataset_name][idx] = result['min']
                    results_max[dataset_name][idx] = result['max']
                    results_median[dataset_name][idx] = result['median']
                    results_n_seeds[dataset_name][idx] = result['n_seeds']
                    
                    if return_all_seeds:
                        results_all[dataset_name][idx] = result['all_values']
                        
                except Exception as e:
                    print(f"Error processing {dataset_name}, index {idx}: {e}")
                    import traceback
                    traceback.print_exc()
                    
                    results_mean[dataset_name][idx] = float('NaN')
                    results_std[dataset_name][idx] = float('NaN')
                    results_min[dataset_name][idx] = float('NaN')
                    results_max[dataset_name][idx] = float('NaN')
                    results_median[dataset_name][idx] = float('NaN')
                    results_n_seeds[dataset_name][idx] = 0
                    
                    if return_all_seeds:
                        results_all[dataset_name][idx] = []
        else:
            # Process in parallel for larger datasets
            with ProcessPoolExecutor(max_workers=n_jobs) as executor:
                futures = {}
                for idx, row in df.iterrows():
                    futures[idx] = executor.submit(
                        process_dataset_row_multi_seed,
                        row["Labels"],
                        row["Predictions"],
                        row["Uncertainties (Month Ahead)"],
                        row["is_multiseed"],
                        n_jobs=1,  # Use sequential processing for seeds within each parallel task
                        compute_eaurc=compute_eaurc
                    )
                
                # Collect results
                for idx, future in futures.items():
                    try:
                        result = future.result()
                        
                        if debug and result['n_seeds'] > 0:
                            print(f"  Row {idx}: {result['n_seeds']} seeds, mean AUC = {result['mean']:.4f}")
                        
                        results_mean[dataset_name][idx] = result['mean']
                        results_std[dataset_name][idx] = result['std']
                        results_min[dataset_name][idx] = result['min']
                        results_max[dataset_name][idx] = result['max']
                        results_median[dataset_name][idx] = result['median']
                        results_n_seeds[dataset_name][idx] = result['n_seeds']
                        
                        if return_all_seeds:
                            results_all[dataset_name][idx] = result['all_values']
                    except Exception as e:
                        print(f"Error processing {dataset_name}, index {idx}: {e}")
                        
                        results_mean[dataset_name][idx] = float('NaN')
                        results_std[dataset_name][idx] = float('NaN')
                        results_min[dataset_name][idx] = float('NaN')
                        results_max[dataset_name][idx] = float('NaN')
                        results_median[dataset_name][idx] = float('NaN')
                        results_n_seeds[dataset_name][idx] = 0
                        
                        if return_all_seeds:
                            results_all[dataset_name][idx] = []
    
    # Prepare data for DataFrame creation
    data_mean = {}
    data_std = {}
    data_min = {}
    data_max = {}
    data_median = {}
    data_n_seeds = {}
    
    for dataset in dataset_dict:
        data_mean[dataset] = [results_mean[dataset].get(idx, float('NaN')) for idx in all_tuples_sorted]
        data_std[dataset] = [results_std[dataset].get(idx, float('NaN')) for idx in all_tuples_sorted]
        data_min[dataset] = [results_min[dataset].get(idx, float('NaN')) for idx in all_tuples_sorted]
        data_max[dataset] = [results_max[dataset].get(idx, float('NaN')) for idx in all_tuples_sorted]
        data_median[dataset] = [results_median[dataset].get(idx, float('NaN')) for idx in all_tuples_sorted]
        data_n_seeds[dataset] = [results_n_seeds[dataset].get(idx, 0) for idx in all_tuples_sorted]
    
    # Create the appropriate index structure
    if has_multiindex:
        # Create a proper MultiIndex with the correct names
        multi_idx = pd.MultiIndex.from_tuples(all_tuples_sorted, names=index_names)
        auc_df_mean = pd.DataFrame(data_mean, index=multi_idx)
        auc_df_std = pd.DataFrame(data_std, index=multi_idx)
        auc_df_min = pd.DataFrame(data_min, index=multi_idx)
        auc_df_max = pd.DataFrame(data_max, index=multi_idx)
        auc_df_median = pd.DataFrame(data_median, index=multi_idx)
        auc_df_n_seeds = pd.DataFrame(data_n_seeds, index=multi_idx)
        
        if return_all_seeds:
            data_all = {}
            for dataset in dataset_dict:
                data_all[dataset] = [results_all[dataset].get(idx, []) for idx in all_tuples_sorted]
            auc_df_all = pd.DataFrame(data_all, index=multi_idx)
    else:
        # For regular index
        auc_df_mean = pd.DataFrame(data_mean, index=all_tuples_sorted)
        auc_df_std = pd.DataFrame(data_std, index=all_tuples_sorted)
        auc_df_min = pd.DataFrame(data_min, index=all_tuples_sorted)
        auc_df_max = pd.DataFrame(data_max, index=all_tuples_sorted)
        auc_df_median = pd.DataFrame(data_median, index=all_tuples_sorted)
        auc_df_n_seeds = pd.DataFrame(data_n_seeds, index=all_tuples_sorted)
        
        if return_all_seeds:
            data_all = {}
            for dataset in dataset_dict:
                data_all[dataset] = [results_all[dataset].get(idx, []) for idx in all_tuples_sorted]
            auc_df_all = pd.DataFrame(data_all, index=all_tuples_sorted)
            
        if reference_df.index.name:
            auc_df_mean.index.name = reference_df.index.name
            auc_df_std.index.name = reference_df.index.name
            auc_df_min.index.name = reference_df.index.name
            auc_df_max.index.name = reference_df.index.name
            auc_df_median.index.name = reference_df.index.name
            auc_df_n_seeds.index.name = reference_df.index.name
            
            if return_all_seeds:
                auc_df_all.index.name = reference_df.index.name
    
    # Fill any NaN values
    auc_df_mean = auc_df_mean.fillna(float('NaN'))
    auc_df_std = auc_df_std.fillna(float('NaN'))
    auc_df_min = auc_df_min.fillna(float('NaN'))
    auc_df_max = auc_df_max.fillna(float('NaN'))
    auc_df_median = auc_df_median.fillna(float('NaN'))
    auc_df_n_seeds = auc_df_n_seeds.fillna(0)
    
    # Create a dictionary to store all the DataFrames
    results_dict = {
        'mean': auc_df_mean,
        'std': auc_df_std,
        'min': auc_df_min,
        'max': auc_df_max,
        'median': auc_df_median,
        'n_seeds': auc_df_n_seeds
    }
    
    if return_all_seeds:
        results_dict['all_values'] = auc_df_all
        
    return results_dict

def create_auc_summary_table_multi_seed(auc_results, include_mean=True, include_rank=True, sort_by_mean=True, sort_by_budget=True):
    """
    Create a formatted summary table from AUC values with optional mean and ranking.
    
    Parameters:
    - auc_results: dict
        Dictionary containing DataFrames with different statistics (mean, std, etc.)
    - include_mean: bool, default=True
        Whether to include a column with mean AUC across datasets.
    - include_rank: bool, default=True
        Whether to include a column with overall ranking based on mean AUC.
    - sort_by_mean: bool, default=True
        Whether to sort the table by mean AUC values (lowest/best first) within each budget group.
    - sort_by_budget: bool, default=True
        Whether to sort the table by monthly label budget first.
        
    Returns:
    - summary_df: pandas DataFrame
        Formatted summary table with AUC values and optional statistics.
    """
    # Create a copy to avoid modifying the original
    auc_df = auc_results['mean'].copy()
    std_df = auc_results['std'].copy()
    
    # Add mean column if requested
    if include_mean:
        auc_df['Mean AUC'] = auc_df.mean(axis=1)
        std_df['Mean AUC'] = std_df.mean(axis=1)
    
    # Add rank column if requested (lower AUC is better)
    if include_rank and include_mean:
        if auc_df.index.nlevels == 3 and sort_by_budget:
            # For each budget group, compute ranks separately
            budget_level_name = auc_df.index.names[0]
            auc_df['Rank'] = auc_df.groupby(level=budget_level_name)['Mean AUC'].rank(method='min')
        else:
            # Overall ranking across all methods
            auc_df['Rank'] = auc_df['Mean AUC'].rank(method='min')
    
    # Sort by budget first, then by mean AUC if requested
    if auc_df.index.nlevels == 3:
        if sort_by_budget and sort_by_mean and include_mean:
            # Sort by budget first, then by mean AUC within each budget group
            budget_level_name = auc_df.index.names[0]
            auc_df = auc_df.sort_values(
                by=[budget_level_name, 'Mean AUC'], 
                ascending=[True, True]
            )
            std_df = std_df.loc[auc_df.index]  # Reorder std_df to match
        elif sort_by_budget:
            # Sort by budget only
            budget_level_name = auc_df.index.names[0]
            auc_df = auc_df.sort_values(by=budget_level_name)
            std_df = std_df.loc[auc_df.index]  # Reorder std_df to match
        elif sort_by_mean and include_mean:
            # Sort by mean AUC only
            auc_df = auc_df.sort_values('Mean AUC')
            std_df = std_df.loc[auc_df.index]  # Reorder std_df to match
    else:
        # Fallback for non-three-level indices
        if sort_by_mean and include_mean:
            auc_df = auc_df.sort_values('Mean AUC')
            std_df = std_df.loc[auc_df.index]  # Reorder std_df to match
    
    # Format the values with mean ± std
    summary_df = pd.DataFrame(index=auc_df.index, columns=auc_df.columns)
    
    for col in auc_df.columns:
        if col != 'Rank':
            # Format as "mean ± std" for each value
            for idx in auc_df.index:
                mean_val = auc_df.loc[idx, col]
                std_val = std_df.loc[idx, col]
                
                if pd.isna(mean_val) or pd.isna(std_val):
                    summary_df.loc[idx, col] = "N/A"
                else:
                    # Format with appropriate precision
                    summary_df.loc[idx, col] = f"{mean_val:.4f} ± {std_val:.4f}"
    
    # Convert ranks to integers where applicable
    if include_rank and include_mean:
        summary_df['Rank'] = auc_df['Rank'].apply(
            lambda x: str(int(float(x))) if not pd.isna(x) else "N/A"
        )

    return summary_df



########################################
### Rejector and Optimised Rejectior ###
########################################


@njit(parallel=True)
def recover_positive_softmax(classifier_uncertainty: np.ndarray, predictions: np.ndarray) -> np.ndarray:
    """
    Recover the softmax probability for the positive class (class 1)

    Args:
        classifier_uncertainty: Uncertainty values calculated as 1 - (np.abs(max_probs - 0.5) / 0.5)
        predictions: Binary predictions array (0 or 1 for each sample)

    Returns:
        np.ndarray: Softmax probabilities for the positive class
    """
    # Check that the predictions are binary in fact
    assert np.all(np.isin(predictions, [0, 1])), "Predictions must be binary (0 or 1)!"
    max_probs = 0.5 * (2 - classifier_uncertainty)
    positive_softmax = np.where(predictions == 1, max_probs, 1 - max_probs)
    return positive_softmax



@njit
def binary_classification_rejection_thresholds_arbitrary_scale(
    agg_scores_binary: np.ndarray, 
    total_to_reject: int,
) -> tuple[float, float]:
    """
    Compute lower and upper thresholds for rejection in binary classification.
    
    Args:
        agg_scores_binary: Array of binary classification scores 
            (arbitrary real values where lower values indicate certainty for one class
             and higher values indicate certainty for the other class)
        total_to_reject: Number of samples to reject
        
    Returns:
        tuple[float, float]: (lower_threshold, upper_threshold)
            For rejection if score in (lower_threshold, upper_threshold)
            Returns (max value + 1, min value - 1) if total_to_reject <= 0 (reject nothing)
            Returns (max value + 1, min value - 1) if total_to_reject > len(scores) (invalid input)
            Normal case returns thresholds where lower_threshold < upper_threshold
    
    Note:
        Works with arbitrary real values (including negative numbers).
        Values most uncertain/ambiguous are closest to the median of the scores.
    """
    n_samples = len(agg_scores_binary)
    
    # Handle edge cases with values that ensure no rejection
    if total_to_reject <= 0 or total_to_reject > n_samples:
        max_val = agg_scores_binary[0]
        min_val = agg_scores_binary[0]
        for k in range(1, n_samples):
            if agg_scores_binary[k] > max_val:
                max_val = agg_scores_binary[k]
            if agg_scores_binary[k] < min_val:
                min_val = agg_scores_binary[k]
        return max_val + 1.0, min_val - 1.0  # Return impossible range (reject nothing)
    
    # Find median value as decision boundary
    sorted_vals = np.sort(agg_scores_binary)
    if n_samples % 2 == 0:
        median = (sorted_vals[n_samples//2 - 1] + sorted_vals[n_samples//2]) / 2.0
    else:
        median = sorted_vals[n_samples//2]
    
    # Compute distances from median (most ambiguous point)
    distances = np.empty(n_samples, dtype=np.float64)
    for k in range(n_samples):
        distances[k] = np.abs(agg_scores_binary[k] - median)
    
    # Get indices of most uncertain predictions (closest to median)
    sorted_indices = np.argsort(distances)
    
    # Initialize threshold values with extremes
    lower_cut = agg_scores_binary[sorted_indices[0]]
    upper_cut = lower_cut
    
    # Find thresholds based on most uncertain predictions
    for k in range(total_to_reject):
        val = agg_scores_binary[sorted_indices[k]]
        if val < lower_cut:
            lower_cut = val
        if val > upper_cut:
            upper_cut = val
    
    return lower_cut, upper_cut

@njit
def binary_classification_rejection_thresholds(
    agg_scores_binary: np.ndarray, 
    total_to_reject: int,
) -> tuple[float, float]:
    """
    Compute lower and upper thresholds for rejection in binary classification.
    
    Args:
        agg_scores_binary: Array of binary classification scores (probabilities for one class)
        total_to_reject: Number of samples to reject
        
    Returns:
        tuple[float, float]: (lower_threshold, upper_threshold)
            For rejection if score <= lower_threshold or score >= upper_threshold
            Returns (1.0, 0.0) if total_to_reject <= 0 (reject nothing)
            Returns (1.0, 0.0) if total_to_reject > len(scores) (invalid input)
            Normal case returns thresholds where lower_threshold < upper_threshold
    
    Note:
        Assumes binary classification scores of either the positive class or negative class
        (i.e. softmax scores with respect to one class only).
        Scores closest to 0.5 are considered most uncertain.
    """
    n_samples = len(agg_scores_binary)
    
    # Handle edge cases
    if total_to_reject <= 0 or total_to_reject > n_samples:
        return 1.0, 0.0  # Return impossible range (reject nothing)
    
    # Compute distances from decision boundary (0.5)
    distances = np.empty(n_samples, dtype=np.float64)
    for k in range(n_samples):
        distances[k] = np.abs(agg_scores_binary[k] - 0.5)
    
    # Get indices of most uncertain predictions (closest to 0.5)
    sorted_indices = np.argsort(distances)
    
    # Initialize threshold values
    lower_cut = 1.0
    upper_cut = 0.0
    
    # Find thresholds based on most uncertain predictions
    for k in range(total_to_reject):
        val = agg_scores_binary[sorted_indices[k]]
        if val < lower_cut:
            lower_cut = val
        if val > upper_cut:
            upper_cut = val
    
    return lower_cut, upper_cut



@njit
def single_value_uncertainty_threshold(
    agg_scores_single: np.ndarray, 
    total_to_reject: int,
) -> float:
    """
    Compute uncertainty threshold for rejection based on single-value uncertainty scores.
    
    Args:
        agg_scores_single: Array of uncertainty scores where higher values indicate more uncertainty.
                          Values must be non-negative (>= 0, where 0 indicates highest certainty).
        total_to_reject: Number of samples to reject
        
    Returns:
        float: Threshold value above which samples should be rejected
              Returns np.inf if total_to_reject <= 0 (reject nothing)
              Returns -np.inf if total_to_reject >= len(scores) (reject everything)
    
    Raises:
        NumbaError: If any score is negative
    
    Note:
        Assumes single-value uncertainty scores where a lower value is more certain
        and a higher value is more uncertain (e.g., HCC's pseudo-loss).
        Zero values are allowed and indicate maximum certainty.
    """
    '''
    # Check for negative values
    for i in range(n_samples):
        if agg_scores_single[i] < 0:
            raise NumbaError("All uncertainty scores must be non-negative!")
    '''
    n_samples = len(agg_scores_single)
    # Handle edge cases
    if total_to_reject <= 0:
        return np.inf  # Reject nothing
    
    if total_to_reject >= n_samples:
        return -np.inf  # Reject everything
    
    # Sort scores from large (most uncertain) to small (most certain)
    sorted_scores = np.argsort(agg_scores_single)[::-1]
    
    # Get the cutoff index (subtract 1 since we want the threshold to be inclusive)
    cutoff_idx = sorted_scores[total_to_reject - 1]
    
    # Return the threshold value
    return agg_scores_single[cutoff_idx]




@njit(parallel=True)
def _PostHocRejectorSimulator_Refactor(
        uncertainties,  # List of np.ndarray (one per month)
        predictions,    # List of np.ndarray (one per month)
        labels,         # List of np.ndarray (one per month)
        rejection_Ns,   # 1D np.ndarray of int64 rejection quotas
        upto_reject,    # Boolean flag to limit rejections per month
        method          # String indicating which method to use
):
    """
    Numba–friendly simulator.
    
    For each rejection quota (rej_N) and for each “month” (data array),
    the rejection logic is applied according to the chosen method:
    
      - "single_thresh_simple": A single uncertainty threshold (from month 0)
      - "dual_thresh_simple": Two thresholds computed from top uncertain samples
      - "dual_thresh_compounded": Thresholds based on an aggregation across months
      
    Instead of computing the metrics on a per‐month basis and then averaging, this
    version accumulates the accepted labels and predictions across months and computes
    aggregated metrics conclusively. In addition, per‐month metrics (F1, FNR, FPR) are stored
    in an array (one row per month) and the number of rejections and acceptances per month is collected.
    
    Returns:
        A tuple of three objects:
          - result: a 2D np.ndarray of shape (len(rejection_Ns), 8) with columns:
              [aggregated_F1, aggregated_FNR, aggregated_FPR, avg_monthly_Rejections,
               avg_monthly_F1, avg_monthly_FNR, avg_monthly_FPR, avg_monthly_Acceptances]
          - month_rejections: a 1D np.ndarray of shape (n_months,) with per-month rejection counts
          - month_acceptances: a 1D np.ndarray of shape (n_months,) with per-month acceptance counts
    """

    n_months = len(uncertainties)
    n_rej = rejection_Ns.shape[0]
    result = np.empty((n_rej, 16, n_months), dtype=np.float64)

    # Outer loop over each rejection quota (parallelized)
    for j in prange(n_rej):
        rej_N = rejection_Ns[j]
        
        # For dual_thresh_compounded, we now preallocate an aggregated_scores array.
        # (We compute total_possible as the sum of n_samples over months.)
        total_possible = 0
        for Mi in range(n_months):
            total_possible += uncertainties[Mi].shape[0]
        
        aggregated_scores = np.empty(total_possible, dtype=np.float64)
        agg_score_index = 0  # running index into aggregated_scores

        # --- Prepare accumulators for aggregated accepted samples ---
        agg_labels = np.empty(total_possible, dtype=labels[0].dtype)
        agg_preds = np.empty(total_possible, dtype=predictions[0].dtype)
        agg_index = 0

        # Arrays to store per-month metrics and counts.
        month_metrics = np.empty((n_months, 6), dtype=np.float64)  # columns: F1, FNR, FPR per month
        month_rejections = np.empty(n_months, dtype=np.int64)
        month_acceptances = np.empty(n_months, dtype=np.int64)

        # Process each month.
        for Mi in range(n_months):
            curr_uncert = uncertainties[Mi]
            curr_preds = predictions[Mi]
            curr_labels = labels[Mi]
            n_samples = curr_uncert.shape[0]

            # Prepare reject_mask (using Numpy’s boolean type)
            reject_mask = np.zeros(n_samples, dtype=np.bool_)


            if method == "single_thresh_simple":
                if Mi == 0:
                    # In the first month calibrated the threshold for the first time
                    uncert_cutoff = single_value_uncertainty_threshold(curr_uncert, rej_N)
                    for k in range(n_samples):
                        if curr_uncert[k] > uncert_cutoff:
                            reject_mask[k] = True
                else:
                    rejection_counter = 0
                    for k in range(n_samples):
                        if curr_uncert[k] > uncert_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break
                    
                    # Recalibrate the threshold for the next month
                    uncert_cutoff = single_value_uncertainty_threshold(curr_uncert, rej_N)

            elif method == "single_thresh_compounded":
                if Mi == 0:
                    total_to_reject = rej_N
                    # Set the first elements of the aggregated_scores array
                    for k in range(n_samples):
                        aggregated_scores[k] = curr_uncert[k]
                        agg_score_index += 1
                    
                    # In the first month calibrated the threshold for the first time
                    cutoff_point = single_value_uncertainty_threshold(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    # Set the accept and the reject masks
                    for k in range(n_samples):
                        if curr_uncert[k] > cutoff_point:
                            reject_mask[k] = True
                else:
                    total_to_reject = Mi * rej_N
                    cutoff_point = single_value_uncertainty_threshold(
                        aggregated_scores[:agg_score_index], 
                        total_to_reject
                    )
                    for k in range(n_samples):
                        if curr_uncert[k] > cutoff_point:
                            reject_mask[k] = True
                    
                    prev_agg_score_index = agg_score_index
                    for k in range(n_samples):
                        aggregated_scores[prev_agg_score_index + k] = curr_uncert[k]
                        agg_score_index += 1

            elif method == "dual_thresh_simple":
                recovered_scores = recover_positive_softmax(curr_uncert, curr_preds)
                if Mi == 0:
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        recovered_scores, rej_N
                    )
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                else:
                    rejection_counter = 0
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        recovered_scores, rej_N
                    )

            elif method == "dual_thresh_compounded":
                recovered_scores = recover_positive_softmax(curr_uncert, curr_preds)
                if Mi == 0:
                    total_to_reject = rej_N
                    # Set the first elements of the aggregated_scores array
                    for k in range(n_samples):
                        aggregated_scores[k] = recovered_scores[k]
                        agg_score_index += 1

                    # In the first month calibrated the threshold for the first time
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    # Set the accept and the reject masks
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True

                else:
                    total_to_reject = Mi * rej_N
                    lower_cutoff, upper_cutoff = binary_classification_rejection_thresholds(
                        aggregated_scores[:agg_score_index], total_to_reject
                    )
                    rejection_counter = 0
                    for k in range(n_samples):
                        if recovered_scores[k] > lower_cutoff and recovered_scores[k] < upper_cutoff:
                            reject_mask[k] = True
                            rejection_counter += 1
                            if upto_reject and rejection_counter >= rej_N:
                                break

                    prev_agg_score_index = agg_score_index
                    for k in range(n_samples):
                        aggregated_scores[prev_agg_score_index + k] = recovered_scores[k]
                        agg_score_index += 1
                    
            ###########################
            # Other stuff and metrics #
            ###########################

            # Build the accepted mask (the complement of reject_mask)
            accept_mask = np.empty(n_samples, dtype=np.bool_)
            for k in range(n_samples):
                accept_mask[k] = not reject_mask[k]

           # Count accepted samples and build temporary arrays for filtering.
            accepted_count = 0
            for k in range(n_samples):
                if accept_mask[k]:
                    accepted_count += 1
            accepted_labels = np.empty(accepted_count, dtype=curr_labels.dtype)
            accepted_preds = np.empty(accepted_count, dtype=curr_preds.dtype)
            pos = 0
            for k in range(n_samples):
                if accept_mask[k]:
                    accepted_labels[pos] = curr_labels[k]
                    accepted_preds[pos] = curr_preds[k]
                    pos += 1

            # --- Compute per-month metrics and counts ---

            f1_baseline, fnr_baseline, fpr_baseline = compute_metrics_numba(curr_labels, curr_preds)
            f1, fnr, fpr = compute_metrics_numba(accepted_labels, accepted_preds)
            month_metrics[Mi, 0] = f1_baseline
            month_metrics[Mi, 1] = fnr_baseline
            month_metrics[Mi, 2] = fpr_baseline
            month_metrics[Mi, 3] = f1
            month_metrics[Mi, 4] = fnr
            month_metrics[Mi, 5] = fpr

            # Count the number of rejections.
            rej_count = 0
            for k in range(n_samples):
                if reject_mask[k]:
                    rej_count += 1
            month_rejections[Mi] = rej_count
            month_acceptances[Mi] = accepted_count
            
            # --- Accumulate accepted samples for aggregated metrics ---
            for k in range(accepted_count):
                agg_labels[agg_index] = accepted_labels[k]
                agg_preds[agg_index] = accepted_preds[k]
                agg_index += 1


        # --- Compute aggregated metrics across all months ---
        if agg_index > 0:
            agg_f1, agg_fnr, agg_fpr = compute_metrics_numba(agg_labels[:agg_index], agg_preds[:agg_index])
        else:
            agg_f1 = 0.0
            agg_fnr = 0.0
            agg_fpr = 0.0
        
        # Compute average per-month metrics and counts.
        sum_f1 = 0.0
        sum_fnr = 0.0
        sum_fpr = 0.0
        sum_rej = 0.0
        sum_acc = 0.0
        for Mi in range(n_months):
            sum_f1 += month_metrics[Mi, 0]
            sum_fnr += month_metrics[Mi, 1]
            sum_fpr += month_metrics[Mi, 2]
            sum_rej += month_rejections[Mi]
            sum_acc += month_acceptances[Mi]
        
        avg_month_f1 = sum_f1 / n_months
        avg_month_fnr = sum_fnr / n_months
        avg_month_fpr = sum_fpr / n_months
        avg_month_rej = sum_rej / n_months
        avg_month_acc = sum_acc / n_months


        # --- Populate the result for this rejection quota ---
        result[j, 0] = np.repeat(agg_f1, n_months)
        result[j, 1] = np.repeat(agg_fnr, n_months)
        result[j, 2] = np.repeat(agg_fpr, n_months)
        result[j, 3] = np.repeat(avg_month_rej, n_months)
        result[j, 4] = np.repeat(avg_month_f1, n_months)
        result[j, 5] = np.repeat(avg_month_fnr, n_months)
        result[j, 6] = np.repeat(avg_month_fpr, n_months)
        result[j, 7] = np.repeat(avg_month_acc, n_months)
        result[j, 8] = month_rejections
        result[j, 9] = month_acceptances

        result[j, 10] = month_metrics[:, 0]
        result[j, 11] = month_metrics[:, 1]
        result[j, 12] = month_metrics[:, 2]

        result[j, 13] = month_metrics[:, 3]
        result[j, 14] = month_metrics[:, 4]
        result[j, 15] = month_metrics[:, 5]

    
    return result




def PostHocRejectorSimulator(uncertainties, predictions, labels, rejection_Ns, upto_reject, method):
    """
    Outer wrapper that calls the Numba–accelerated PostHocRejectorSimulator
    and returns the results as a dictionary in the following format:

        result[rej_N] = {
            "aggregated_F1": ...,
            "aggregated_FNR": ...,
            "aggregated_FPR": ...,
            "avg_monthly_Rejections": ...,
            "avg_monthly_F1": ...,
            "avg_monthly_FNR": ...,
            "avg_monthly_FPR": ...,
            "avg_monthly_Acceptances": ...,
            "monthly_Rejections": <array>,
            "monthly_Acceptances": <array>
        }
    
    Args:
        uncertainties: List (or numba.typed.List) of np.ndarray uncertainty values (one per month)
        predictions: List (or numba.typed.List) of np.ndarray predictions (one per month)
        labels: List (or numba.typed.List) of np.ndarray ground truth labels (one per month)
        rejection_Ns: Iterable or np.ndarray of rejection quotas (e.g. [10, 20, 30, ...])
        upto_reject: Boolean flag indicating whether to limit the number of rejections per month
        method: A string indicating which rejection method to use.
        
    Returns:
        dict: A dictionary mapping each rejection quota (rej_N) to its metrics dictionary.
              Additionally, the monthly rejections and acceptances arrays are returned.
    """
    # Ensure rejection_Ns is a NumPy array
    rejection_Ns_arr = np.array(rejection_Ns, dtype=np.int64)
    
    # Call the jitted function.
    result_array = _PostHocRejectorSimulator_Refactor(
        uncertainties, predictions, labels,
        rejection_Ns_arr, upto_reject, method
    )
    
    # Create the dictionary output.
    result = {}
    for idx, rej_N in enumerate(rejection_Ns_arr):
        result[int(rej_N)] = {
            "F1": result_array[idx, 0][0],
            "FNR": result_array[idx, 1][0],
            "FPR": result_array[idx, 2][0],

            "avg_monthly_Rejections": result_array[idx, 3][0],
            "avg_monthly_F1": result_array[idx, 4][0],
            "avg_monthly_FNR": result_array[idx, 5][0],
            "avg_monthly_FPR": result_array[idx, 6][0],
            "avg_monthly_Acceptances": result_array[idx, 7][0],

            "monthly_Rejections": result_array[idx, 8],
            "monthly_Acceptances": result_array[idx, 9],
            "monthly_Total": result_array[idx, 8] + result_array[idx, 9],
            "monthly_Rejections": result_array[idx, 8],
            "monthly_Acceptances": result_array[idx, 9],

            "monthly_F1_no_rejection": result_array[idx, 10],
            "monthly_FNR_no_rejection": result_array[idx, 11],
            "monthly_FPR_no_rejection": result_array[idx, 12],

            "monthly_F1": result_array[idx, 13],
            "monthly_FNR": result_array[idx, 14],
            "monthly_FPR": result_array[idx, 15],
        }
    
    return result




