"""
Performance-Based Rejection Simulation

Calibrates rejection thresholds against performance targets (FNR, FPR)
and measures target adherence using statistically principled metrics.

This module extends the budget-based rejection framework with performance-based
calibration, answering: "What threshold maintains target FNR/FPR?"

Metrics Design:
--------------
Based on Murphy decomposition and forecast verification best practices:
- MAE (Mean Absolute Error): Primary adherence metric
- Bias (Mean Signed Error): Systematic over/under-shooting
- RMSE: Penalizes large deviations
- HitRate: Fraction of months hitting target within tolerance
- Reliability Diagrams: Gold standard visualization for calibration

References:
----------
- Murphy (1973): Brier score decomposition
- Dimitriadis et al. (PNAS 2021): Stable reliability diagrams
- Geifman & El-Yaniv (NeurIPS 2017): Selective classification

Example:
    >>> simulator = PerformanceRejectionSimulator()
    >>>
    >>> # Calibrate threshold on validation data
    >>> threshold = simulator.calibrate_threshold(
    ...     val_predictions, val_labels, val_uncertainties,
    ...     target_metric="FNR", target_value=0.05
    ... )
    >>>
    >>> # Simulate deployment with calibrated threshold
    >>> results = simulator.simulate_deployment(
    ...     test_predictions, test_labels, test_uncertainties,
    ...     threshold=threshold, target_metric="FNR", target_value=0.05
    ... )
    >>>
    >>> # Compute adherence metrics
    >>> metrics = simulator.compute_adherence_metrics(results, target_value=0.05)
    >>> print(f"MAE: {metrics.mae:.4f}, HitRate@2%: {metrics.hit_rate_2pct:.2%}")
"""

from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple, Literal, Union
import numpy as np
import numpy.typing as npt
from numba import njit
import warnings


# ============================================================================
# DATA STRUCTURES
# ============================================================================

@dataclass
class TargetAdherenceMetrics:
    """
    Metrics for measuring target adherence in performance-based rejection.

    All metrics are based on established statistical theory:
    - MAE, Bias, RMSE: Standard error metrics
    - HitRate: Coverage probability from conformal prediction
    - Mann-Kendall: Trend detection from time series analysis

    Attributes:
        target_value: The target FNR/FPR value we're trying to achieve
        target_metric: "FNR" or "FPR"

        # Primary adherence metrics
        mae: Mean Absolute Error - average |actual - target|
        bias: Mean Signed Error - average (actual - target), shows direction
        rmse: Root Mean Square Error - penalizes large deviations

        # Hit rates at different tolerances
        hit_rate_1pct: Fraction of months with |actual - target| < 0.01
        hit_rate_2pct: Fraction of months with |actual - target| < 0.02
        hit_rate_5pct: Fraction of months with |actual - target| < 0.05

        # Coverage metrics
        mean_coverage: Mean fraction of samples accepted
        cv_coverage: Coefficient of variation of coverage
        min_coverage: Worst-case (minimum) coverage

        # Temporal stability
        cv_deviation: CV of |actual - target| across months
        mann_kendall_tau: Trend statistic (-1 to 1)
        mann_kendall_p: P-value for trend test
        max_deviation: Worst-case |actual - target|

        # Per-month details (for visualization)
        monthly_actual: Actual FNR/FPR per month
        monthly_coverage: Coverage per month
        monthly_threshold: Threshold used per month (for adaptive)
    """
    target_value: float
    target_metric: str  # "FNR" or "FPR"

    # Primary adherence metrics
    mae: float
    bias: float
    rmse: float

    # Hit rates at different tolerances
    hit_rate_1pct: float
    hit_rate_2pct: float
    hit_rate_5pct: float

    # Coverage metrics
    mean_coverage: float
    cv_coverage: float
    min_coverage: float

    # Temporal stability
    cv_deviation: float
    mann_kendall_tau: float
    mann_kendall_p: float
    max_deviation: float

    # Per-month details
    monthly_actual: npt.NDArray[np.float64] = field(default_factory=lambda: np.array([]))
    monthly_coverage: npt.NDArray[np.float64] = field(default_factory=lambda: np.array([]))
    monthly_threshold: npt.NDArray[np.float64] = field(default_factory=lambda: np.array([]))

    def to_dict(self) -> Dict:
        """Convert to dictionary for DataFrame creation."""
        return {
            'target_value': self.target_value,
            'target_metric': self.target_metric,
            'mae': self.mae,
            'bias': self.bias,
            'rmse': self.rmse,
            'hit_rate_1pct': self.hit_rate_1pct,
            'hit_rate_2pct': self.hit_rate_2pct,
            'hit_rate_5pct': self.hit_rate_5pct,
            'mean_coverage': self.mean_coverage,
            'cv_coverage': self.cv_coverage,
            'min_coverage': self.min_coverage,
            'cv_deviation': self.cv_deviation,
            'mann_kendall_tau': self.mann_kendall_tau,
            'mann_kendall_p': self.mann_kendall_p,
            'max_deviation': self.max_deviation,
        }


