"""
Statistical analysis module for Aurora framework.

This module provides statistical tools for robust analysis:
- Bootstrap confidence intervals
- Pairwise method comparisons
- Statistical significance testing

The bootstrap approach is particularly useful when:
- We have limited seeds (e.g., 5 random seeds)
- We want non-parametric confidence intervals
- We need to compare methods without normality assumptions
"""

from typing import List, Dict, Tuple, Union, Optional, Callable
import numpy as np
import numpy.typing as npt


class BootstrapCI:
    """
    Compute bootstrap confidence intervals.

    Bootstrap resampling provides non-parametric confidence intervals
    without assuming any particular distribution of the data.

    **Why Bootstrap?**
    - Works with limited samples (e.g., 5 seeds × 24 months = 120 observations)
    - No normality assumption required
    - Provides intuitive confidence intervals

    **For Aurora**:
    - 5 seeds × 24 months = 120 observations (AndroZoo)
    - 5 seeds × 72 months = 360 observations (API-Graph)
    - This is sufficient for reliable bootstrap estimates

    Example:
        >>> f1_scores = [0.85, 0.87, 0.84, 0.86, 0.88, 0.83, 0.89]
        >>> mean, (ci_low, ci_high) = BootstrapCI.compute(f1_scores, np.mean)
        >>> print(f"F1: {mean:.3f} [{ci_low:.3f}, {ci_high:.3f}]")
    """

    @staticmethod
    def compute(
        data: Union[List[float], npt.NDArray],
        statistic: Callable[[npt.NDArray], float] = np.mean,
        n_bootstrap: int = 10000,
        confidence_level: float = 0.95,
        random_state: Optional[int] = None
    ) -> Tuple[float, Tuple[float, float]]:
        """
        Compute bootstrap confidence interval for a statistic.

        Args:
            data: Data array to resample
            statistic: Function to compute on each resample (default: np.mean)
            n_bootstrap: Number of bootstrap iterations (default: 10000)
            confidence_level: Confidence level (default: 0.95 for 95% CI)
            random_state: Random seed for reproducibility

        Returns:
            Tuple of (point_estimate, (ci_lower, ci_upper))

        Example:
            >>> data = np.array([0.85, 0.87, 0.84, 0.86, 0.88])
            >>> mean, (ci_low, ci_high) = BootstrapCI.compute(data, np.mean)
            >>> 0.83 < ci_low < ci_high < 0.89
            True
        """
        data = np.asarray(data)
        n = len(data)

        if n == 0:
            raise ValueError("Cannot bootstrap empty data")

        if n < 2:
            # With only 1 sample, CI is just that value
            val = float(statistic(data))
            return val, (val, val)

        if random_state is not None:
            np.random.seed(random_state)

        # Point estimate
        point_estimate = float(statistic(data))

        # Bootstrap resampling
        bootstrap_statistics = np.zeros(n_bootstrap)
        for i in range(n_bootstrap):
            resample_indices = np.random.randint(0, n, size=n)
            resample = data[resample_indices]
            bootstrap_statistics[i] = statistic(resample)

        # Compute percentile CI
        alpha = 1 - confidence_level
        ci_lower = float(np.percentile(bootstrap_statistics, 100 * alpha / 2))
        ci_upper = float(np.percentile(bootstrap_statistics, 100 * (1 - alpha / 2)))

        return point_estimate, (ci_lower, ci_upper)

    @staticmethod
    def compute_multiple(
        data: Union[List[float], npt.NDArray],
        statistics: Dict[str, Callable[[npt.NDArray], float]],
        n_bootstrap: int = 10000,
        confidence_level: float = 0.95,
        random_state: Optional[int] = None
    ) -> Dict[str, Tuple[float, Tuple[float, float]]]:
        """
        Compute bootstrap CIs for multiple statistics efficiently.

        Generates one set of bootstrap samples and computes all statistics,
        which is more efficient than calling compute() multiple times.

        Args:
            data: Data array
            statistics: Dict mapping statistic_name -> function
            n_bootstrap: Number of bootstrap iterations
            confidence_level: Confidence level
            random_state: Random seed

        Returns:
            Dict mapping statistic_name -> (point_estimate, (ci_lower, ci_upper))

        Example:
            >>> data = np.array([0.85, 0.87, 0.84, 0.86, 0.88, 0.83, 0.89])
            >>> stats = {
            ...     'mean': np.mean,
            ...     'std': np.std,
            ...     'median': np.median
            ... }
            >>> results = BootstrapCI.compute_multiple(data, stats)
            >>> for name, (val, ci) in results.items():
            ...     print(f"{name}: {val:.3f} [{ci[0]:.3f}, {ci[1]:.3f}]")
        """
        data = np.asarray(data)
        n = len(data)

        if n == 0:
            raise ValueError("Cannot bootstrap empty data")

        if random_state is not None:
            np.random.seed(random_state)

        # Point estimates
        point_estimates = {name: float(func(data)) for name, func in statistics.items()}

        # Edge case: single sample
        if n < 2:
            return {
                name: (val, (val, val))
                for name, val in point_estimates.items()
            }

        # Bootstrap resampling - compute all statistics on same resamples
        bootstrap_results = {name: np.zeros(n_bootstrap) for name in statistics}

        for i in range(n_bootstrap):
            resample_indices = np.random.randint(0, n, size=n)
            resample = data[resample_indices]

            for name, func in statistics.items():
                bootstrap_results[name][i] = func(resample)

        # Compute CIs
        alpha = 1 - confidence_level
        results = {}

        for name in statistics:
            ci_lower = float(np.percentile(bootstrap_results[name], 100 * alpha / 2))
            ci_upper = float(np.percentile(bootstrap_results[name], 100 * (1 - alpha / 2)))
            results[name] = (point_estimates[name], (ci_lower, ci_upper))

        return results


