"""
Core metrics computation for Aurora framework

This module provides the fundamental metrics used in Aurora evaluation:
- Reliability: F1, FNR, FPR (traditional performance metrics)
- Stability: CV[F1], CV[FNR], CV[FPR] (temporal consistency)
- Drawdown: Min[F1], Max[F1] (worst/best case performance)

All metrics are computed from predictions and labels arrays.
"""

from typing import List, Dict, Tuple, Union, Optional
import numpy as np
import numpy.typing as npt

# Import the optimized metrics computation from tools
# This is the same function used throughout the codebase
try:
    from tools import compute_metrics_numba
except ImportError:
    # Fallback implementation if tools.py not available
    def compute_metrics_numba(labels: npt.NDArray, predictions: npt.NDArray) -> Tuple[float, float, float]:
        """
        Fallback implementation of metrics computation

        Returns: (F1, FNR, FPR)
        """
        # True positives, false positives, true negatives, false negatives
        tp = np.sum((labels == 1) & (predictions == 1))
        fp = np.sum((labels == 0) & (predictions == 1))
        tn = np.sum((labels == 0) & (predictions == 0))
        fn = np.sum((labels == 1) & (predictions == 0))

        # F1 score
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        # False Negative Rate (missed malware)
        fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0

        # False Positive Rate (false alarms)
        fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0

        return f1, fnr, fpr


class ReliabilityMetrics:
    """
    Compute reliability metrics (F1, FNR, FPR)

    These are traditional classification performance metrics that measure
    how well the classifier performs overall.

    Metrics:
    - F1: Harmonic mean of precision and recall (higher is better)
    - FNR (False Negative Rate): Proportion of malware samples missed (lower is better)
    - FPR (False Positive Rate): Proportion of benign samples flagged as malware (lower is better)

    Example:
        >>> predictions = [np.array([0, 1, 0, 1]), np.array([1, 1, 0, 0])]
        >>> labels = [np.array([0, 1, 1, 1]), np.array([1, 0, 0, 1])]
        >>>
        >>> # Aggregated across all months
        >>> metrics = ReliabilityMetrics.compute_aggregated(predictions, labels)
        >>> print(f"F1: {metrics['F1']:.3f}")
        >>>
        >>> # Per-month for stability analysis
        >>> monthly = ReliabilityMetrics.compute_per_month(predictions, labels)
        >>> print(f"Monthly F1: {monthly['F1']}")
    """

    @staticmethod
    def compute_aggregated(
        predictions: Union[List[npt.NDArray], npt.NDArray],
        labels: Union[List[npt.NDArray], npt.NDArray]
    ) -> Dict[str, float]:
        """
        Compute reliability metrics by aggregating all predictions/labels

        This combines all months into one large dataset and computes metrics once.
        This is the preferred method for overall performance assessment.

        Args:
            predictions: List of prediction arrays (one per month) or single array
            labels: List of label arrays (one per month) or single array

        Returns:
            Dictionary with keys:
            - 'F1': F1 score [0-1]
            - 'FNR': False Negative Rate [0-1]
            - 'FPR': False Positive Rate [0-1]

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0])]
            >>> metrics = ReliabilityMetrics.compute_aggregated(preds, labs)
            >>> metrics
            {'F1': 0.8, 'FNR': 0.25, 'FPR': 0.333...}
        """
        # Handle both list and single array inputs
        if isinstance(predictions, list):
            all_preds = np.concatenate(predictions)
            all_labels = np.concatenate(labels)
        else:
            all_preds = predictions
            all_labels = labels

        # Validate inputs
        if len(all_preds) != len(all_labels):
            raise ValueError(f"Predictions ({len(all_preds)}) and labels ({len(all_labels)}) must have same length")

        if len(all_preds) == 0:
            raise ValueError("Cannot compute metrics on empty arrays")

        # Compute using optimized numba function
        f1, fnr, fpr = compute_metrics_numba(all_labels, all_preds)

        return {
            'F1': float(f1),
            'FNR': float(fnr),
            'FPR': float(fpr)
        }

    @staticmethod
    def compute_per_month(
        predictions: List[npt.NDArray],
        labels: List[npt.NDArray]
    ) -> Dict[str, List[float]]:
        """
        Compute reliability metrics for each month separately

        This computes metrics month-by-month, which is needed for:
        - Stability analysis (computing CV)
        - Drawdown analysis (finding min/max)
        - Temporal trend analysis

        Args:
            predictions: List of prediction arrays (one per month)
            labels: List of label arrays (one per month)

        Returns:
            Dictionary with keys:
            - 'F1': List of monthly F1 scores
            - 'FNR': List of monthly FNR values
            - 'FPR': List of monthly FPR values

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0])]
            >>> monthly = ReliabilityMetrics.compute_per_month(preds, labs)
            >>> len(monthly['F1'])
            2
        """
        if len(predictions) != len(labels):
            raise ValueError(f"Number of prediction arrays ({len(predictions)}) must match labels ({len(labels)})")

        if len(predictions) == 0:
            raise ValueError("Cannot compute metrics on empty list")

        monthly_f1 = []
        monthly_fnr = []
        monthly_fpr = []

        for preds, labs in zip(predictions, labels):
            if len(preds) != len(labs):
                raise ValueError(f"Prediction length {len(preds)} != label length {len(labs)}")

            if len(preds) == 0:
                # Handle empty month (shouldn't happen, but be defensive)
                monthly_f1.append(0.0)
                monthly_fnr.append(0.0)
                monthly_fpr.append(0.0)
                continue

            f1, fnr, fpr = compute_metrics_numba(labs, preds)
            monthly_f1.append(float(f1))
            monthly_fnr.append(float(fnr))
            monthly_fpr.append(float(fpr))

        return {
            'F1': monthly_f1,
            'FNR': monthly_fnr,
            'FPR': monthly_fpr
        }

    @staticmethod
    def compute_average(
        predictions: List[npt.NDArray],
        labels: List[npt.NDArray]
    ) -> Dict[str, float]:
        """
        Compute reliability metrics by averaging per-month metrics

        Alternative to compute_aggregated() that:
        1. Computes metrics for each month
        2. Takes the mean across months

        Note: This can give different results than compute_aggregated()!
        - compute_aggregated(): Weights each sample equally
        - compute_average(): Weights each month equally

        Use compute_aggregated() unless you specifically need equal month weighting.

        Args:
            predictions: List of prediction arrays (one per month)
            labels: List of label arrays (one per month)

        Returns:
            Dictionary with averaged metrics

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0])]
            >>> avg = ReliabilityMetrics.compute_average(preds, labs)
            >>> # This averages the two monthly F1 scores
        """
        monthly = ReliabilityMetrics.compute_per_month(predictions, labels)

        return {
            'F1': float(np.mean(monthly['F1'])),
            'FNR': float(np.mean(monthly['FNR'])),
            'FPR': float(np.mean(monthly['FPR']))
        }