@dataclass
class ReliabilityDiagramData:
    """
    Data for reliability diagram visualization.

    Reliability diagrams plot target vs observed values to assess calibration.
    Perfect calibration = diagonal line (observed = target).

    Following Dimitriadis et al. (PNAS 2021) for stable reliability diagrams.

    Attributes:
        targets: Array of target values tested
        observed_mean: Mean observed value at each target
        observed_std: Std of observed values at each target
        observed_ci_low: Lower 95% CI bound
        observed_ci_high: Upper 95% CI bound
        coverage_mean: Mean coverage at each target
        coverage_std: Std of coverage at each target
        n_months: Number of months per target (for error bar calculation)
    """
    targets: npt.NDArray[np.float64]
    observed_mean: npt.NDArray[np.float64]
    observed_std: npt.NDArray[np.float64]
    observed_ci_low: npt.NDArray[np.float64]
    observed_ci_high: npt.NDArray[np.float64]
    coverage_mean: npt.NDArray[np.float64]
    coverage_std: npt.NDArray[np.float64]
    n_months: int

    def calibration_error(self) -> float:
        """
        Compute scalar calibration error (area between curve and diagonal).

        Lower is better. 0 = perfect calibration.
        """
        # Use trapezoidal rule to compute area between curve and diagonal
        deviations = np.abs(self.observed_mean - self.targets)

        if len(self.targets) < 2:
            return float(deviations[0]) if len(deviations) > 0 else 0.0

        # Sort by target for proper integration
        sort_idx = np.argsort(self.targets)
        sorted_targets = self.targets[sort_idx]
        sorted_deviations = deviations[sort_idx]

        # Trapezoidal integration
        area = np.trapz(sorted_deviations, sorted_targets)

        # Normalize by target range
        target_range = sorted_targets[-1] - sorted_targets[0]
        if target_range > 0:
            area /= target_range

        return float(area)


@dataclass
class MonthlyRejectionResult:
    """Results from applying rejection to a single month."""
    month_idx: int
    threshold: float

    # Before rejection
    total_samples: int
    baseline_fnr: float
    baseline_fpr: float
    baseline_f1: float

    # After rejection
    accepted_samples: int
    rejected_samples: int
    coverage: float  # accepted / total

    # Metrics on accepted samples
    rejected_fnr: float
    rejected_fpr: float
    rejected_f1: float


# ============================================================================
# STATISTICAL UTILITIES
# ============================================================================

@njit
def _mann_kendall_core(x: npt.NDArray) -> Tuple[float, int]:
    """
    Numba-accelerated Mann-Kendall test core computation.

    Returns (S statistic, n) for subsequent p-value calculation.
    """
    n = len(x)
    s = 0

    for i in range(n - 1):
        for j in range(i + 1, n):
            diff = x[j] - x[i]
            if diff > 0:
                s += 1
            elif diff < 0:
                s -= 1

    return float(s), n