class BootstrapComparison:
    """
    Compare two methods using bootstrap resampling.

    Provides:
    - P(A > B): Probability that method A is better than B
    - Difference CI: Confidence interval for (mean_A - mean_B)
    - Significance test: Is the difference significant at given level?

    Example:
        >>> f1_a = np.array([0.85, 0.87, 0.84, 0.86, 0.88])
        >>> f1_b = np.array([0.82, 0.84, 0.81, 0.83, 0.85])
        >>> result = BootstrapComparison.compare(f1_a, f1_b)
        >>> print(f"P(A > B) = {result['prob_a_greater']:.3f}")
    """

    @staticmethod
    def compare(
        data_a: Union[List[float], npt.NDArray],
        data_b: Union[List[float], npt.NDArray],
        n_bootstrap: int = 10000,
        confidence_level: float = 0.95,
        random_state: Optional[int] = None
    ) -> Dict[str, float]:
        """
        Compare two methods using bootstrap.

        Args:
            data_a: Data from method A
            data_b: Data from method B
            n_bootstrap: Number of bootstrap iterations
            confidence_level: Confidence level for difference CI
            random_state: Random seed

        Returns:
            Dictionary with:
            - 'mean_a': Mean of method A
            - 'mean_b': Mean of method B
            - 'mean_diff': Mean difference (A - B)
            - 'diff_ci_lower': Lower CI for difference
            - 'diff_ci_upper': Upper CI for difference
            - 'prob_a_greater': P(mean_A > mean_B)
            - 'prob_b_greater': P(mean_B > mean_A)
            - 'significant': True if CI doesn't contain 0

        Example:
            >>> a = np.array([0.85, 0.87, 0.84, 0.86, 0.88])
            >>> b = np.array([0.82, 0.84, 0.81, 0.83, 0.85])
            >>> result = BootstrapComparison.compare(a, b)
            >>> result['prob_a_greater'] > 0.5  # A is likely better
            True
        """
        data_a = np.asarray(data_a)
        data_b = np.asarray(data_b)

        if len(data_a) == 0 or len(data_b) == 0:
            raise ValueError("Cannot compare empty data")

        if random_state is not None:
            np.random.seed(random_state)

        n_a = len(data_a)
        n_b = len(data_b)

        # Point estimates
        mean_a = float(np.mean(data_a))
        mean_b = float(np.mean(data_b))
        mean_diff = mean_a - mean_b

        # Bootstrap
        bootstrap_diff = np.zeros(n_bootstrap)
        count_a_greater = 0

        for i in range(n_bootstrap):
            # Resample each dataset independently
            resample_a = data_a[np.random.randint(0, n_a, size=n_a)]
            resample_b = data_b[np.random.randint(0, n_b, size=n_b)]

            mean_a_boot = np.mean(resample_a)
            mean_b_boot = np.mean(resample_b)

            bootstrap_diff[i] = mean_a_boot - mean_b_boot

            if mean_a_boot > mean_b_boot:
                count_a_greater += 1

        # Confidence interval for difference
        alpha = 1 - confidence_level
        ci_lower = float(np.percentile(bootstrap_diff, 100 * alpha / 2))
        ci_upper = float(np.percentile(bootstrap_diff, 100 * (1 - alpha / 2)))

        # Probability estimates
        prob_a_greater = count_a_greater / n_bootstrap
        prob_b_greater = 1 - prob_a_greater

        # Significance: CI doesn't contain 0
        significant = not (ci_lower <= 0 <= ci_upper)

        return {
            'mean_a': mean_a,
            'mean_b': mean_b,
            'mean_diff': mean_diff,
            'diff_ci_lower': ci_lower,
            'diff_ci_upper': ci_upper,
            'prob_a_greater': prob_a_greater,
            'prob_b_greater': prob_b_greater,
            'significant': significant
        }

    @staticmethod
    def compare_multiple(
        data_dict: Dict[str, Union[List[float], npt.NDArray]],
        n_bootstrap: int = 10000,
        confidence_level: float = 0.95,
        random_state: Optional[int] = None
    ) -> Dict[Tuple[str, str], Dict[str, float]]:
        """
        Compare all pairs of methods.

        Args:
            data_dict: Dict mapping method_name -> data array
            n_bootstrap: Number of bootstrap iterations
            confidence_level: Confidence level
            random_state: Random seed

        Returns:
            Dict mapping (method_a, method_b) -> comparison result

        Example:
            >>> data = {
            ...     'DeepDrebin': np.array([0.85, 0.87, 0.84]),
            ...     'HCC': np.array([0.88, 0.89, 0.87]),
            ...     'Drebin': np.array([0.80, 0.82, 0.81])
            ... }
            >>> comparisons = BootstrapComparison.compare_multiple(data)
            >>> for (a, b), result in comparisons.items():
            ...     print(f"{a} vs {b}: P({a} > {b}) = {result['prob_a_greater']:.3f}")
        """
        methods = list(data_dict.keys())
        results = {}

        for i, method_a in enumerate(methods):
            for method_b in methods[i+1:]:
                result = BootstrapComparison.compare(
                    data_dict[method_a],
                    data_dict[method_b],
                    n_bootstrap=n_bootstrap,
                    confidence_level=confidence_level,
                    random_state=random_state
                )
                results[(method_a, method_b)] = result

        return results


