"""
Pareto Dominance Analysis for Multi-Objective Malware Classifier Evaluation

This module provides Pareto frontier analysis for comparing malware classifiers
across multiple objectives (reliability, stability, tail risk, temporal).

Key Concepts:
- Pareto Dominance: Solution A dominates B if A is at least as good in ALL
  objectives and strictly better in at least one.
- Pareto Frontier: Set of all non-dominated solutions (rational choices).
- Universal Pareto Optimality: Solutions on frontier across ALL datasets.

The Four Pillars Framework:
- Reliability (R): Mean F1 - "How well does it detect?"
- Stability (S): σ[F1] - "How consistent is it?" (NOT CV[F1] which has issues)
- Tail Risk (T): Min[F1] - "What's the worst case?"
- Temporal (τ): Mann-Kendall tau - "Is it getting worse?"

Usage:
    >>> from aurora.pareto import (
    ...     ParetoAnalyzer,
    ...     compute_pareto_metrics,
    ...     find_pareto_frontier,
    ... )
    >>>
    >>> # Compute metrics for all methods
    >>> metrics = compute_pareto_metrics(results_by_method)
    >>>
    >>> # Find Pareto frontier
    >>> analyzer = ParetoAnalyzer(metrics)
    >>> frontier = analyzer.find_frontier()
    >>>
    >>> # Check universal optimality across datasets
    >>> universal = analyzer.find_universal_pareto_optimal()
"""

from typing import Dict, List, Tuple, Optional, Set, Any, Union
from dataclasses import dataclass, field
import numpy as np
import numpy.typing as npt

from .metrics import (
    ReliabilityMetrics,
    StabilityMetrics,
    DrawdownMetrics,
)


# =============================================================================
# DATA STRUCTURES
# =============================================================================

@dataclass
class ParetoMetrics:
    """
    Metrics for a single method on a single dataset.

    All metrics are oriented so that HIGHER IS BETTER.
    """
    # Identification
    method_name: str
    dataset: str

    # Four Pillars (higher is better)
    mean_f1: float           # Reliability: average F1
    sigma_f1: float          # Stability: σ[F1] (INVERTED: we store 1 - normalized)
    min_f1: float            # Tail Risk: worst-case F1
    mann_kendall_tau: float  # Temporal: trend statistic

    # Additional metrics (optional)
    cv_f1: Optional[float] = None      # CV[F1] (for reference, not recommended)
    p5_f1: Optional[float] = None      # 5th percentile F1
    max_f1: Optional[float] = None     # Best-case F1

    # Raw values (before normalization)
    raw_values: Dict[str, float] = field(default_factory=dict)

    def to_vector(self, metrics: List[str] = None) -> np.ndarray:
        """
        Convert to numpy vector for Pareto comparison.

        Args:
            metrics: List of metric names to include.
                    Default: ['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau']

        Returns:
            Numpy array of metric values (all oriented higher-is-better)
        """
        if metrics is None:
            metrics = ['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau']

        return np.array([getattr(self, m) for m in metrics])


@dataclass
class ParetoResult:
    """
    Result of Pareto analysis for a single dataset.
    """
    dataset: str
    frontier: List[str]          # Method names on frontier
    dominated: List[str]         # Method names dominated
    dominance_matrix: Dict[str, List[str]]  # method -> list of methods it dominates
    metrics: Dict[str, ParetoMetrics]       # method_name -> metrics


@dataclass
class UniversalParetoResult:
    """
    Cross-dataset Pareto analysis result.
    """
    # Methods on frontier in ALL datasets
    universal_optimal: List[str]

    # Methods on frontier in SOME datasets
    partial_optimal: Dict[str, List[str]]  # method -> list of datasets where optimal

    # Methods NEVER on frontier (always dominated)
    always_dominated: List[str]

    # Per-dataset results
    dataset_results: Dict[str, ParetoResult]

    # Summary statistics
    frontier_counts: Dict[str, int]  # method -> number of datasets where on frontier


# =============================================================================
# PARETO DOMINANCE FUNCTIONS
# =============================================================================