def mann_kendall_test(x: npt.NDArray) -> Tuple[float, float]:
    """
    Perform Mann-Kendall trend test.

    Tests for monotonic trend in time series data.

    Args:
        x: 1D array of values over time

    Returns:
        (tau, p_value):
            tau: Kendall's tau (-1 to 1)
                 > 0: increasing trend
                 < 0: decreasing trend
                 ≈ 0: no trend
            p_value: Two-sided p-value for H0: no trend

    Reference:
        Mann (1945), Kendall (1975)
    """
    x = np.asarray(x, dtype=np.float64)
    n = len(x)

    if n < 3:
        return 0.0, 1.0  # Not enough data

    # Compute S statistic
    s, _ = _mann_kendall_core(x)

    # Compute tau
    # tau = S / (n * (n-1) / 2)
    tau = s / (n * (n - 1) / 2)

    # Compute variance of S (assuming no ties for simplicity)
    var_s = n * (n - 1) * (2 * n + 5) / 18

    # Compute z-score
    if s > 0:
        z = (s - 1) / np.sqrt(var_s)
    elif s < 0:
        z = (s + 1) / np.sqrt(var_s)
    else:
        z = 0.0

    # Two-sided p-value using normal approximation
    from scipy import stats
    p_value = 2 * (1 - stats.norm.cdf(abs(z)))

    return float(tau), float(p_value)


def bootstrap_ci(
    values: npt.NDArray,
    statistic: str = 'mean',
    confidence: float = 0.95,
    n_bootstrap: int = 1000,
    seed: Optional[int] = None
) -> Tuple[float, float]:
    """
    Compute bootstrap confidence interval.

    Args:
        values: 1D array of values
        statistic: 'mean' or 'median'
        confidence: Confidence level (default: 0.95)
        n_bootstrap: Number of bootstrap samples
        seed: Random seed for reproducibility

    Returns:
        (ci_low, ci_high): Confidence interval bounds
    """
    rng = np.random.default_rng(seed)
    values = np.asarray(values)
    n = len(values)

    if n == 0:
        return np.nan, np.nan

    # Bootstrap resampling
    bootstrap_stats = np.empty(n_bootstrap)

    for i in range(n_bootstrap):
        sample = rng.choice(values, size=n, replace=True)
        if statistic == 'mean':
            bootstrap_stats[i] = np.mean(sample)
        elif statistic == 'median':
            bootstrap_stats[i] = np.median(sample)
        else:
            raise ValueError(f"Unknown statistic: {statistic}")

    # Compute percentiles
    alpha = 1 - confidence
    ci_low = np.percentile(bootstrap_stats, 100 * alpha / 2)
    ci_high = np.percentile(bootstrap_stats, 100 * (1 - alpha / 2))

    return float(ci_low), float(ci_high)


# ============================================================================
# CORE REJECTION FUNCTIONS
# ============================================================================

@njit
def compute_metrics_numba(y_true: npt.NDArray, y_pred: npt.NDArray) -> Tuple[float, float, float]:
    """
    Compute F1, FNR, FPR for binary classification.

    Returns: (f1, fnr, fpr)
    """
    n = y_true.shape[0]

    if n < 20:
        return np.nan, np.nan, np.nan

    TP = TN = FP = FN = 0

    for i in range(n):
        yt = y_true[i]
        yp = y_pred[i]

        if yt == 1 and yp == 1:
            TP += 1
        elif yt == 0 and yp == 0:
            TN += 1
        elif yt == 0 and yp == 1:
            FP += 1
        elif yt == 1 and yp == 0:
            FN += 1

    # F1
    denom_f1 = 2 * TP + FP + FN
    f1 = (2.0 * TP) / denom_f1 if denom_f1 > 0 else 0.0

    # FNR = FN / (TP + FN)
    actual_positives = TP + FN
    fnr = FN / actual_positives if actual_positives > 0 else 0.0

    # FPR = FP / (FP + TN)
    actual_negatives = FP + TN
    fpr = FP / actual_negatives if actual_negatives > 0 else 0.0

    return f1, fnr, fpr