class StabilityMetrics:
    """
    Compute stability metrics for temporal consistency analysis.

    This class provides multiple metrics to measure temporal stability:

    **Dispersion Metrics** (lower is better):
    - CV[F1]: Coefficient of Variation = std / mean (conflates with performance level)
    - σ[F1]: Standard deviation (absolute variability)
    - MAD[F1]: Median Absolute Deviation (robust to outliers)
    - IQR[F1]: Interquartile Range (spread without outliers)

    **Temporal Metrics** (capture trend/structure):
    - Mann-Kendall τ: Monotonic trend (-1 to +1, negative = degrading)
    - ACF(1): Autocorrelation at lag 1 (high = clustered good/bad periods)

    **Tail Risk Metrics** (higher is better):
    - Min[F1]: Worst-case performance
    - 5th percentile: Robust worst-case

    Interpretation:
    - CV = 0.05 → Very stable (5% variation)
    - CV = 0.10 → Stable (10% variation)
    - CV = 0.20 → Moderately unstable (20% variation)
    - CV = 0.30+ → Unstable (30%+ variation)

    Lower CV/σ/MAD/IQR = more stable = more trustworthy in production!

    Key Insight:
        CV[F1] alone is insufficient because:
        1. It conflates stability with performance level (high F1 → low CV by construction)
        2. It ignores temporal structure (random noise vs. monotonic decline have same CV)
        3. It's sensitive to outliers

        The paper frames stability as an *operational concern*. Operators ask:
        1. "Will performance collapse?" → Min[F1], 5th percentile
        2. "Is it getting worse over time?" → Mann-Kendall τ
        3. "How much variance should I expect?" → σ[F1] or IQR[F1]

    Example:
        >>> monthly_f1 = [0.85, 0.87, 0.84, 0.86, 0.88]
        >>> suite = StabilityMetrics.compute_stability_suite(monthly_f1)
        >>> print(f"CV[F1]: {suite['cv']:.3f}")
        >>> print(f"σ[F1]: {suite['sigma']:.3f}")
        >>> print(f"Mann-Kendall τ: {suite['mann_kendall_tau']:.3f}")
    """

    @staticmethod
    def coefficient_of_variation(values: Union[List[float], npt.NDArray]) -> float:
        """
        Compute coefficient of variation: CV = std / mean

        Args:
            values: List or array of metric values (e.g., monthly F1 scores)

        Returns:
            Coefficient of variation (0.0 if mean is 0 or all values identical)

        Example:
            >>> values = [0.9, 0.9, 0.9]  # Perfectly stable
            >>> StabilityMetrics.coefficient_of_variation(values)
            0.0
            >>>
            >>> values = [0.5, 0.9, 0.6]  # High variation
            >>> cv = StabilityMetrics.coefficient_of_variation(values)
            >>> cv > 0.2
            True
        """
        arr = np.array(values)

        if len(arr) == 0:
            raise ValueError("Cannot compute CV on empty array")

        mean = np.mean(arr)

        # If mean is 0 or very close to 0, CV is undefined
        # Return 0.0 to indicate "no variation" (though technically undefined)
        if np.abs(mean) < 1e-10:
            return 0.0

        std = np.std(arr)

        # If std is 0, all values are identical → perfect stability
        if std < 1e-10:
            return 0.0

        return float(std / mean)

    @staticmethod
    def compute_cv_metrics(
        monthly_metrics: Dict[str, List[float]]
    ) -> Dict[str, float]:
        """
        Compute CV for F1, FNR, and FPR

        Args:
            monthly_metrics: Dictionary with keys 'F1', 'FNR', 'FPR'
                            and values as lists of monthly scores

        Returns:
            Dictionary with keys 'CV[F1]', 'CV[FNR]', 'CV[FPR]'

        Example:
            >>> monthly = {
            ...     'F1': [0.85, 0.87, 0.84],
            ...     'FNR': [0.10, 0.12, 0.11],
            ...     'FPR': [0.05, 0.04, 0.06]
            ... }
            >>> cv_metrics = StabilityMetrics.compute_cv_metrics(monthly)
            >>> cv_metrics
            {'CV[F1]': 0.017..., 'CV[FNR]': 0.088..., 'CV[FPR]': 0.2...}
        """
        required_keys = ['F1', 'FNR', 'FPR']
        for key in required_keys:
            if key not in monthly_metrics:
                raise ValueError(f"Missing required key: {key}")

        return {
            'CV[F1]': StabilityMetrics.coefficient_of_variation(monthly_metrics['F1']),
            'CV[FNR]': StabilityMetrics.coefficient_of_variation(monthly_metrics['FNR']),
            'CV[FPR]': StabilityMetrics.coefficient_of_variation(monthly_metrics['FPR']),
        }

    @staticmethod
    def compute_from_predictions(
        predictions: List[npt.NDArray],
        labels: List[npt.NDArray]
    ) -> Dict[str, float]:
        """
        Convenience method: Compute CV directly from predictions/labels

        This combines:
        1. ReliabilityMetrics.compute_per_month()
        2. StabilityMetrics.compute_cv_metrics()

        Args:
            predictions: List of prediction arrays (one per month)
            labels: List of label arrays (one per month)

        Returns:
            Dictionary with CV metrics

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1]), np.array([1, 1, 1])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0]), np.array([1, 1, 0])]
            >>> cv = StabilityMetrics.compute_from_predictions(preds, labs)
            >>> 'CV[F1]' in cv
            True
        """
        monthly_metrics = ReliabilityMetrics.compute_per_month(predictions, labels)
        return StabilityMetrics.compute_cv_metrics(monthly_metrics)

    @staticmethod
    def standard_deviation(values: Union[List[float], npt.NDArray]) -> float:
        """
        Compute standard deviation (σ) - absolute variability measure.

        Unlike CV, this doesn't depend on the mean, so it doesn't conflate
        stability with performance level.

        Args:
            values: List or array of metric values (e.g., monthly F1 scores)

        Returns:
            Standard deviation (0.0 if all values identical)

        Example:
            >>> values = [0.85, 0.87, 0.84, 0.86, 0.88]
            >>> sigma = StabilityMetrics.standard_deviation(values)
            >>> print(f"σ[F1]: {sigma:.4f}")  # ~0.0158
        """
        arr = np.array(values)
        if len(arr) == 0:
            raise ValueError("Cannot compute std on empty array")
        return float(np.std(arr))

    @staticmethod
    def median_absolute_deviation(values: Union[List[float], npt.NDArray]) -> float:
        """
        Compute Median Absolute Deviation (MAD) - robust variability measure.

        MAD = median(|x_i - median(x)|)

        More robust to outliers than standard deviation. A single bad month
        won't skew MAD as much as it would std.

        Args:
            values: List or array of metric values

        Returns:
            MAD value (0.0 if all values identical)

        Example:
            >>> values = [0.85, 0.87, 0.84, 0.86, 0.88]
            >>> mad = StabilityMetrics.median_absolute_deviation(values)
            >>> print(f"MAD[F1]: {mad:.4f}")  # ~0.01
        """
        arr = np.array(values)
        if len(arr) == 0:
            raise ValueError("Cannot compute MAD on empty array")
        median = np.median(arr)
        return float(np.median(np.abs(arr - median)))

    @staticmethod
    def interquartile_range(values: Union[List[float], npt.NDArray]) -> float:
        """
        Compute Interquartile Range (IQR) - spread without outliers.

        IQR = Q3 - Q1 (75th percentile - 25th percentile)

        Robust measure of spread that ignores extreme values.

        Args:
            values: List or array of metric values

        Returns:
            IQR value (0.0 if insufficient data)

        Example:
            >>> values = [0.85, 0.87, 0.84, 0.86, 0.88, 0.83, 0.89]
            >>> iqr = StabilityMetrics.interquartile_range(values)
            >>> print(f"IQR[F1]: {iqr:.4f}")
        """
        arr = np.array(values)
        if len(arr) < 4:
            # Need at least 4 values for meaningful IQR
            return float(np.std(arr)) if len(arr) > 0 else 0.0
        q1, q3 = np.percentile(arr, [25, 75])
        return float(q3 - q1)

    @staticmethod
    def percentile(values: Union[List[float], npt.NDArray], p: float = 5.0) -> float:
        """
        Compute percentile value (default: 5th percentile for robust worst-case).

        The 5th percentile is more robust than Min[F1] for worst-case analysis,
        as it's not affected by a single anomalous bad month.

        Args:
            values: List or array of metric values
            p: Percentile to compute (default 5.0 for 5th percentile)

        Returns:
            The p-th percentile value

        Example:
            >>> values = [0.85, 0.87, 0.84, 0.86, 0.88, 0.30]  # One outlier
            >>> min_val = np.min(values)  # 0.30 (affected by outlier)
            >>> p5 = StabilityMetrics.percentile(values, 5)  # ~0.32 (robust)
        """
        arr = np.array(values)
        if len(arr) == 0:
            raise ValueError("Cannot compute percentile on empty array")
        return float(np.percentile(arr, p))

    @staticmethod
    def mann_kendall_tau(values: Union[List[float], npt.NDArray]) -> Tuple[float, float]:
        """
        Compute Mann-Kendall trend test statistic (τ) and p-value.

        Detects monotonic trends in time series data. Unlike CV/std, this
        captures whether performance is systematically degrading over time.

        τ ranges from -1 to +1:
        - τ > 0: Increasing trend (improving over time)
        - τ < 0: Decreasing trend (degrading over time)
        - τ ≈ 0: No monotonic trend

        p-value indicates statistical significance:
        - p < 0.05: Significant trend
        - p >= 0.05: No significant trend

        This is a non-parametric test that doesn't assume normality.

        Args:
            values: List or array of metric values in temporal order

        Returns:
            Tuple of (tau, p_value)

        Example:
            >>> # Degrading performance
            >>> values = [0.90, 0.88, 0.85, 0.82, 0.80]
            >>> tau, p = StabilityMetrics.mann_kendall_tau(values)
            >>> print(f"τ: {tau:.3f}, p: {p:.3f}")  # τ ≈ -1.0, p < 0.05

            >>> # Random variation (no trend)
            >>> values = [0.85, 0.88, 0.84, 0.87, 0.86]
            >>> tau, p = StabilityMetrics.mann_kendall_tau(values)
            >>> print(f"τ: {tau:.3f}, p: {p:.3f}")  # τ ≈ 0, p > 0.05
        """
        arr = np.array(values)
        n = len(arr)

        if n < 3:
            return 0.0, 1.0  # Not enough data for trend test

        # Count concordant and discordant pairs
        s = 0
        for i in range(n - 1):
            for j in range(i + 1, n):
                diff = arr[j] - arr[i]
                if diff > 0:
                    s += 1
                elif diff < 0:
                    s -= 1
                # ties (diff == 0) don't contribute

        # Kendall's tau
        n_pairs = n * (n - 1) / 2
        tau = s / n_pairs if n_pairs > 0 else 0.0

        # Compute p-value using normal approximation
        # Variance of S under null hypothesis (no trend)
        var_s = n * (n - 1) * (2 * n + 5) / 18

        # Handle ties in variance calculation (simplified)
        # For exact tie handling, we'd need to count tied groups

        # Z-score
        if var_s > 0:
            if s > 0:
                z = (s - 1) / np.sqrt(var_s)
            elif s < 0:
                z = (s + 1) / np.sqrt(var_s)
            else:
                z = 0.0

            # Two-tailed p-value from standard normal
            # Using approximation: erfc(|z|/sqrt(2))
            from math import erf
            p_value = 2 * (1 - 0.5 * (1 + erf(abs(z) / np.sqrt(2))))
        else:
            p_value = 1.0

        return float(tau), float(p_value)

    @staticmethod
    def autocorrelation_lag1(values: Union[List[float], npt.NDArray]) -> float:
        """
        Compute autocorrelation at lag 1 (ACF(1)).

        Measures month-to-month persistence. High ACF(1) indicates that
        good/bad periods tend to cluster (if F1 is high this month,
        it's likely high next month too).

        ACF(1) ranges from -1 to +1:
        - ACF(1) > 0: Positive persistence (clustered good/bad periods)
        - ACF(1) < 0: Alternating pattern
        - ACF(1) ≈ 0: No serial dependence

        Args:
            values: List or array of metric values in temporal order

        Returns:
            Autocorrelation at lag 1

        Example:
            >>> # Clustered good/bad months
            >>> values = [0.85, 0.86, 0.87, 0.70, 0.71, 0.72]
            >>> acf1 = StabilityMetrics.autocorrelation_lag1(values)
            >>> acf1 > 0.5  # High positive autocorrelation
            True
        """
        arr = np.array(values)
        n = len(arr)

        if n < 3:
            return 0.0  # Not enough data

        mean = np.mean(arr)
        var = np.var(arr)

        if var < 1e-10:
            return 0.0  # No variation

        # Autocorrelation at lag 1
        acf = np.sum((arr[:-1] - mean) * (arr[1:] - mean)) / ((n - 1) * var)

        return float(acf)

    @staticmethod
    def compute_stability_suite(
        values: Union[List[float], npt.NDArray],
        include_percentiles: bool = True
    ) -> Dict[str, float]:
        """
        Compute full stability metrics suite for a single metric (e.g., F1).

        This provides a comprehensive stability analysis that addresses
        the limitations of CV[F1] alone.

        Args:
            values: List or array of metric values in temporal order
            include_percentiles: Whether to include 5th/95th percentiles

        Returns:
            Dictionary with all stability metrics:
            - 'cv': Coefficient of variation
            - 'sigma': Standard deviation
            - 'mad': Median absolute deviation
            - 'iqr': Interquartile range
            - 'mann_kendall_tau': Trend statistic (-1 to +1)
            - 'mann_kendall_p': P-value for trend test
            - 'acf_lag1': Autocorrelation at lag 1
            - 'min': Minimum value
            - 'max': Maximum value
            - 'p5': 5th percentile (if include_percentiles=True)
            - 'p95': 95th percentile (if include_percentiles=True)

        Example:
            >>> monthly_f1 = [0.85, 0.87, 0.84, 0.86, 0.88, 0.83, 0.89]
            >>> suite = StabilityMetrics.compute_stability_suite(monthly_f1)
            >>> print(f"CV: {suite['cv']:.3f}")
            >>> print(f"σ: {suite['sigma']:.3f}")
            >>> print(f"Mann-Kendall τ: {suite['mann_kendall_tau']:.3f}")
            >>> if suite['mann_kendall_p'] < 0.05:
            ...     print("Significant trend detected!")
        """
        arr = np.array(values)

        if len(arr) == 0:
            raise ValueError("Cannot compute stability suite on empty array")

        tau, p_value = StabilityMetrics.mann_kendall_tau(arr)

        result = {
            'cv': StabilityMetrics.coefficient_of_variation(arr),
            'sigma': StabilityMetrics.standard_deviation(arr),
            'mad': StabilityMetrics.median_absolute_deviation(arr),
            'iqr': StabilityMetrics.interquartile_range(arr),
            'mann_kendall_tau': tau,
            'mann_kendall_p': p_value,
            'acf_lag1': StabilityMetrics.autocorrelation_lag1(arr),
            'min': float(np.min(arr)),
            'max': float(np.max(arr)),
        }

        if include_percentiles:
            result['p5'] = StabilityMetrics.percentile(arr, 5)
            result['p95'] = StabilityMetrics.percentile(arr, 95)

        return result

    @staticmethod
    def compute_stability_suite_from_predictions(
        predictions: List[npt.NDArray],
        labels: List[npt.NDArray],
        metric: str = 'F1'
    ) -> Dict[str, float]:
        """
        Convenience method: Compute full stability suite from predictions/labels.

        Args:
            predictions: List of prediction arrays (one per month)
            labels: List of label arrays (one per month)
            metric: Which metric to analyze ('F1', 'FNR', or 'FPR')

        Returns:
            Full stability suite for the specified metric

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1]), ...]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0]), ...]
            >>> suite = StabilityMetrics.compute_stability_suite_from_predictions(
            ...     preds, labs, metric='F1'
            ... )
            >>> print(f"CV[F1]: {suite['cv']:.3f}")
            >>> print(f"σ[F1]: {suite['sigma']:.3f}")
        """
        if metric not in ['F1', 'FNR', 'FPR']:
            raise ValueError(f"Metric must be 'F1', 'FNR', or 'FPR', got: {metric}")

        monthly_metrics = ReliabilityMetrics.compute_per_month(predictions, labels)
        return StabilityMetrics.compute_stability_suite(monthly_metrics[metric])