def pareto_dominates(a: npt.NDArray, b: npt.NDArray, epsilon: float = 1e-9) -> bool:
    """
    Check if solution 'a' Pareto-dominates solution 'b'.

    Definition: A dominates B if:
    1. A is at least as good as B in ALL objectives
    2. A is strictly better than B in AT LEAST ONE objective

    All values should be oriented so that HIGHER IS BETTER.

    Args:
        a: Metric vector for solution A
        b: Metric vector for solution B
        epsilon: Tolerance for numerical comparison

    Returns:
        True if a dominates b, False otherwise

    Example:
        >>> a = np.array([0.9, 0.8, 0.7])  # Better in all
        >>> b = np.array([0.8, 0.7, 0.6])
        >>> pareto_dominates(a, b)
        True
        >>> pareto_dominates(b, a)
        False
    """
    a = np.asarray(a)
    b = np.asarray(b)

    if len(a) != len(b):
        raise ValueError(f"Vectors must have same length: {len(a)} vs {len(b)}")

    # Condition 1: a is at least as good as b in all objectives
    at_least_as_good = np.all(a >= b - epsilon)

    # Condition 2: a is strictly better in at least one objective
    strictly_better = np.any(a > b + epsilon)

    return at_least_as_good and strictly_better


def find_pareto_frontier(
    solutions: Dict[str, npt.NDArray],
    return_dominated: bool = False
) -> Union[List[str], Tuple[List[str], List[str]]]:
    """
    Find the Pareto frontier from a set of solutions.

    The Pareto frontier contains all non-dominated solutions - these are
    the only rational choices as any dominated solution can be improved
    in at least one objective without sacrificing any other.

    Args:
        solutions: Dict mapping solution name to metric vector
                  (all metrics oriented so higher is better)
        return_dominated: If True, also return list of dominated solutions

    Returns:
        List of solution names on the Pareto frontier.
        If return_dominated=True, returns (frontier, dominated) tuple.

    Example:
        >>> solutions = {
        ...     'A': np.array([0.9, 0.8]),
        ...     'B': np.array([0.8, 0.9]),
        ...     'C': np.array([0.7, 0.7]),  # Dominated by A and B
        ... }
        >>> find_pareto_frontier(solutions)
        ['A', 'B']
    """
    names = list(solutions.keys())
    n = len(names)

    if n == 0:
        return ([], []) if return_dominated else []

    is_dominated = {name: False for name in names}

    # Check all pairs
    for i, name_a in enumerate(names):
        if is_dominated[name_a]:
            continue

        for j, name_b in enumerate(names):
            if i == j or is_dominated[name_a]:
                continue

            # Check if name_b dominates name_a
            if pareto_dominates(solutions[name_b], solutions[name_a]):
                is_dominated[name_a] = True
                break

    frontier = [name for name in names if not is_dominated[name]]

    if return_dominated:
        dominated = [name for name in names if is_dominated[name]]
        return frontier, dominated

    return frontier


def compute_dominance_matrix(solutions: Dict[str, npt.NDArray]) -> Dict[str, List[str]]:
    """
    Compute which solutions dominate which other solutions.

    Args:
        solutions: Dict mapping solution name to metric vector

    Returns:
        Dict mapping each solution to list of solutions it dominates

    Example:
        >>> solutions = {'A': np.array([0.9, 0.9]), 'B': np.array([0.7, 0.7])}
        >>> compute_dominance_matrix(solutions)
        {'A': ['B'], 'B': []}
    """
    names = list(solutions.keys())
    dominates = {name: [] for name in names}

    for name_a in names:
        for name_b in names:
            if name_a != name_b:
                if pareto_dominates(solutions[name_a], solutions[name_b]):
                    dominates[name_a].append(name_b)

    return dominates


# =============================================================================
# METRIC NORMALIZATION
# =============================================================================

