"""
Uncertainty quality metrics for Aurora framework

This module provides metrics for assessing the quality of uncertainty estimates:
- AURC (Area Under Risk-Coverage curve)
- E-AURC (Excess AURC, normalized)
- Risk-coverage curve data for visualization
- Multi-seed support for robustness analysis

The key question: Can the classifier identify its own mistakes?
Good uncertainty → rejecting uncertain samples reduces error rate
Bad uncertainty → rejecting uncertain samples doesn't help

Example:
    >>> from aurora.uncertainty import UncertaintyQualityMetrics
    >>>
    >>> # Compute E-AURC (lower is better)
    >>> eaurc = UncertaintyQualityMetrics.compute_eaurc(
    ...     predictions=predictions,
    ...     labels=labels,
    ...     uncertainties=uncertainties
    ... )
    >>>
    >>> # Get risk-coverage curve for plotting
    >>> curve_data = UncertaintyQualityMetrics.compute_risk_coverage_curve(
    ...     predictions=predictions,
    ...     labels=labels,
    ...     uncertainties=uncertainties
    ... )
"""

from typing import List, Dict, Tuple, Union, Optional
import numpy as np
import numpy.typing as npt
import numba

# Try to import from tools.py for consistency, fallback to local implementation
try:
    from tools import (
        values_to_quantiles_vectorized,
        compute_histograms_and_auc,
        compute_aurc as tools_compute_aurc
    )
    _USING_TOOLS = True
except ImportError:
    _USING_TOOLS = False


# ============================================================================
# Helper Functions (local implementations if tools.py not available)
# ============================================================================

def _values_to_quantiles_vectorized(arr: npt.NDArray) -> npt.NDArray:
    """
    Convert array of values to their quantile distribution using vectorization.

    Args:
        arr: Array of uncertainty values

    Returns:
        Array of quantiles in range [0, 1]
    """
    arr = np.asarray(arr)

    # Handle edge case where all values are the same
    if len(arr) == 0:
        return arr

    if np.all(arr == arr[0]):
        return np.ones_like(arr, dtype=float)

    # Use argsort twice to efficiently compute ranks
    ranks = np.argsort(np.argsort(arr)).astype(float) / (len(arr) - 1)

    return ranks


@numba.jit(nopython=True, parallel=True)
def _compute_histograms_and_auc(
    all_labels: npt.NDArray,
    all_predictions: npt.NDArray,
    all_uncertainties: npt.NDArray
) -> float:
    """
    Compute AURC (Area Under Risk-Coverage curve) using Numba for speed.

    This function:
    1. Bins the quantile-transformed uncertainties
    2. Computes cumulative correct/incorrect counts per bin
    3. Integrates the risk-coverage curve using trapezoidal rule

    Args:
        all_labels: True labels
        all_predictions: Predicted labels
        all_uncertainties: Uncertainty scores (already quantile-transformed)

    Returns:
        AURC value (area under the curve)
    """
    n_total = len(all_uncertainties)

    # Compute a boolean mask for correct predictions
    correct_mask = np.zeros(n_total, dtype=np.bool_)
    for i in numba.prange(n_total):
        correct_mask[i] = all_labels[i] == all_predictions[i]

    # Choose number of bins (based on sample size)
    n_bins = min(100, max(20, int(np.sqrt(n_total))))
    hist_correct = np.zeros(n_bins, dtype=np.int64)
    hist_incorrect = np.zeros(n_bins, dtype=np.int64)

    # Bin width based on quantiles (0 to 1)
    bin_width = 1.0 / n_bins
    for i in range(n_total):
        bin_idx = min(n_bins - 1, int(all_uncertainties[i] / bin_width))
        if correct_mask[i]:
            hist_correct[bin_idx] += 1
        else:
            hist_incorrect[bin_idx] += 1

    # Compute cumulative counts per bin
    cum_correct = np.zeros(n_bins, dtype=np.int64)
    cum_incorrect = np.zeros(n_bins, dtype=np.int64)
    cum_correct[0] = hist_correct[0]
    cum_incorrect[0] = hist_incorrect[0]
    for i in range(1, n_bins):
        cum_correct[i] = cum_correct[i - 1] + hist_correct[i]
        cum_incorrect[i] = cum_incorrect[i - 1] + hist_incorrect[i]

    # Compute risk and coverage for each bin
    cum_total = np.zeros(n_bins, dtype=np.int64)
    risk = np.zeros(n_bins, dtype=np.float64)
    coverage = np.zeros(n_bins, dtype=np.float64)
    for i in range(n_bins):
        cum_total[i] = cum_correct[i] + cum_incorrect[i]
        coverage[i] = cum_total[i] / n_total
        if cum_total[i] > 0:
            risk[i] = cum_incorrect[i] / cum_total[i]

    # Compute the AUC using the trapezoidal rule
    auc = 0.0
    if coverage[0] > 0.01:
        # Segment from (0,0) to the first point
        auc += (0.0 + risk[0]) * coverage[0] / 2.0
        # Remaining segments
        for i in range(n_bins - 1):
            auc += (risk[i] + risk[i + 1]) * (coverage[i + 1] - coverage[i]) / 2.0
    else:
        for i in range(n_bins - 1):
            auc += (risk[i] + risk[i + 1]) * (coverage[i + 1] - coverage[i]) / 2.0

    return auc