@njit
def apply_uncertainty_threshold(
    predictions: npt.NDArray,
    labels: npt.NDArray,
    uncertainties: npt.NDArray,
    threshold: float,
    reject_above: bool = True
) -> Tuple[npt.NDArray, npt.NDArray, int, int]:
    """
    Apply uncertainty threshold to reject samples.

    Args:
        predictions: Binary predictions
        labels: True labels
        uncertainties: Uncertainty scores
        threshold: Rejection threshold
        reject_above: If True, reject samples with uncertainty > threshold

    Returns:
        (accepted_predictions, accepted_labels, n_accepted, n_rejected)
    """
    n = len(predictions)

    # Create mask
    if reject_above:
        accept_mask = uncertainties <= threshold
    else:
        accept_mask = uncertainties >= threshold

    # Count
    n_accepted = 0
    for i in range(n):
        if accept_mask[i]:
            n_accepted += 1
    n_rejected = n - n_accepted

    # Extract accepted samples
    accepted_preds = np.empty(n_accepted, dtype=predictions.dtype)
    accepted_labels = np.empty(n_accepted, dtype=labels.dtype)

    idx = 0
    for i in range(n):
        if accept_mask[i]:
            accepted_preds[idx] = predictions[i]
            accepted_labels[idx] = labels[i]
            idx += 1

    return accepted_preds, accepted_labels, n_accepted, n_rejected


# ============================================================================
# PERFORMANCE REJECTION SIMULATOR
# ============================================================================