class StatisticalSummary:
    """
    Generate publication-ready statistical summaries.

    Combines point estimates with confidence intervals in various formats.
    """

    @staticmethod
    def format_with_ci(
        value: float,
        ci: Tuple[float, float],
        decimals: int = 3,
        as_percent: bool = False,
        format_style: str = "brackets"
    ) -> str:
        """
        Format a value with its confidence interval.

        Args:
            value: Point estimate
            ci: (lower, upper) confidence interval
            decimals: Number of decimal places
            as_percent: If True, multiply by 100 and add %
            format_style: "brackets" for "[a, b]", "pm" for "±", "parens" for "(a, b)"

        Returns:
            Formatted string

        Examples:
            >>> StatisticalSummary.format_with_ci(0.85, (0.82, 0.88))
            '0.850 [0.820, 0.880]'

            >>> StatisticalSummary.format_with_ci(0.85, (0.82, 0.88), as_percent=True)
            '85.0% [82.0%, 88.0%]'

            >>> StatisticalSummary.format_with_ci(0.85, (0.82, 0.88), format_style="pm")
            '0.850 ± 0.030'
        """
        if as_percent:
            value = value * 100
            ci = (ci[0] * 100, ci[1] * 100)

        fmt = f".{decimals}f"
        val_str = f"{value:{fmt}}"
        ci_low_str = f"{ci[0]:{fmt}}"
        ci_high_str = f"{ci[1]:{fmt}}"

        if as_percent:
            val_str += "%"
            ci_low_str += "%"
            ci_high_str += "%"

        if format_style == "brackets":
            return f"{val_str} [{ci_low_str}, {ci_high_str}]"
        elif format_style == "pm":
            # Use half-width as ± value
            half_width = (ci[1] - ci[0]) / 2
            hw_str = f"{half_width:{fmt}}"
            if as_percent:
                hw_str += "%"
            return f"{val_str} ± {hw_str}"
        elif format_style == "parens":
            return f"{val_str} ({ci_low_str}, {ci_high_str})"
        else:
            raise ValueError(f"Unknown format_style: {format_style}")

    @staticmethod
    def create_summary_table(
        data_dict: Dict[str, Union[List[float], npt.NDArray]],
        statistic: Callable[[npt.NDArray], float] = np.mean,
        n_bootstrap: int = 10000,
        confidence_level: float = 0.95,
        as_percent: bool = False,
        decimals: int = 2,
        random_state: Optional[int] = None
    ) -> Dict[str, str]:
        """
        Create a summary table with CIs for multiple methods.

        Args:
            data_dict: Dict mapping method_name -> data array
            statistic: Function to compute (default: np.mean)
            n_bootstrap: Number of bootstrap iterations
            confidence_level: Confidence level
            as_percent: Format as percentages
            decimals: Number of decimal places
            random_state: Random seed

        Returns:
            Dict mapping method_name -> formatted string with CI

        Example:
            >>> data = {
            ...     'DeepDrebin': np.array([0.85, 0.87, 0.84]),
            ...     'HCC': np.array([0.88, 0.89, 0.87])
            ... }
            >>> table = StatisticalSummary.create_summary_table(data, as_percent=True)
            >>> for method, summary in table.items():
            ...     print(f"{method}: {summary}")
        """
        results = {}

        for method_name, data in data_dict.items():
            value, ci = BootstrapCI.compute(
                data,
                statistic=statistic,
                n_bootstrap=n_bootstrap,
                confidence_level=confidence_level,
                random_state=random_state
            )
            results[method_name] = StatisticalSummary.format_with_ci(
                value, ci, decimals=decimals, as_percent=as_percent
            )

        return results