@numba.jit(nopython=True, parallel=True)
def _compute_risk_coverage_curve(
    all_labels: npt.NDArray,
    all_predictions: npt.NDArray,
    all_uncertainties: npt.NDArray,
    n_bins: int = 100
) -> Tuple[npt.NDArray, npt.NDArray]:
    """
    Compute risk-coverage curve data for visualization.

    Returns arrays of (coverage, risk) points.
    """
    n_total = len(all_uncertainties)

    # Compute a boolean mask for correct predictions
    correct_mask = np.zeros(n_total, dtype=np.bool_)
    for i in numba.prange(n_total):
        correct_mask[i] = all_labels[i] == all_predictions[i]

    # Use specified number of bins
    n_bins = min(n_bins, n_total)
    hist_correct = np.zeros(n_bins, dtype=np.int64)
    hist_incorrect = np.zeros(n_bins, dtype=np.int64)

    # Bin width based on quantiles (0 to 1)
    bin_width = 1.0 / n_bins
    for i in range(n_total):
        bin_idx = min(n_bins - 1, int(all_uncertainties[i] / bin_width))
        if correct_mask[i]:
            hist_correct[bin_idx] += 1
        else:
            hist_incorrect[bin_idx] += 1

    # Compute cumulative counts per bin
    cum_correct = np.zeros(n_bins, dtype=np.int64)
    cum_incorrect = np.zeros(n_bins, dtype=np.int64)
    cum_correct[0] = hist_correct[0]
    cum_incorrect[0] = hist_incorrect[0]
    for i in range(1, n_bins):
        cum_correct[i] = cum_correct[i - 1] + hist_correct[i]
        cum_incorrect[i] = cum_incorrect[i - 1] + hist_incorrect[i]

    # Compute risk and coverage for each bin
    cum_total = np.zeros(n_bins, dtype=np.int64)
    risk = np.zeros(n_bins, dtype=np.float64)
    coverage = np.zeros(n_bins, dtype=np.float64)
    for i in range(n_bins):
        cum_total[i] = cum_correct[i] + cum_incorrect[i]
        coverage[i] = cum_total[i] / n_total
        if cum_total[i] > 0:
            risk[i] = cum_incorrect[i] / cum_total[i]

    return coverage, risk


# ============================================================================
# UncertaintyQualityMetrics Class
# ============================================================================