class DrawdownMetrics:
    """
    Compute drawdown metrics (worst-case and best-case performance)

    Drawdown analysis identifies:
    - Min[F1]: Worst monthly performance (critical for risk assessment)
    - Max[F1]: Best monthly performance (upside potential)
    - Max Drawdown: Largest performance drop (vs baseline or target)

    Use cases:
    - Risk assessment: What's the worst-case performance?
    - Rejection analysis: How much does performance degrade with rejection?
    - SLA planning: Can we guarantee performance above threshold?

    Example:
        >>> monthly_f1 = [0.85, 0.90, 0.82, 0.88]
        >>> minmax = DrawdownMetrics.compute_minmax(monthly_f1)
        >>> print(f"Worst case: {minmax['Min[F1]']:.2f}")  # 0.82
        >>> print(f"Best case: {minmax['Max[F1]']:.2f}")   # 0.90
    """

    @staticmethod
    def compute_minmax(monthly_values: Union[List[float], npt.NDArray]) -> Dict[str, float]:
        """
        Compute minimum and maximum values

        Typically used with F1 scores to find worst/best months.

        Args:
            monthly_values: List or array of monthly metric values

        Returns:
            Dictionary with 'Min[F1]' and 'Max[F1]'
            (or Min/Max of whatever metric provided)

        Example:
            >>> monthly_f1 = [0.85, 0.90, 0.82, 0.88]
            >>> minmax = DrawdownMetrics.compute_minmax(monthly_f1)
            >>> minmax['Min[F1]']
            0.82
            >>> minmax['Max[F1]']
            0.90
        """
        arr = np.array(monthly_values)

        if len(arr) == 0:
            raise ValueError("Cannot compute min/max on empty array")

        return {
            'Min[F1]': float(np.min(arr)),
            'Max[F1]': float(np.max(arr))
        }

    @staticmethod
    def compute_max_drawdown(
        monthly_values: Union[List[float], npt.NDArray],
        baseline_values: Optional[Union[List[float], npt.NDArray]] = None,
        target_value: Optional[float] = None
    ) -> float:
        """
        Compute maximum drawdown (worst degradation)

        Three modes:
        1. baseline_values provided: max(baseline - values) for each month
        2. target_value provided: max(target - values) for each month
        3. neither provided: max(max_value - values) (drawdown from peak)

        Args:
            monthly_values: Values to analyze (e.g., F1 with rejection)
            baseline_values: Optional baseline (e.g., F1 without rejection)
            target_value: Optional target threshold

        Returns:
            Maximum drawdown (positive value = degradation)

        Example:
            >>> # Drawdown vs baseline
            >>> with_rejection = [0.85, 0.80, 0.82]
            >>> without_rejection = [0.90, 0.88, 0.89]
            >>> drawdown = DrawdownMetrics.compute_max_drawdown(
            ...     with_rejection,
            ...     baseline_values=without_rejection
            ... )
            >>> drawdown  # Largest drop: 0.88 - 0.80 = 0.08
            0.08
            >>>
            >>> # Drawdown vs target
            >>> values = [0.85, 0.80, 0.82]
            >>> drawdown = DrawdownMetrics.compute_max_drawdown(
            ...     values,
            ...     target_value=0.90
            ... )
            >>> drawdown  # Largest gap: 0.90 - 0.80 = 0.10
            0.10
            >>>
            >>> # Drawdown from peak
            >>> values = [0.90, 0.85, 0.92, 0.80]
            >>> drawdown = DrawdownMetrics.compute_max_drawdown(values)
            >>> drawdown  # Drop from peak 0.92 to 0.80 = 0.12
            0.12
        """
        arr = np.array(monthly_values)

        if len(arr) == 0:
            raise ValueError("Cannot compute drawdown on empty array")

        # Mode 1: Drawdown vs baseline
        if baseline_values is not None:
            baseline_arr = np.array(baseline_values)
            if len(baseline_arr) != len(arr):
                raise ValueError(f"Baseline length {len(baseline_arr)} != values length {len(arr)}")

            differences = baseline_arr - arr
            # Only consider positive differences (degradations: baseline > current)
            positive_diffs = differences[differences > 0]

            if len(positive_diffs) == 0:
                return 0.0  # No degradation

            return float(np.max(positive_diffs))  # Largest positive = worst drawdown

        # Mode 2: Drawdown vs target
        if target_value is not None:
            differences = target_value - arr
            max_diff = np.max(differences)
            return float(max(0.0, max_diff))  # Only positive differences count

        # Mode 3: Drawdown from peak
        peak = np.max(arr)
        differences = peak - arr
        return float(np.max(differences))

    @staticmethod
    def compute_from_predictions(
        predictions: List[npt.NDArray],
        labels: List[npt.NDArray],
        baseline_predictions: Optional[List[npt.NDArray]] = None,
        metric: str = 'F1'
    ) -> Dict[str, float]:
        """
        Convenience method: Compute drawdown metrics from predictions/labels

        Args:
            predictions: Prediction arrays
            labels: Label arrays
            baseline_predictions: Optional baseline predictions (e.g., without rejection)
            metric: Which metric to analyze ('F1', 'FNR', or 'FPR')

        Returns:
            Dictionary with 'Min[F1]', 'Max[F1]', and optionally 'Max Drawdown'

        Example:
            >>> preds = [np.array([1, 1, 0]), np.array([1, 0, 1])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0])]
            >>> drawdown = DrawdownMetrics.compute_from_predictions(preds, labs)
            >>> 'Min[F1]' in drawdown
            True
        """
        if metric not in ['F1', 'FNR', 'FPR']:
            raise ValueError(f"Metric must be 'F1', 'FNR', or 'FPR', got: {metric}")

        # Compute monthly metrics
        monthly_metrics = ReliabilityMetrics.compute_per_month(predictions, labels)
        monthly_values = monthly_metrics[metric]

        # Compute min/max
        result = DrawdownMetrics.compute_minmax(monthly_values)

        # Compute max drawdown if baseline provided
        if baseline_predictions is not None:
            baseline_monthly = ReliabilityMetrics.compute_per_month(baseline_predictions, labels)
            baseline_values = baseline_monthly[metric]

            max_drawdown = DrawdownMetrics.compute_max_drawdown(
                monthly_values,
                baseline_values=baseline_values
            )
            result['Max Drawdown'] = max_drawdown

        return result