# Convenience functions for common use cases
def bootstrap_ci(
    data: Union[List[float], npt.NDArray],
    statistic: Callable[[npt.NDArray], float] = np.mean,
    n_bootstrap: int = 10000,
    ci: float = 0.95,
    random_state: Optional[int] = None
) -> Tuple[float, Tuple[float, float]]:
    """
    Shortcut for computing bootstrap CI.

    Args:
        data: Data array
        statistic: Function to compute
        n_bootstrap: Number of iterations
        ci: Confidence level
        random_state: Random seed for reproducibility

    Returns:
        (point_estimate, (ci_lower, ci_upper))

    Example:
        >>> data = [0.85, 0.87, 0.84, 0.86, 0.88]
        >>> mean, (low, high) = bootstrap_ci(data)
        >>> print(f"Mean F1: {mean:.3f} [{low:.3f}, {high:.3f}]")
    """
    return BootstrapCI.compute(data, statistic, n_bootstrap, ci, random_state)


def bootstrap_compare(
    data_a: Union[List[float], npt.NDArray],
    data_b: Union[List[float], npt.NDArray],
    n_bootstrap: int = 10000
) -> float:
    """
    Shortcut for comparing two methods.

    Returns P(mean_A > mean_B).

    Args:
        data_a: Data from method A
        data_b: Data from method B
        n_bootstrap: Number of iterations

    Returns:
        Probability that A is better than B

    Example:
        >>> a = [0.85, 0.87, 0.84, 0.86, 0.88]
        >>> b = [0.82, 0.84, 0.81, 0.83, 0.85]
        >>> prob = bootstrap_compare(a, b)
        >>> print(f"P(A > B) = {prob:.3f}")
    """
    result = BootstrapComparison.compare(data_a, data_b, n_bootstrap)
    return result['prob_a_greater']