class UncertaintyQualityMetrics:
    """
    Compute uncertainty quality metrics (AURC, E-AURC)

    Uncertainty quality measures whether the classifier can identify its mistakes.
    Good uncertainty: Rejecting uncertain samples → lower error rate
    Bad uncertainty: Rejecting uncertain samples → doesn't help

    **AURC (Area Under Risk-Coverage)**:
    - Lower is better (less area = better uncertainty)
    - Measures average risk across all coverage levels
    - Raw metric, not normalized

    **E-AURC (Excess AURC)**:
    - AURC normalized by optimal AURC
    - E-AURC = AURC - AURC_optimal
    - AURC_optimal = r + (1-r) * ln(1-r) where r = error rate
    - Lower is better (0 = optimal uncertainty)
    - Recommended metric for comparisons

    **Interpretation**:
    - E-AURC < 0.05 → Excellent uncertainty calibration
    - E-AURC < 0.10 → Good uncertainty calibration
    - E-AURC > 0.15 → Poor uncertainty calibration

    Example:
        >>> # Compute E-AURC
        >>> eaurc = UncertaintyQualityMetrics.compute_eaurc(
        ...     predictions=preds,
        ...     labels=labs,
        ...     uncertainties=uncs
        ... )
        >>> print(f"E-AURC: {eaurc:.4f}")
        >>>
        >>> # Get both AURC and E-AURC
        >>> aurc, eaurc = UncertaintyQualityMetrics.compute_both(preds, labs, uncs)
        >>>
        >>> # Get risk-coverage curve for plotting
        >>> coverage, risk = UncertaintyQualityMetrics.compute_risk_coverage_curve(
        ...     preds, labs, uncs
        ... )
        >>> plt.plot(coverage, risk)
    """

    @staticmethod
    def compute_aurc(
        predictions: Union[List[npt.NDArray], npt.NDArray],
        labels: Union[List[npt.NDArray], npt.NDArray],
        uncertainties: Union[List[npt.NDArray], npt.NDArray]
    ) -> float:
        """
        Compute AURC (Area Under Risk-Coverage curve)

        Args:
            predictions: Predicted labels (list of arrays or single array)
            labels: True labels (list of arrays or single array)
            uncertainties: Uncertainty scores (list of arrays or single array)
                          Higher uncertainty = less confident

        Returns:
            AURC value (lower is better)

        Example:
            >>> preds = [np.array([1, 0, 1]), np.array([1, 1, 0])]
            >>> labs = [np.array([1, 1, 1]), np.array([1, 0, 0])]
            >>> uncs = [np.array([0.1, 0.9, 0.2]), np.array([0.3, 0.4, 0.8])]
            >>> aurc = UncertaintyQualityMetrics.compute_aurc(preds, labs, uncs)
        """
        # Concatenate if lists
        if isinstance(predictions, list):
            all_preds = np.concatenate(predictions)
            all_labels = np.concatenate(labels)
            all_uncs = np.concatenate(uncertainties)
        else:
            all_preds = predictions
            all_labels = labels
            all_uncs = uncertainties

        # Validate inputs
        if len(all_preds) != len(all_labels) or len(all_preds) != len(all_uncs):
            raise ValueError(
                f"Length mismatch: predictions={len(all_preds)}, "
                f"labels={len(all_labels)}, uncertainties={len(all_uncs)}"
            )

        if len(all_preds) == 0:
            raise ValueError("Cannot compute AURC on empty arrays")

        # Convert uncertainties to quantiles
        if _USING_TOOLS:
            quantile_uncs = values_to_quantiles_vectorized(all_uncs)
        else:
            quantile_uncs = _values_to_quantiles_vectorized(all_uncs)

        # Compute AURC
        if _USING_TOOLS:
            aurc = compute_histograms_and_auc(all_labels, all_preds, quantile_uncs)
        else:
            aurc = _compute_histograms_and_auc(all_labels, all_preds, quantile_uncs)

        return float(aurc)

    @staticmethod
    def compute_eaurc(
        predictions: Union[List[npt.NDArray], npt.NDArray],
        labels: Union[List[npt.NDArray], npt.NDArray],
        uncertainties: Union[List[npt.NDArray], npt.NDArray]
    ) -> float:
        """
        Compute E-AURC (Excess AURC, normalized version)

        E-AURC = AURC - AURC_optimal

        where AURC_optimal = r + (1-r) * ln(1-r) and r = error rate

        Args:
            predictions: Predicted labels
            labels: True labels
            uncertainties: Uncertainty scores

        Returns:
            E-AURC value (lower is better, 0 = optimal)

        Example:
            >>> eaurc = UncertaintyQualityMetrics.compute_eaurc(preds, labs, uncs)
            >>> if eaurc < 0.05:
            ...     print("Excellent uncertainty calibration!")
        """
        # Concatenate if lists
        if isinstance(predictions, list):
            all_preds = np.concatenate(predictions)
            all_labels = np.concatenate(labels)
            all_uncs = np.concatenate(uncertainties)
        else:
            all_preds = predictions
            all_labels = labels
            all_uncs = uncertainties

        # Compute AURC
        aurc = UncertaintyQualityMetrics.compute_aurc(
            all_preds, all_labels, all_uncs
        )

        # Compute error rate
        r_hat = np.mean(all_labels != all_preds)

        # Compute optimal AURC
        if r_hat < 1.0:
            optimal_aurc = r_hat + (1 - r_hat) * np.log(1 - r_hat)
        else:
            optimal_aurc = r_hat  # All predictions wrong

        # Excess AURC
        eaurc = aurc - optimal_aurc

        return float(eaurc)

    @staticmethod
    def compute_both(
        predictions: Union[List[npt.NDArray], npt.NDArray],
        labels: Union[List[npt.NDArray], npt.NDArray],
        uncertainties: Union[List[npt.NDArray], npt.NDArray]
    ) -> Tuple[float, float]:
        """
        Compute both AURC and E-AURC (more efficient than calling separately)

        Args:
            predictions: Predicted labels
            labels: True labels
            uncertainties: Uncertainty scores

        Returns:
            Tuple of (AURC, E-AURC)

        Example:
            >>> aurc, eaurc = UncertaintyQualityMetrics.compute_both(preds, labs, uncs)
            >>> print(f"AURC: {aurc:.4f}, E-AURC: {eaurc:.4f}")
        """
        # Concatenate if lists
        if isinstance(predictions, list):
            all_preds = np.concatenate(predictions)
            all_labels = np.concatenate(labels)
            all_uncs = np.concatenate(uncertainties)
        else:
            all_preds = predictions
            all_labels = labels
            all_uncs = uncertainties

        # Compute AURC
        aurc = UncertaintyQualityMetrics.compute_aurc(
            all_preds, all_labels, all_uncs
        )

        # Compute error rate
        r_hat = np.mean(all_labels != all_preds)

        # Compute optimal AURC
        if r_hat < 1.0:
            optimal_aurc = r_hat + (1 - r_hat) * np.log(1 - r_hat)
        else:
            optimal_aurc = r_hat

        # Excess AURC
        eaurc = aurc - optimal_aurc

        return float(aurc), float(eaurc)

    @staticmethod
    def compute_risk_coverage_curve(
        predictions: Union[List[npt.NDArray], npt.NDArray],
        labels: Union[List[npt.NDArray], npt.NDArray],
        uncertainties: Union[List[npt.NDArray], npt.NDArray],
        n_bins: int = 100
    ) -> Dict[str, npt.NDArray]:
        """
        Compute risk-coverage curve data for visualization

        The risk-coverage curve shows:
        - X-axis: Coverage (fraction of samples retained)
        - Y-axis: Risk (error rate on retained samples)

        As we reject more uncertain samples:
        - Coverage decreases (fewer samples retained)
        - Risk should decrease (lower error rate)

        Good uncertainty → steep drop in risk
        Bad uncertainty → flat curve

        Args:
            predictions: Predicted labels
            labels: True labels
            uncertainties: Uncertainty scores
            n_bins: Number of bins for the curve (default: 100)

        Returns:
            Dictionary with:
            - 'coverage': Array of coverage values [0, 1]
            - 'risk': Array of risk values [0, 1]
            - 'n_samples': Number of samples at each coverage level

        Example:
            >>> curve = UncertaintyQualityMetrics.compute_risk_coverage_curve(
            ...     preds, labs, uncs
            ... )
            >>> plt.plot(curve['coverage'], curve['risk'])
            >>> plt.xlabel('Coverage')
            >>> plt.ylabel('Risk (Error Rate)')
            >>> plt.title('Risk-Coverage Curve')
        """
        # Concatenate if lists
        if isinstance(predictions, list):
            all_preds = np.concatenate(predictions)
            all_labels = np.concatenate(labels)
            all_uncs = np.concatenate(uncertainties)
        else:
            all_preds = predictions
            all_labels = labels
            all_uncs = uncertainties

        # Convert to quantiles
        if _USING_TOOLS:
            quantile_uncs = values_to_quantiles_vectorized(all_uncs)
        else:
            quantile_uncs = _values_to_quantiles_vectorized(all_uncs)

        # Compute curve
        coverage, risk = _compute_risk_coverage_curve(
            all_labels, all_preds, quantile_uncs, n_bins
        )

        # Compute sample counts
        n_total = len(all_preds)
        n_samples = (coverage * n_total).astype(int)

        return {
            'coverage': coverage,
            'risk': risk,
            'n_samples': n_samples
        }

    @staticmethod
    def compute_multi_seed(
        predictions_seeds: List[List[npt.NDArray]],
        labels_seeds: List[List[npt.NDArray]],
        uncertainties_seeds: List[List[npt.NDArray]],
        metric: str = 'eaurc'
    ) -> Dict[str, float]:
        """
        Compute uncertainty metrics across multiple random seeds

        This provides robustness analysis by running the same experiment
        with different random seeds and aggregating results.

        Args:
            predictions_seeds: List of predictions for each seed
                              Format: [[seed0_month0, seed0_month1, ...], [seed1_month0, ...], ...]
            labels_seeds: List of labels for each seed (same format)
            uncertainties_seeds: List of uncertainties for each seed (same format)
            metric: Which metric to compute ('aurc', 'eaurc', or 'both')

        Returns:
            Dictionary with:
            - 'mean': Mean across seeds
            - 'std': Standard deviation across seeds
            - 'min': Minimum across seeds
            - 'max': Maximum across seeds
            - 'median': Median across seeds
            - 'all_values': List of values for each seed
            - 'n_seeds': Number of seeds

        Example:
            >>> # Data from 3 random seeds
            >>> preds_seeds = [seed0_preds, seed1_preds, seed2_preds]
            >>> labs_seeds = [seed0_labs, seed1_labs, seed2_labs]
            >>> uncs_seeds = [seed0_uncs, seed1_uncs, seed2_uncs]
            >>>
            >>> stats = UncertaintyQualityMetrics.compute_multi_seed(
            ...     preds_seeds, labs_seeds, uncs_seeds, metric='eaurc'
            ... )
            >>> print(f"E-AURC: {stats['mean']:.4f} ± {stats['std']:.4f}")
        """
        n_seeds = len(predictions_seeds)

        if n_seeds == 0:
            raise ValueError("No seeds provided")

        if not (len(labels_seeds) == n_seeds and len(uncertainties_seeds) == n_seeds):
            raise ValueError("Number of seeds must match for predictions, labels, and uncertainties")

        # Compute metric for each seed
        values = []
        for seed_idx in range(n_seeds):
            if metric == 'aurc':
                value = UncertaintyQualityMetrics.compute_aurc(
                    predictions_seeds[seed_idx],
                    labels_seeds[seed_idx],
                    uncertainties_seeds[seed_idx]
                )
            elif metric == 'eaurc':
                value = UncertaintyQualityMetrics.compute_eaurc(
                    predictions_seeds[seed_idx],
                    labels_seeds[seed_idx],
                    uncertainties_seeds[seed_idx]
                )
            elif metric == 'both':
                aurc, eaurc = UncertaintyQualityMetrics.compute_both(
                    predictions_seeds[seed_idx],
                    labels_seeds[seed_idx],
                    uncertainties_seeds[seed_idx]
                )
                value = eaurc  # Default to E-AURC for aggregation
            else:
                raise ValueError(f"Invalid metric: {metric}. Must be 'aurc', 'eaurc', or 'both'")

            values.append(value)

        # Aggregate statistics
        values_arr = np.array(values)

        return {
            'mean': float(np.mean(values_arr)),
            'std': float(np.std(values_arr)),
            'min': float(np.min(values_arr)),
            'max': float(np.max(values_arr)),
            'median': float(np.median(values_arr)),
            'all_values': values,
            'n_seeds': n_seeds
        }