def normalize_metrics(
    metrics_by_method: Dict[str, Dict[str, float]],
    higher_is_better: Dict[str, bool] = None
) -> Dict[str, Dict[str, float]]:
    """
    Min-max normalize metrics to [0, 1] range.

    For metrics where lower is better (e.g., σ[F1]), the value is inverted
    so that 1.0 always means BEST and 0.0 always means WORST.

    Args:
        metrics_by_method: Dict mapping method name to dict of metric values
        higher_is_better: Dict mapping metric name to bool (True if higher is better).
                         Default: mean_f1, min_f1, p5_f1, mann_kendall_tau are higher-is-better;
                                  cv_f1, sigma_f1 are lower-is-better

    Returns:
        Dict with same structure but normalized values

    Example:
        >>> metrics = {
        ...     'A': {'mean_f1': 0.9, 'sigma_f1': 0.05},
        ...     'B': {'mean_f1': 0.7, 'sigma_f1': 0.15},
        ... }
        >>> normalized = normalize_metrics(metrics)
        >>> normalized['A']['mean_f1']  # Best mean_f1 -> 1.0
        1.0
        >>> normalized['A']['sigma_f1']  # Best (lowest) sigma -> 1.0
        1.0
    """
    if higher_is_better is None:
        higher_is_better = {
            'mean_f1': True,
            'min_f1': True,
            'max_f1': True,
            'p5_f1': True,
            'mann_kendall_tau': True,
            'cv_f1': False,      # Lower is better
            'sigma_f1': False,   # Lower is better
            'mad_f1': False,     # Lower is better
        }

    if not metrics_by_method:
        return {}

    # Get all metric names
    first_method = next(iter(metrics_by_method.values()))
    metric_names = list(first_method.keys())

    # Compute min/max for each metric
    min_vals = {}
    max_vals = {}

    for metric in metric_names:
        values = [m[metric] for m in metrics_by_method.values() if metric in m]
        if values:
            min_vals[metric] = min(values)
            max_vals[metric] = max(values)
        else:
            min_vals[metric] = 0.0
            max_vals[metric] = 1.0

    # Normalize
    normalized = {}

    for method_name, method_metrics in metrics_by_method.items():
        normalized[method_name] = {}

        for metric, value in method_metrics.items():
            min_v = min_vals.get(metric, 0.0)
            max_v = max_vals.get(metric, 1.0)

            # Avoid division by zero
            if max_v - min_v < 1e-10:
                norm_value = 1.0  # All values are the same
            else:
                norm_value = (value - min_v) / (max_v - min_v)

            # Invert if lower is better
            is_higher_better = higher_is_better.get(metric, True)
            if not is_higher_better:
                norm_value = 1.0 - norm_value

            normalized[method_name][metric] = norm_value

    return normalized


# =============================================================================
# PARETO ANALYZER CLASS
# =============================================================================