class PerformanceRejectionSimulator:
    """
    Simulate rejection with performance-based calibration.

    Calibrates thresholds on validation period to achieve target FNR/FPR,
    then measures adherence on deployment period using statistically
    principled metrics.

    Two modes:
    1. Fixed threshold: Calibrate once on validation, apply fixed threshold
    2. Adaptive threshold: Recalibrate threshold each month using accumulated data

    Example:
        >>> sim = PerformanceRejectionSimulator()
        >>>
        >>> # Calibrate on validation
        >>> threshold = sim.calibrate_threshold(
        ...     val_preds, val_labels, val_uncs,
        ...     target_metric="FNR", target_value=0.05
        ... )
        >>>
        >>> # Deploy on test
        >>> monthly_results = sim.simulate_deployment(
        ...     test_preds_by_month, test_labels_by_month, test_uncs_by_month,
        ...     threshold=threshold,
        ...     target_metric="FNR", target_value=0.05
        ... )
        >>>
        >>> # Compute adherence metrics
        >>> metrics = sim.compute_adherence_metrics(monthly_results, target_value=0.05)
    """

    def __init__(self, tolerance: float = 1e-4, max_iterations: int = 100):
        """
        Initialize simulator.

        Args:
            tolerance: Convergence tolerance for binary search
            max_iterations: Maximum iterations for binary search
        """
        self.tolerance = tolerance
        self.max_iterations = max_iterations

    def calibrate_threshold(
        self,
        predictions: npt.NDArray,
        labels: npt.NDArray,
        uncertainties: npt.NDArray,
        target_metric: Literal["FNR", "FPR"],
        target_value: float,
    ) -> float:
        """
        Calibrate rejection threshold to achieve target FNR or FPR.

        Uses binary search over threshold space to find threshold that
        achieves target metric on the provided (validation) data.

        This provides "marginal" calibration (average over samples),
        following conformal prediction conventions.

        Args:
            predictions: Binary predictions (pooled from validation months)
            labels: True labels (pooled from validation months)
            uncertainties: Uncertainty scores (pooled from validation months)
            target_metric: "FNR" or "FPR" - which metric to target
            target_value: Target value for the metric (e.g., 0.05 for 5%)

        Returns:
            Calibrated threshold value

        Raises:
            ValueError: If target is unachievable (would require rejecting all samples)
        """
        predictions = np.asarray(predictions)
        labels = np.asarray(labels)
        uncertainties = np.asarray(uncertainties)

        if len(predictions) == 0:
            raise ValueError("Empty input arrays")

        # Sort uncertainties for efficient search
        sorted_indices = np.argsort(uncertainties)
        sorted_uncs = uncertainties[sorted_indices]

        # Binary search bounds
        low_thresh = sorted_uncs[0] - 1e-6  # Accept all
        high_thresh = sorted_uncs[-1] + 1e-6  # Reject all

        # Check if target is achievable at extremes
        # At low threshold (accept all), what's the metric?
        f1_all, fnr_all, fpr_all = compute_metrics_numba(labels, predictions)
        baseline_metric = fnr_all if target_metric == "FNR" else fpr_all

        if baseline_metric <= target_value:
            # Already achieving target without rejection
            return high_thresh  # Don't reject anything

        # Binary search
        for _ in range(self.max_iterations):
            mid_thresh = (low_thresh + high_thresh) / 2

            # Apply threshold
            acc_preds, acc_labels, n_acc, n_rej = apply_uncertainty_threshold(
                predictions, labels, uncertainties, mid_thresh, reject_above=True
            )

            if n_acc < 20:
                # Not enough samples - need lower threshold (accept more)
                low_thresh = mid_thresh
                continue

            # Compute metric
            f1, fnr, fpr = compute_metrics_numba(acc_labels, acc_preds)
            current_metric = fnr if target_metric == "FNR" else fpr

            if np.isnan(current_metric):
                # Invalid - adjust bounds
                low_thresh = mid_thresh
                continue

            # Check convergence
            if abs(current_metric - target_value) < self.tolerance:
                return mid_thresh

            # Adjust search bounds
            if current_metric > target_value:
                # Need to reject more (lower threshold)
                high_thresh = mid_thresh
            else:
                # Can accept more (higher threshold)
                low_thresh = mid_thresh

        # Return best found
        warnings.warn(
            f"Binary search did not converge after {self.max_iterations} iterations. "
            f"Best threshold: {mid_thresh}, achieved {target_metric}: {current_metric:.4f} "
            f"(target: {target_value:.4f})"
        )
        return mid_thresh

    def apply_threshold_single_month(
        self,
        predictions: npt.NDArray,
        labels: npt.NDArray,
        uncertainties: npt.NDArray,
        threshold: float,
        month_idx: int = 0
    ) -> MonthlyRejectionResult:
        """
        Apply calibrated threshold to a single month's data.

        Args:
            predictions: Binary predictions for this month
            labels: True labels for this month
            uncertainties: Uncertainty scores for this month
            threshold: Calibrated rejection threshold
            month_idx: Month index (for tracking)

        Returns:
            MonthlyRejectionResult with all metrics
        """
        predictions = np.asarray(predictions)
        labels = np.asarray(labels)
        uncertainties = np.asarray(uncertainties)

        # Baseline metrics (no rejection)
        baseline_f1, baseline_fnr, baseline_fpr = compute_metrics_numba(labels, predictions)

        # Apply threshold
        acc_preds, acc_labels, n_accepted, n_rejected = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold, reject_above=True
        )

        # Metrics on accepted samples
        if n_accepted >= 20:
            rejected_f1, rejected_fnr, rejected_fpr = compute_metrics_numba(acc_labels, acc_preds)
        else:
            rejected_f1 = rejected_fnr = rejected_fpr = np.nan

        total = len(predictions)
        coverage = n_accepted / total if total > 0 else 0.0

        return MonthlyRejectionResult(
            month_idx=month_idx,
            threshold=threshold,
            total_samples=total,
            baseline_fnr=baseline_fnr,
            baseline_fpr=baseline_fpr,
            baseline_f1=baseline_f1,
            accepted_samples=n_accepted,
            rejected_samples=n_rejected,
            coverage=coverage,
            rejected_fnr=rejected_fnr,
            rejected_fpr=rejected_fpr,
            rejected_f1=rejected_f1,
        )

    def simulate_deployment(
        self,
        predictions_by_month: List[npt.NDArray],
        labels_by_month: List[npt.NDArray],
        uncertainties_by_month: List[npt.NDArray],
        threshold: float,
        target_metric: Literal["FNR", "FPR"],
        target_value: float,
        adaptive: bool = False,
    ) -> List[MonthlyRejectionResult]:
        """
        Simulate deployment with calibrated threshold.

        Args:
            predictions_by_month: List of prediction arrays (one per month)
            labels_by_month: List of label arrays (one per month)
            uncertainties_by_month: List of uncertainty arrays (one per month)
            threshold: Initial calibrated threshold
            target_metric: "FNR" or "FPR"
            target_value: Target value for the metric
            adaptive: If True, recalibrate threshold each month

        Returns:
            List of MonthlyRejectionResult for each month
        """
        results = []
        current_threshold = threshold

        # For adaptive mode: accumulate data
        if adaptive:
            accumulated_preds = []
            accumulated_labels = []
            accumulated_uncs = []

        for month_idx, (preds, labels, uncs) in enumerate(
            zip(predictions_by_month, labels_by_month, uncertainties_by_month)
        ):
            # Apply current threshold
            result = self.apply_threshold_single_month(
                preds, labels, uncs, current_threshold, month_idx
            )
            results.append(result)

            # Adaptive recalibration
            if adaptive:
                accumulated_preds.append(preds)
                accumulated_labels.append(labels)
                accumulated_uncs.append(uncs)

                # Recalibrate on accumulated data
                all_preds = np.concatenate(accumulated_preds)
                all_labels = np.concatenate(accumulated_labels)
                all_uncs = np.concatenate(accumulated_uncs)

                current_threshold = self.calibrate_threshold(
                    all_preds, all_labels, all_uncs,
                    target_metric, target_value
                )

        return results

    def compute_adherence_metrics(
        self,
        monthly_results: List[MonthlyRejectionResult],
        target_metric: Literal["FNR", "FPR"],
        target_value: float,
    ) -> TargetAdherenceMetrics:
        """
        Compute all adherence metrics from monthly results.

        Based on Murphy decomposition and forecast verification best practices.

        Args:
            monthly_results: List of MonthlyRejectionResult from deployment
            target_metric: "FNR" or "FPR"
            target_value: Target value we're trying to achieve

        Returns:
            TargetAdherenceMetrics with all computed metrics
        """
        # Extract actual values
        if target_metric == "FNR":
            monthly_actual = np.array([r.rejected_fnr for r in monthly_results])
        else:
            monthly_actual = np.array([r.rejected_fpr for r in monthly_results])

        monthly_coverage = np.array([r.coverage for r in monthly_results])
        monthly_threshold = np.array([r.threshold for r in monthly_results])

        # Filter out NaN values for metric computation
        valid_mask = ~np.isnan(monthly_actual)
        valid_actual = monthly_actual[valid_mask]
        valid_coverage = monthly_coverage[valid_mask]

        if len(valid_actual) == 0:
            # No valid data
            return TargetAdherenceMetrics(
                target_value=target_value,
                target_metric=target_metric,
                mae=np.nan, bias=np.nan, rmse=np.nan,
                hit_rate_1pct=np.nan, hit_rate_2pct=np.nan, hit_rate_5pct=np.nan,
                mean_coverage=np.nan, cv_coverage=np.nan, min_coverage=np.nan,
                cv_deviation=np.nan, mann_kendall_tau=np.nan, mann_kendall_p=np.nan,
                max_deviation=np.nan,
                monthly_actual=monthly_actual,
                monthly_coverage=monthly_coverage,
                monthly_threshold=monthly_threshold,
            )

        # Compute deviations
        deviations = np.abs(valid_actual - target_value)
        signed_deviations = valid_actual - target_value

        # Primary metrics
        mae = float(np.mean(deviations))
        bias = float(np.mean(signed_deviations))
        rmse = float(np.sqrt(np.mean(signed_deviations ** 2)))

        # Hit rates
        hit_rate_1pct = float(np.mean(deviations < 0.01))
        hit_rate_2pct = float(np.mean(deviations < 0.02))
        hit_rate_5pct = float(np.mean(deviations < 0.05))

        # Coverage metrics
        mean_coverage = float(np.mean(valid_coverage))
        cv_coverage = float(np.std(valid_coverage) / np.mean(valid_coverage)) if np.mean(valid_coverage) > 0 else 0.0
        min_coverage = float(np.min(valid_coverage))

        # Temporal stability
        cv_deviation = float(np.std(deviations) / np.mean(deviations)) if np.mean(deviations) > 0 else 0.0
        max_deviation = float(np.max(deviations))

        # Mann-Kendall trend test
        if len(valid_actual) >= 3:
            tau, p = mann_kendall_test(deviations)
        else:
            tau, p = 0.0, 1.0

        return TargetAdherenceMetrics(
            target_value=target_value,
            target_metric=target_metric,
            mae=mae,
            bias=bias,
            rmse=rmse,
            hit_rate_1pct=hit_rate_1pct,
            hit_rate_2pct=hit_rate_2pct,
            hit_rate_5pct=hit_rate_5pct,
            mean_coverage=mean_coverage,
            cv_coverage=cv_coverage,
            min_coverage=min_coverage,
            cv_deviation=cv_deviation,
            mann_kendall_tau=tau,
            mann_kendall_p=p,
            max_deviation=max_deviation,
            monthly_actual=monthly_actual,
            monthly_coverage=monthly_coverage,
            monthly_threshold=monthly_threshold,
        )

    def simulate_target_grid(
        self,
        val_predictions: npt.NDArray,
        val_labels: npt.NDArray,
        val_uncertainties: npt.NDArray,
        test_predictions_by_month: List[npt.NDArray],
        test_labels_by_month: List[npt.NDArray],
        test_uncertainties_by_month: List[npt.NDArray],
        target_metric: Literal["FNR", "FPR"],
        target_values: List[float],
        adaptive: bool = False,
    ) -> List[TargetAdherenceMetrics]:
        """
        Run full grid simulation over multiple target values.

        Args:
            val_*: Validation data (pooled) for calibration
            test_*_by_month: Test data (per month) for evaluation
            target_metric: "FNR" or "FPR"
            target_values: List of target values to test
            adaptive: If True, use adaptive threshold mode

        Returns:
            List of TargetAdherenceMetrics, one per target value
        """
        results = []

        for target_value in target_values:
            # Calibrate threshold
            threshold = self.calibrate_threshold(
                val_predictions, val_labels, val_uncertainties,
                target_metric, target_value
            )

            # Simulate deployment
            monthly_results = self.simulate_deployment(
                test_predictions_by_month,
                test_labels_by_month,
                test_uncertainties_by_month,
                threshold=threshold,
                target_metric=target_metric,
                target_value=target_value,
                adaptive=adaptive,
            )

            # Compute adherence metrics
            metrics = self.compute_adherence_metrics(
                monthly_results, target_metric, target_value
            )
            results.append(metrics)

        return results

    def compute_reliability_diagram(
        self,
        adherence_results: List[TargetAdherenceMetrics],
    ) -> ReliabilityDiagramData:
        """
        Compute reliability diagram data from adherence results.

        Args:
            adherence_results: List of TargetAdherenceMetrics from simulate_target_grid

        Returns:
            ReliabilityDiagramData for visualization
        """
        targets = np.array([m.target_value for m in adherence_results])
        observed_mean = np.array([
            np.nanmean(m.monthly_actual) for m in adherence_results
        ])
        observed_std = np.array([
            np.nanstd(m.monthly_actual) for m in adherence_results
        ])
        coverage_mean = np.array([m.mean_coverage for m in adherence_results])
        coverage_std = np.array([m.cv_coverage * m.mean_coverage for m in adherence_results])

        # Compute confidence intervals via bootstrap
        ci_low = []
        ci_high = []

        for m in adherence_results:
            valid = m.monthly_actual[~np.isnan(m.monthly_actual)]
            if len(valid) >= 3:
                low, high = bootstrap_ci(valid, statistic='mean', confidence=0.95)
            else:
                low = high = np.nanmean(m.monthly_actual)
            ci_low.append(low)
            ci_high.append(high)

        # Get n_months from first result
        n_months = len(adherence_results[0].monthly_actual) if adherence_results else 0

        return ReliabilityDiagramData(
            targets=targets,
            observed_mean=observed_mean,
            observed_std=observed_std,
            observed_ci_low=np.array(ci_low),
            observed_ci_high=np.array(ci_high),
            coverage_mean=coverage_mean,
            coverage_std=coverage_std,
            n_months=n_months,
        )