class SelectiveClassificationMetrics:
    """
    Metrics for selective classification (rejection/abstention)

    These are the "starred" metrics from the paper that measure
    performance when the classifier can reject uncertain samples.

    The paper evaluates selective classification by:
    1. Simulating rejection across multiple budgets (e.g., 100, 200, ..., 1500)
    2. Computing metrics at each budget (F1*, MAPD*, MD[F1]*)
    3. Aggregating across budgets using AUC/Mean/Median

    Key metrics:
    - **F1***: F1 score with selective classification
    - **MAPD***: Mean Absolute Percentage Deviation from target rejection rate
    - **MD[F1]***: Maximum Drawdown in F1 vs baseline
    - **CV[F1]***: Coefficient of Variation under rejection
    """

    @staticmethod
    def compute_mapd(
        monthly_rejections: Union[List[float], npt.NDArray],
        target_rejection_rate: float
    ) -> float:
        """
        Compute Mean Absolute Percentage Deviation (MAPD)

        Measures how consistently the rejection mechanism hits the target rate.

        MAPD = mean(|actual - target| / target) × 100

        Lower values indicate more consistent rejection behavior.
        Higher values reveal greater variability or systematic bias.

        Args:
            monthly_rejections: Actual number of rejections per month
            target_rejection_rate: Target number of rejections per month

        Returns:
            MAPD value (percentage)

        Raises:
            ValueError: If target rate is non-positive or array is empty

        Example:
            >>> monthly_rej = [95, 105, 98, 102]
            >>> target = 100
            >>> mapd = SelectiveClassificationMetrics.compute_mapd(monthly_rej, target)
            >>> mapd
            3.5  # Average 3.5% deviation from target

            >>> # Perfect consistency
            >>> monthly_rej = [100, 100, 100]
            >>> mapd = SelectiveClassificationMetrics.compute_mapd(monthly_rej, 100)
            >>> mapd
            0.0

            >>> # High variability
            >>> monthly_rej = [50, 150, 75, 125]
            >>> mapd = SelectiveClassificationMetrics.compute_mapd(monthly_rej, 100)
            >>> mapd
            37.5  # High deviation from target
        """
        arr = np.array(monthly_rejections)

        if target_rejection_rate <= 0:
            raise ValueError("Target rejection rate must be positive")

        if len(arr) == 0:
            raise ValueError("Cannot compute MAPD on empty array")

        deviations = np.abs(arr - target_rejection_rate)
        mapd = 100 * np.mean(deviations / target_rejection_rate)

        return float(mapd)

    @staticmethod
    def compute_benefit_fraction(
        f1_baseline: Union[List[float], npt.NDArray],
        f1_with_rejection: Union[List[float], npt.NDArray]
    ) -> float:
        """
        Compute Benefit Fraction (BF): percentage of SC's total impact that's positive.

        BF = Sum(positive ΔF1) / Sum(|ΔF1|) × 100%

        where ΔF1_i = F1_with_rejection_i - F1_baseline_i

        This metric captures BOTH frequency AND magnitude of SC's effect:
        - BF = 100%: SC always helps, never hurts
        - BF = 80%: 80% of SC's total impact is positive
        - BF = 50%: SC helps and hurts equally (break-even)
        - BF = 20%: SC mostly hurts (80% negative impact)
        - BF = 0%: SC always hurts, never helps

        Args:
            f1_baseline: Monthly F1 values WITHOUT rejection
            f1_with_rejection: Monthly F1 values WITH rejection

        Returns:
            Benefit Fraction as percentage (0-100)
            Returns 50.0 if total movement is zero (no effect)

        Raises:
            ValueError: If arrays are empty or different lengths

        Example:
            >>> baseline = [0.80, 0.82, 0.78, 0.81, 0.79]
            >>> with_rej = [0.85, 0.84, 0.75, 0.82, 0.77]
            >>> # ΔF1 = [+0.05, +0.02, -0.03, +0.01, -0.02]
            >>> # Gains = 0.05 + 0.02 + 0.01 = 0.08
            >>> # Losses = 0.03 + 0.02 = 0.05
            >>> # Total = 0.08 + 0.05 = 0.13
            >>> # BF = 0.08 / 0.13 = 61.5%
            >>> bf = SelectiveClassificationMetrics.compute_benefit_fraction(baseline, with_rej)
            >>> print(f"BF = {bf:.1f}%")  # BF = 61.5%
        """
        baseline = np.array(f1_baseline)
        with_rej = np.array(f1_with_rejection)

        if len(baseline) == 0 or len(with_rej) == 0:
            raise ValueError("Cannot compute Benefit Fraction on empty arrays")

        if len(baseline) != len(with_rej):
            raise ValueError(
                f"Arrays must have same length: baseline={len(baseline)}, with_rejection={len(with_rej)}"
            )

        # Compute per-month differences
        delta = with_rej - baseline

        # Sum of positive changes (gains)
        total_gain = float(np.sum(np.maximum(delta, 0)))

        # Sum of absolute changes (total movement)
        total_movement = float(np.sum(np.abs(delta)))

        # Handle edge case: no change at all
        if total_movement == 0:
            return 50.0  # Neutral - no effect

        return (total_gain / total_movement) * 100

    @staticmethod
    def aggregate_across_budgets(
        budget_values: Dict[int, float],
        method: str = "Median"
    ) -> float:
        """
        Aggregate metric values across rejection budgets

        The paper computes metrics at multiple rejection budgets (e.g., 100, 200, ..., 1500)
        and then aggregates them into a single value for the table.

        Args:
            budget_values: Dict mapping budget -> metric value
            method: Aggregation method:
                - "AUC": Area under curve (trapezoidal rule, normalized by budget range)
                - "Mean": Arithmetic mean across budgets
                - "Median": Median value across budgets
                - "Min": Minimum value across budgets (used for Max Drawdown in paper)

        Returns:
            Aggregated value

        Raises:
            ValueError: If budget_values is empty or method is unknown

        Example:
            >>> values = {100: 75.0, 200: 78.0, 400: 82.0, 800: 85.0}
            >>> SelectiveClassificationMetrics.aggregate_across_budgets(values, "Mean")
            80.0

            >>> # AUC gives area under curve normalized by budget range
            >>> SelectiveClassificationMetrics.aggregate_across_budgets(values, "AUC")
            79.64...  # Approximate

            >>> # Median for robustness to outliers
            >>> SelectiveClassificationMetrics.aggregate_across_budgets(values, "Median")
            80.0
        """
        if not budget_values:
            raise ValueError("Cannot aggregate empty dict")

        budgets = sorted(budget_values.keys())
        values = [budget_values[b] for b in budgets]

        if method == "AUC":
            # Trapezoidal rule
            auc = np.trapz(values, x=budgets)
            # Normalize by budget range
            return float(auc / (budgets[-1] - budgets[0]))
        elif method == "Mean":
            return float(np.mean(values))
        elif method == "Median":
            return float(np.median(values))
        elif method == "Min":
            # Minimum value across budgets
            return float(np.min(values))
        elif method == "Max":
            # Maximum value across budgets (used for true worst-case Max Drawdown)
            return float(np.max(values))
        else:
            raise ValueError(f"Unknown aggregation method: {method}. Must be 'AUC', 'Mean', 'Median', 'Min', or 'Max'")

    @staticmethod
    def compute_from_rejection_results(
        rejection_results: Dict[int, Dict],
        baseline_f1: Optional[float] = None,
        use_monthly_f1: bool = False,
        metrics: Optional[List[str]] = None,
        aggregation_method: str = "Median"
    ) -> Dict[str, float]:
        """
        Compute selective classification metrics from rejection simulation results

        Takes the output from PostHocRejectorSimulator (dict mapping budget -> metrics)
        and computes aggregated paper metrics.

        Args:
            rejection_results: Dict mapping rejection_budget -> metrics dict
                Expected keys in each metrics dict: 'F1', 'monthly_Rejections', etc.
            baseline_f1: Baseline F1 (without rejection) for computing Max Drawdown.
                If None and use_monthly_f1=True, will be computed from rejection_results[0]['monthly_F1']
            use_monthly_f1: If True, compute baseline F1 as mean of monthly F1 values
                (per-month averaging, as used in paper). If False, use overall F1 from
                rejection_results[0]['F1'] (flatten-first approach).
            metrics: List of metrics to compute. If None, computes all:
                - 'F1': F1 with selective classification
                - 'MAPD': Mean Absolute Percentage Deviation
                - 'Max Drawdown (F1)': Maximum degradation vs baseline
                - 'CV[F1]': Coefficient of Variation
            aggregation_method: How to aggregate across budgets ("AUC", "Mean", "Median").
                Default is "Median" as per paper methodology.

        Returns:
            Dictionary with aggregated metrics

        Example:
            >>> # Simulate rejection (from PostHocRejectorSimulator)
            >>> rejection_results = {
            ...     100: {'F1': 0.75, 'monthly_Rejections': [95, 105, 98]},
            ...     200: {'F1': 0.78, 'monthly_Rejections': [195, 205, 198]},
            ...     400: {'F1': 0.82, 'monthly_Rejections': [395, 405, 398]},
            ... }
            >>> baseline_f1 = 0.70
            >>>
            >>> metrics = SelectiveClassificationMetrics.compute_from_rejection_results(
            ...     rejection_results,
            ...     baseline_f1=baseline_f1,
            ...     aggregation_method="Mean"
            ... )
            >>> 'F1' in metrics
            True
            >>> 'MAPD' in metrics
            True
            >>> 'Max Drawdown (F1)' in metrics
            True
        """
        if not rejection_results:
            raise ValueError("rejection_results cannot be empty")

        if metrics is None:
            metrics = ['F1', 'MAPD', 'Max Drawdown (F1)', 'CV[F1]', 'Benefit Fraction']

        # Compute baseline F1 from monthly values if requested
        if use_monthly_f1 and baseline_f1 is None:
            if 0 in rejection_results and 'monthly_F1' in rejection_results[0]:
                monthly_f1_baseline = np.array(rejection_results[0]['monthly_F1'])
                baseline_f1 = float(np.mean(monthly_f1_baseline))
            else:
                raise ValueError(
                    "use_monthly_f1=True requires rejection_results[0] to contain 'monthly_F1' key"
                )

        result = {}

        # Collect values across rejection budgets
        budget_f1 = {}
        budget_mapd = {}
        budget_cv_f1 = {}
        budget_drawdown = {}
        budget_benefit_fraction = {}

        for rej_budget, rej_metrics in rejection_results.items():
            if rej_budget == 0:
                continue  # Skip baseline (no rejection)

            # F1 with selective classification
            if 'F1' in metrics and 'F1' in rej_metrics:
                budget_f1[rej_budget] = rej_metrics['F1']

            # MAPD
            if 'MAPD' in metrics and 'monthly_Rejections' in rej_metrics:
                monthly_rej = rej_metrics['monthly_Rejections']
                if len(monthly_rej) > 0:
                    mapd = SelectiveClassificationMetrics.compute_mapd(
                        monthly_rej,
                        target_rejection_rate=rej_budget
                    )
                    budget_mapd[rej_budget] = mapd

            # CV[F1]
            if 'CV[F1]' in metrics and 'monthly_F1' in rej_metrics:
                monthly_f1 = rej_metrics['monthly_F1']
                if len(monthly_f1) > 1:
                    cv = StabilityMetrics.coefficient_of_variation(monthly_f1)
                    budget_cv_f1[rej_budget] = cv * 100  # Convert to percentage

            # Max Drawdown (paper methodology: monthly differences)
            if 'Max Drawdown (F1)' in metrics:
                # Get monthly F1 values with and without rejection
                if 'monthly_F1' in rej_metrics and 0 in rejection_results and 'monthly_F1' in rejection_results[0]:
                    monthly_f1_with_rej = np.array(rej_metrics['monthly_F1'])
                    monthly_f1_baseline = np.array(rejection_results[0]['monthly_F1'])

                    # Calculate per-month differences (with_rejection - baseline)
                    # Negative values mean performance degraded with rejection
                    differences = monthly_f1_with_rej - monthly_f1_baseline

                    # Find worst degradation (most negative difference)
                    # Only consider months where performance decreased (negative diff)
                    neg_diffs = differences[differences < 0]
                    max_drawdown = abs(np.min(neg_diffs)) if len(neg_diffs) > 0 else 0.0

                    budget_drawdown[rej_budget] = max_drawdown * 100  # Convert to percentage
                elif baseline_f1 is not None and 'F1' in rej_metrics:
                    # Fallback: use overall F1 if monthly values not available
                    drawdown = max(0.0, baseline_f1 - rej_metrics['F1'])
                    budget_drawdown[rej_budget] = drawdown * 100

            # Benefit Fraction (BF): % of SC's impact that's positive
            if 'Benefit Fraction' in metrics:
                if 'monthly_F1' in rej_metrics and 0 in rejection_results and 'monthly_F1' in rejection_results[0]:
                    monthly_f1_with_rej = np.array(rej_metrics['monthly_F1'])
                    monthly_f1_baseline = np.array(rejection_results[0]['monthly_F1'])

                    bf = SelectiveClassificationMetrics.compute_benefit_fraction(
                        monthly_f1_baseline,
                        monthly_f1_with_rej
                    )
                    budget_benefit_fraction[rej_budget] = bf

        # Aggregate across budgets
        if budget_f1:
            # F1 is stored as proportion (0-1), convert to percentage
            f1_aggregated = SelectiveClassificationMetrics.aggregate_across_budgets(
                budget_f1, method=aggregation_method
            )
            result['F1'] = f1_aggregated * 100  # Convert to percentage

        if budget_mapd:
            result['MAPD'] = SelectiveClassificationMetrics.aggregate_across_budgets(
                budget_mapd, method=aggregation_method
            )

        if budget_cv_f1:
            result['CV[F1]'] = SelectiveClassificationMetrics.aggregate_across_budgets(
                budget_cv_f1, method=aggregation_method
            )

        if budget_drawdown:
            # Use MAX aggregation for Max Drawdown to get TRUE worst case
            # This reports the largest degradation across all rejection budgets
            # (Changed from MIN which was counterintuitively reporting best-case)
            result['Max Drawdown (F1)'] = SelectiveClassificationMetrics.aggregate_across_budgets(
                budget_drawdown, method="Max"
            )

        if budget_benefit_fraction:
            # Aggregate BF using median (robust to outliers at extreme budgets)
            result['Benefit Fraction'] = SelectiveClassificationMetrics.aggregate_across_budgets(
                budget_benefit_fraction, method=aggregation_method
            )

        return result