class ParetoAnalyzer:
    """
    Comprehensive Pareto dominance analyzer for malware classifier evaluation.

    This class provides tools to:
    1. Compute the Four Pillars metrics for each method
    2. Normalize metrics for fair comparison
    3. Find Pareto frontiers per dataset
    4. Identify universally Pareto-optimal methods across datasets

    Example:
        >>> analyzer = ParetoAnalyzer()
        >>>
        >>> # Add results for each method/dataset
        >>> analyzer.add_method_results(
        ...     method_name="DeepDrebin (Subsample)",
        ...     dataset="androzoo",
        ...     monthly_f1=[0.85, 0.87, 0.84, ...],
        ... )
        >>>
        >>> # Compute Pareto frontier
        >>> result = analyzer.analyze_dataset("androzoo")
        >>> print(f"Frontier: {result.frontier}")
        >>>
        >>> # Find universally optimal methods
        >>> universal = analyzer.analyze_universal()
        >>> print(f"Universal optimal: {universal.universal_optimal}")
    """

    def __init__(self, metrics: List[str] = None):
        """
        Initialize analyzer.

        Args:
            metrics: List of metrics to use for Pareto comparison.
                    Default: ['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau']
        """
        self.metrics = metrics or ['mean_f1', 'sigma_f1', 'min_f1', 'mann_kendall_tau']

        # Storage: dataset -> method_name -> ParetoMetrics
        self._raw_metrics: Dict[str, Dict[str, Dict[str, float]]] = {}
        self._pareto_metrics: Dict[str, Dict[str, ParetoMetrics]] = {}

        # Cache
        self._normalized: Dict[str, Dict[str, Dict[str, float]]] = {}
        self._results: Dict[str, ParetoResult] = {}

    def add_method_results(
        self,
        method_name: str,
        dataset: str,
        monthly_f1: List[float],
        monthly_predictions: List[npt.NDArray] = None,
        monthly_labels: List[npt.NDArray] = None,
    ) -> None:
        """
        Add results for a method on a dataset.

        Computes the Four Pillars metrics from monthly F1 scores.

        Args:
            method_name: Name of the method (e.g., "DeepDrebin (Subsample)")
            dataset: Dataset name (e.g., "androzoo")
            monthly_f1: List of F1 scores, one per month
            monthly_predictions: Optional, for computing additional metrics
            monthly_labels: Optional, for computing additional metrics
        """
        if dataset not in self._raw_metrics:
            self._raw_metrics[dataset] = {}

        # Compute stability suite
        suite = StabilityMetrics.compute_stability_suite(monthly_f1)

        # Store raw metrics
        self._raw_metrics[dataset][method_name] = {
            'mean_f1': float(np.mean(monthly_f1)),
            'sigma_f1': suite['sigma'],
            'cv_f1': suite['cv'],
            'min_f1': suite['min'],
            'max_f1': suite['max'],
            'p5_f1': suite['p5'],
            'mann_kendall_tau': suite['mann_kendall_tau'],
        }

        # Invalidate cache
        if dataset in self._normalized:
            del self._normalized[dataset]
        if dataset in self._results:
            del self._results[dataset]

    def add_precomputed_metrics(
        self,
        method_name: str,
        dataset: str,
        metrics: Dict[str, float]
    ) -> None:
        """
        Add pre-computed metrics for a method.

        Args:
            method_name: Name of the method
            dataset: Dataset name
            metrics: Dict with metric values (mean_f1, sigma_f1, min_f1, etc.)
        """
        if dataset not in self._raw_metrics:
            self._raw_metrics[dataset] = {}

        self._raw_metrics[dataset][method_name] = metrics.copy()

        # Invalidate cache
        if dataset in self._normalized:
            del self._normalized[dataset]
        if dataset in self._results:
            del self._results[dataset]

    def _ensure_normalized(self, dataset: str) -> None:
        """Ensure metrics are normalized for the given dataset."""
        if dataset not in self._normalized:
            if dataset not in self._raw_metrics:
                raise ValueError(f"No data for dataset: {dataset}")

            self._normalized[dataset] = normalize_metrics(
                self._raw_metrics[dataset]
            )

    def analyze_dataset(self, dataset: str) -> ParetoResult:
        """
        Analyze Pareto frontier for a single dataset.

        Args:
            dataset: Dataset name

        Returns:
            ParetoResult with frontier, dominated methods, and metrics
        """
        if dataset in self._results:
            return self._results[dataset]

        self._ensure_normalized(dataset)

        # Convert to vectors
        solutions = {}
        for method_name, metrics in self._normalized[dataset].items():
            vector = np.array([metrics[m] for m in self.metrics])
            solutions[method_name] = vector

        # Find frontier
        frontier, dominated = find_pareto_frontier(solutions, return_dominated=True)

        # Compute dominance matrix
        dominance_matrix = compute_dominance_matrix(solutions)

        # Create ParetoMetrics objects
        pareto_metrics = {}
        for method_name, norm_metrics in self._normalized[dataset].items():
            pareto_metrics[method_name] = ParetoMetrics(
                method_name=method_name,
                dataset=dataset,
                mean_f1=norm_metrics.get('mean_f1', 0.0),
                sigma_f1=norm_metrics.get('sigma_f1', 0.0),
                min_f1=norm_metrics.get('min_f1', 0.0),
                mann_kendall_tau=norm_metrics.get('mann_kendall_tau', 0.0),
                cv_f1=norm_metrics.get('cv_f1'),
                p5_f1=norm_metrics.get('p5_f1'),
                max_f1=norm_metrics.get('max_f1'),
                raw_values=self._raw_metrics[dataset][method_name].copy(),
            )

        result = ParetoResult(
            dataset=dataset,
            frontier=frontier,
            dominated=dominated,
            dominance_matrix=dominance_matrix,
            metrics=pareto_metrics,
        )

        self._results[dataset] = result
        return result

    def analyze_universal(self) -> UniversalParetoResult:
        """
        Analyze Pareto optimality across all datasets.

        Identifies:
        - Methods on frontier in ALL datasets (universally optimal)
        - Methods on frontier in SOME datasets (partial)
        - Methods NEVER on frontier (always dominated)

        Returns:
            UniversalParetoResult with cross-dataset analysis
        """
        if not self._raw_metrics:
            raise ValueError("No data added. Use add_method_results() first.")

        datasets = list(self._raw_metrics.keys())

        # Analyze each dataset
        dataset_results = {}
        for dataset in datasets:
            dataset_results[dataset] = self.analyze_dataset(dataset)

        # Get all method names
        all_methods: Set[str] = set()
        for dataset in datasets:
            all_methods.update(self._raw_metrics[dataset].keys())

        # Count frontier appearances
        frontier_counts: Dict[str, int] = {m: 0 for m in all_methods}
        partial_optimal: Dict[str, List[str]] = {m: [] for m in all_methods}

        for dataset, result in dataset_results.items():
            for method in result.frontier:
                frontier_counts[method] += 1
                partial_optimal[method].append(dataset)

        n_datasets = len(datasets)

        # Categorize methods
        universal_optimal = [m for m, count in frontier_counts.items() if count == n_datasets]
        always_dominated = [m for m, count in frontier_counts.items() if count == 0]

        # Remove always-dominated from partial (they have empty lists)
        partial_optimal = {m: ds for m, ds in partial_optimal.items()
                         if 0 < frontier_counts[m] < n_datasets}

        return UniversalParetoResult(
            universal_optimal=universal_optimal,
            partial_optimal=partial_optimal,
            always_dominated=always_dominated,
            dataset_results=dataset_results,
            frontier_counts=frontier_counts,
        )

    def get_raw_metrics(self, dataset: str = None) -> Dict:
        """Get raw (non-normalized) metrics."""
        if dataset:
            return self._raw_metrics.get(dataset, {})
        return self._raw_metrics

    def get_normalized_metrics(self, dataset: str) -> Dict[str, Dict[str, float]]:
        """Get normalized metrics for a dataset."""
        self._ensure_normalized(dataset)
        return self._normalized[dataset]