# ============================================================================
# CONVENIENCE FUNCTIONS
# ============================================================================

def run_performance_rejection_analysis(
    val_predictions: npt.NDArray,
    val_labels: npt.NDArray,
    val_uncertainties: npt.NDArray,
    test_predictions_by_month: List[npt.NDArray],
    test_labels_by_month: List[npt.NDArray],
    test_uncertainties_by_month: List[npt.NDArray],
    fnr_targets: Optional[List[float]] = None,
    fpr_targets: Optional[List[float]] = None,
    adaptive: bool = False,
) -> Dict[str, Union[List[TargetAdherenceMetrics], ReliabilityDiagramData]]:
    """
    Run complete performance-based rejection analysis.

    Convenience function that runs target grid simulation for both
    FNR and FPR targets, and computes reliability diagrams.

    Args:
        val_*: Validation data (pooled) for calibration
        test_*_by_month: Test data (per month) for evaluation
        fnr_targets: List of FNR targets (default: [0.01, 0.02, 0.05, 0.10, 0.15, 0.20])
        fpr_targets: List of FPR targets (default: [0.01, 0.02, 0.05, 0.10, 0.15])
        adaptive: If True, use adaptive threshold mode

    Returns:
        Dictionary with:
            'fnr_metrics': List[TargetAdherenceMetrics]
            'fpr_metrics': List[TargetAdherenceMetrics]
            'fnr_reliability': ReliabilityDiagramData
            'fpr_reliability': ReliabilityDiagramData
    """
    if fnr_targets is None:
        fnr_targets = [0.01, 0.02, 0.05, 0.10, 0.15, 0.20]
    if fpr_targets is None:
        fpr_targets = [0.01, 0.02, 0.05, 0.10, 0.15]

    simulator = PerformanceRejectionSimulator()

    results = {}

    # FNR targets
    if fnr_targets:
        fnr_metrics = simulator.simulate_target_grid(
            val_predictions, val_labels, val_uncertainties,
            test_predictions_by_month, test_labels_by_month, test_uncertainties_by_month,
            target_metric="FNR",
            target_values=fnr_targets,
            adaptive=adaptive,
        )
        results['fnr_metrics'] = fnr_metrics
        results['fnr_reliability'] = simulator.compute_reliability_diagram(fnr_metrics)

    # FPR targets
    if fpr_targets:
        fpr_metrics = simulator.simulate_target_grid(
            val_predictions, val_labels, val_uncertainties,
            test_predictions_by_month, test_labels_by_month, test_uncertainties_by_month,
            target_metric="FPR",
            target_values=fpr_targets,
            adaptive=adaptive,
        )
        results['fpr_metrics'] = fpr_metrics
        results['fpr_reliability'] = simulator.compute_reliability_diagram(fpr_metrics)

    return results