class ThresholdIndependentMetrics:
    """
    Compute threshold-independent metrics (ROC-AUC, PR-AUC).

    These metrics evaluate classifier performance across all possible
    classification thresholds, making them robust to threshold selection.

    **Why These Matter**:
    - F1/FNR/FPR depend on the classification threshold (typically 0.5)
    - Different thresholds may be optimal for different operational costs
    - ROC-AUC and PR-AUC summarize performance across ALL thresholds
    - PR-AUC is especially important for imbalanced data (like 9:1 malware ratio)

    **Metrics**:
    - **ROC-AUC**: Area under Receiver Operating Characteristic curve
      - Measures trade-off between TPR and FPR
      - 0.5 = random, 1.0 = perfect
      - Can be misleading for highly imbalanced data

    - **PR-AUC**: Area under Precision-Recall curve
      - Measures trade-off between precision and recall
      - More informative for imbalanced datasets
      - Focuses on positive class (malware) performance

    **Data Requirements**:
    - Requires soft predictions (probability scores), not just 0/1
    - For softmax: max probability or uncertainty = 1 - max_prob
    - For SVM: distance to hyperplane (may need Platt scaling for probability)

    Example:
        >>> labels = np.array([1, 1, 0, 0, 1, 0])
        >>> scores = np.array([0.9, 0.7, 0.3, 0.2, 0.8, 0.4])  # Soft predictions
        >>> roc_auc = ThresholdIndependentMetrics.compute_roc_auc(labels, scores)
        >>> pr_auc = ThresholdIndependentMetrics.compute_pr_auc(labels, scores)
    """

    @staticmethod
    def compute_roc_auc(
        labels: npt.NDArray,
        scores: npt.NDArray
    ) -> float:
        """
        Compute Area Under ROC Curve (ROC-AUC).

        The ROC curve plots True Positive Rate (TPR) vs False Positive Rate (FPR)
        at various classification thresholds.

        ROC-AUC interpretation:
        - 1.0 = Perfect classifier (TPR=1 at FPR=0)
        - 0.5 = Random classifier (diagonal line)
        - < 0.5 = Worse than random (flip predictions!)

        Note: ROC-AUC can be overly optimistic for imbalanced data because
        it gives equal weight to all thresholds. Consider PR-AUC for
        highly imbalanced datasets.

        Args:
            labels: Binary ground truth labels (0 or 1)
            scores: Prediction scores (higher = more likely positive/malware)

        Returns:
            ROC-AUC value [0, 1]

        Raises:
            ValueError: If labels or scores are empty, or if only one class present

        Example:
            >>> labels = np.array([1, 1, 0, 0, 1, 0])
            >>> scores = np.array([0.9, 0.7, 0.3, 0.2, 0.8, 0.4])
            >>> roc_auc = ThresholdIndependentMetrics.compute_roc_auc(labels, scores)
            >>> 0.8 < roc_auc <= 1.0  # Good classifier
            True
        """
        labels = np.asarray(labels)
        scores = np.asarray(scores)

        if len(labels) == 0 or len(scores) == 0:
            raise ValueError("Labels and scores cannot be empty")

        if len(labels) != len(scores):
            raise ValueError(f"Labels length ({len(labels)}) != scores length ({len(scores)})")

        # Check for both classes present
        unique_labels = np.unique(labels)
        if len(unique_labels) < 2:
            raise ValueError(f"Need both classes present, got only: {unique_labels}")

        # Compute ROC curve points using the trapezoidal method
        # Sort by scores descending
        desc_score_indices = np.argsort(scores)[::-1]
        sorted_labels = labels[desc_score_indices]

        # Count positives and negatives
        n_pos = np.sum(labels == 1)
        n_neg = np.sum(labels == 0)

        if n_pos == 0 or n_neg == 0:
            raise ValueError("Need at least one positive and one negative sample")

        # Compute TPR and FPR at each threshold
        tps = np.cumsum(sorted_labels == 1)
        fps = np.cumsum(sorted_labels == 0)

        tpr = tps / n_pos
        fpr = fps / n_neg

        # Prepend (0, 0) for the ROC curve starting point
        tpr = np.concatenate([[0], tpr])
        fpr = np.concatenate([[0], fpr])

        # Compute AUC using trapezoidal rule
        auc = np.trapz(tpr, fpr)

        return float(auc)

    @staticmethod
    def compute_pr_auc(
        labels: npt.NDArray,
        scores: npt.NDArray
    ) -> float:
        """
        Compute Area Under Precision-Recall Curve (PR-AUC).

        The PR curve plots Precision vs Recall at various thresholds.
        PR-AUC is more informative than ROC-AUC for imbalanced datasets
        because it focuses on the positive class (malware).

        PR-AUC interpretation:
        - 1.0 = Perfect classifier
        - baseline = proportion of positives (e.g., 0.9 for 9:1 malware ratio)
        - Values close to baseline indicate poor discrimination

        For a 9:1 malware-to-benign ratio:
        - Random classifier: PR-AUC ≈ 0.9 (high by chance!)
        - Good classifier: PR-AUC > 0.95
        - Excellent classifier: PR-AUC > 0.99

        Args:
            labels: Binary ground truth labels (0 or 1)
            scores: Prediction scores (higher = more likely positive/malware)

        Returns:
            PR-AUC value [0, 1]

        Raises:
            ValueError: If labels or scores are empty, or if only one class present

        Example:
            >>> labels = np.array([1, 1, 0, 0, 1, 0])
            >>> scores = np.array([0.9, 0.7, 0.3, 0.2, 0.8, 0.4])
            >>> pr_auc = ThresholdIndependentMetrics.compute_pr_auc(labels, scores)
            >>> pr_auc > 0.5  # Better than random for balanced data
            True
        """
        labels = np.asarray(labels)
        scores = np.asarray(scores)

        if len(labels) == 0 or len(scores) == 0:
            raise ValueError("Labels and scores cannot be empty")

        if len(labels) != len(scores):
            raise ValueError(f"Labels length ({len(labels)}) != scores length ({len(scores)})")

        # Check for both classes present
        unique_labels = np.unique(labels)
        if len(unique_labels) < 2:
            raise ValueError(f"Need both classes present, got only: {unique_labels}")

        n_pos = np.sum(labels == 1)
        n_neg = np.sum(labels == 0)

        if n_pos == 0 or n_neg == 0:
            raise ValueError("Need at least one positive and one negative sample")

        # Sort by scores descending
        desc_score_indices = np.argsort(scores)[::-1]
        sorted_labels = labels[desc_score_indices]

        # Compute precision and recall at each threshold
        tps = np.cumsum(sorted_labels == 1)
        fps = np.cumsum(sorted_labels == 0)

        # Precision = TP / (TP + FP) = TP / position
        precision = tps / (tps + fps)

        # Recall = TP / total positives
        recall = tps / n_pos

        # Prepend (recall=0, precision=1) for interpolation
        # This is standard for PR curves
        precision = np.concatenate([[1], precision])
        recall = np.concatenate([[0], recall])

        # Compute AUC using trapezoidal rule
        # Note: recall is the x-axis, precision is y-axis
        auc = np.trapz(precision, recall)

        return float(auc)

    @staticmethod
    def compute_from_monthly_data(
        monthly_labels: List[npt.NDArray],
        monthly_scores: List[npt.NDArray],
        aggregation: str = "pooled"
    ) -> Dict[str, float]:
        """
        Compute ROC-AUC and PR-AUC from monthly data.

        Two aggregation modes:
        - "pooled": Combine all months into one dataset, compute AUC once
        - "average": Compute AUC per month, then average

        Pooled is generally preferred as it gives more stable estimates.

        Args:
            monthly_labels: List of label arrays (one per month)
            monthly_scores: List of score arrays (one per month)
            aggregation: "pooled" or "average"

        Returns:
            Dictionary with 'ROC-AUC' and 'PR-AUC'

        Example:
            >>> labels_m1 = np.array([1, 1, 0, 0])
            >>> labels_m2 = np.array([1, 0, 0, 1])
            >>> scores_m1 = np.array([0.9, 0.7, 0.3, 0.2])
            >>> scores_m2 = np.array([0.8, 0.4, 0.3, 0.6])
            >>> metrics = ThresholdIndependentMetrics.compute_from_monthly_data(
            ...     [labels_m1, labels_m2],
            ...     [scores_m1, scores_m2],
            ...     aggregation="pooled"
            ... )
            >>> 'ROC-AUC' in metrics and 'PR-AUC' in metrics
            True
        """
        if len(monthly_labels) != len(monthly_scores):
            raise ValueError(
                f"Number of label arrays ({len(monthly_labels)}) != "
                f"number of score arrays ({len(monthly_scores)})"
            )

        if len(monthly_labels) == 0:
            raise ValueError("Cannot compute AUC on empty data")

        if aggregation == "pooled":
            # Combine all months
            all_labels = np.concatenate(monthly_labels)
            all_scores = np.concatenate(monthly_scores)

            return {
                'ROC-AUC': ThresholdIndependentMetrics.compute_roc_auc(all_labels, all_scores),
                'PR-AUC': ThresholdIndependentMetrics.compute_pr_auc(all_labels, all_scores)
            }

        elif aggregation == "average":
            # Compute per month, then average
            monthly_roc = []
            monthly_pr = []

            for labels, scores in zip(monthly_labels, monthly_scores):
                try:
                    roc = ThresholdIndependentMetrics.compute_roc_auc(labels, scores)
                    pr = ThresholdIndependentMetrics.compute_pr_auc(labels, scores)
                    monthly_roc.append(roc)
                    monthly_pr.append(pr)
                except ValueError:
                    # Skip months with only one class
                    continue

            if len(monthly_roc) == 0:
                raise ValueError("No valid months for AUC computation")

            return {
                'ROC-AUC': float(np.mean(monthly_roc)),
                'PR-AUC': float(np.mean(monthly_pr)),
                'ROC-AUC_std': float(np.std(monthly_roc)),
                'PR-AUC_std': float(np.std(monthly_pr)),
                'n_valid_months': len(monthly_roc)
            }

        else:
            raise ValueError(f"Unknown aggregation: {aggregation}. Must be 'pooled' or 'average'")

    @staticmethod
    def uncertainty_to_malware_score(
        uncertainty: npt.NDArray,
        prediction: npt.NDArray
    ) -> npt.NDArray:
        """
        Convert uncertainty scores to malware probability scores.

        In our framework, uncertainty is typically computed as:
        - Softmax: uncertainty = 1 - |max_prob - 0.5| / 0.5
        - SVM: uncertainty = -|distance to hyperplane|

        For ROC/PR curves, we need a score where HIGHER = more likely malware.

        This function converts uncertainty + prediction into a malware score:
        - If predicted malware (pred=1): score = 1 - uncertainty
          (confident malware prediction = high score)
        - If predicted benign (pred=0): score = uncertainty
          (uncertain benign prediction = moderate score)

        Args:
            uncertainty: Uncertainty scores (higher = more uncertain)
            prediction: Binary predictions (0=benign, 1=malware)

        Returns:
            Malware probability scores (higher = more likely malware)

        Example:
            >>> uncertainty = np.array([0.1, 0.8, 0.2, 0.9])  # Low=confident
            >>> prediction = np.array([1, 1, 0, 0])  # Malware, Malware, Benign, Benign
            >>> scores = ThresholdIndependentMetrics.uncertainty_to_malware_score(
            ...     uncertainty, prediction
            ... )
            >>> # Confident malware (0.1 unc, pred=1) should have high score
            >>> # Uncertain benign (0.9 unc, pred=0) should have moderate score
        """
        uncertainty = np.asarray(uncertainty)
        prediction = np.asarray(prediction)

        if len(uncertainty) != len(prediction):
            raise ValueError(
                f"Uncertainty length ({len(uncertainty)}) != "
                f"prediction length ({len(prediction)})"
            )

        # For malware predictions: confidence = 1 - uncertainty
        # For benign predictions: we use uncertainty as a proxy for "might be malware"
        scores = np.where(
            prediction == 1,
            1 - uncertainty,  # Confident malware = high score
            uncertainty       # Uncertain benign = moderate score
        )

        return scores

    @staticmethod
    def compute_from_predictions_and_uncertainty(
        monthly_labels: List[npt.NDArray],
        monthly_predictions: List[npt.NDArray],
        monthly_uncertainties: List[npt.NDArray],
        aggregation: str = "pooled"
    ) -> Dict[str, float]:
        """
        Compute ROC-AUC and PR-AUC from predictions and uncertainty scores.

        This is a convenience method that:
        1. Converts uncertainty + predictions to malware scores
        2. Computes threshold-independent metrics

        Args:
            monthly_labels: List of label arrays (one per month)
            monthly_predictions: List of prediction arrays (0/1)
            monthly_uncertainties: List of uncertainty arrays

        Returns:
            Dictionary with 'ROC-AUC' and 'PR-AUC'

        Example:
            >>> labels = [np.array([1, 1, 0, 0])]
            >>> preds = [np.array([1, 0, 0, 1])]
            >>> uncs = [np.array([0.1, 0.8, 0.2, 0.9])]
            >>> metrics = ThresholdIndependentMetrics.compute_from_predictions_and_uncertainty(
            ...     labels, preds, uncs
            ... )
        """
        if not (len(monthly_labels) == len(monthly_predictions) == len(monthly_uncertainties)):
            raise ValueError("All monthly arrays must have the same length")

        # Convert to scores
        monthly_scores = [
            ThresholdIndependentMetrics.uncertainty_to_malware_score(unc, pred)
            for unc, pred in zip(monthly_uncertainties, monthly_predictions)
        ]

        return ThresholdIndependentMetrics.compute_from_monthly_data(
            monthly_labels, monthly_scores, aggregation
        )