# =============================================================================
# CONVENIENCE FUNCTIONS
# =============================================================================

def compute_pareto_metrics_from_results(
    results: List[Any],
    method_key: str = 'base_name',
    dataset_key: str = 'dataset',
) -> ParetoAnalyzer:
    """
    Compute Pareto metrics from a list of experiment results.

    This is a convenience function that creates a ParetoAnalyzer and
    populates it from experiment results.

    Args:
        results: List of experiment result objects/dicts
        method_key: Key for method name in results
        dataset_key: Key for dataset name in results

    Returns:
        Populated ParetoAnalyzer
    """
    from collections import defaultdict

    analyzer = ParetoAnalyzer()

    # Group results by (dataset, method)
    grouped = defaultdict(list)

    for result in results:
        if hasattr(result, method_key):
            method = getattr(result, method_key)
            dataset = getattr(result, dataset_key)
        else:
            method = result[method_key]
            dataset = result[dataset_key]

        grouped[(dataset, method)].append(result)

    # Compute metrics for each group
    for (dataset, method), group_results in grouped.items():
        # Extract monthly F1 scores
        monthly_f1 = []

        for r in group_results:
            if hasattr(r, 'predictions') and hasattr(r, 'labels'):
                # Compute F1 from predictions/labels
                from .metrics import ReliabilityMetrics
                preds = r.predictions if isinstance(r.predictions, np.ndarray) else np.array(r.predictions)
                labs = r.labels if isinstance(r.labels, np.ndarray) else np.array(r.labels)
                metrics = ReliabilityMetrics.compute_aggregated(preds, labs)
                monthly_f1.append(metrics['F1'])

        if monthly_f1:
            analyzer.add_method_results(method, dataset, monthly_f1)

    return analyzer


def format_pareto_table(
    universal_result: UniversalParetoResult,
    datasets: List[str] = None,
) -> str:
    """
    Format Pareto analysis results as a text table.

    Args:
        universal_result: Result from ParetoAnalyzer.analyze_universal()
        datasets: Order of datasets in columns (default: sorted alphabetically)

    Returns:
        Formatted string table
    """
    if datasets is None:
        datasets = sorted(universal_result.dataset_results.keys())

    lines = []

    # Header
    header = f"{'Method':<35} | " + " | ".join(f"{ds[:4]:^4}" for ds in datasets) + " | Status"
    lines.append(header)
    lines.append("-" * len(header))

    # Get all methods sorted by frontier count (descending)
    all_methods = set()
    for result in universal_result.dataset_results.values():
        all_methods.update(result.frontier)
        all_methods.update(result.dominated)

    sorted_methods = sorted(
        all_methods,
        key=lambda m: (-universal_result.frontier_counts.get(m, 0), m)
    )

    # Rows
    for method in sorted_methods:
        count = universal_result.frontier_counts.get(method, 0)

        # Determine status
        if method in universal_result.universal_optimal:
            status = "UNIVERSAL"
        elif method in universal_result.always_dominated:
            status = "Dominated"
        else:
            status = f"Partial ({count}/{len(datasets)})"

        # Dataset columns
        cols = []
        for ds in datasets:
            result = universal_result.dataset_results.get(ds)
            if result and method in result.frontier:
                cols.append(" *  ")  # On frontier
            else:
                cols.append(" -  ")  # Dominated

        row = f"{method:<35} | " + " | ".join(cols) + f" | {status}"
        lines.append(row)

    return "\n".join(lines)
