"""
Non-Conformity Measure (NCM) System

This module provides a flexible system for handling different uncertainty quantification
methods (NCMs) for malware detectors. Different methods can produce uncertainties in
different ways:

- MSP (Maximum Softmax Probability): Computed from softmax scores
- Pseudo-Loss: Direct uncertainty from training
- Margin: Distance-based uncertainty from SVM
- OOD: Out-of-distribution scores from CADE

The system supports:
1. Direct field selection (e.g., use "Uncertainties (Month Ahead)" as-is)
2. Computed NCMs (e.g., compute softmax-uncertainty from softmax-score)
3. Automatic method name expansion (e.g., "HCC (warm)" → "HCC (warm) - Pseudo-Loss")
"""

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Any
import numpy as np


# ============================================================================
# NCM COMPUTATION FUNCTIONS
# ============================================================================

def compute_softmax_uncertainty(softmax_score: np.ndarray) -> np.ndarray:
    """
    Compute MSP (Maximum Softmax Probability) uncertainty

    Args:
        softmax_score: Array of shape (n_samples, n_classes) with softmax probabilities

    Returns:
        Array of shape (n_samples,) with uncertainty scores

    Formula: uncertainty = 1 - |max_prob - 0.5| / 0.5
    - When max_prob = 0.5 (maximum uncertainty): uncertainty = 1
    - When max_prob = 0 or 1 (confident): uncertainty = 0
    """
    if softmax_score.ndim == 1:
        # Already probability scores
        max_probs = softmax_score
    else:
        # Extract maximum probability across classes
        max_probs = np.max(softmax_score, axis=1)

    uncertainty = 1 - (np.abs(max_probs - 0.5) / 0.5)
    return uncertainty


def compute_margin_uncertainty(decision_function: np.ndarray) -> np.ndarray:
    """
    Compute margin-based uncertainty from SVM decision function

    Args:
        decision_function: Array of decision function values (distance to hyperplane)

    Returns:
        Array of uncertainty scores (negated absolute distance)

    Samples far from hyperplane (confident) → low uncertainty
    Samples near hyperplane (uncertain) → high uncertainty
    """
    return -np.abs(decision_function)


def compute_entropy_uncertainty(softmax_score: np.ndarray) -> np.ndarray:
    """
    Compute entropy-based uncertainty

    Args:
        softmax_score: Array of shape (n_samples, n_classes) with softmax probabilities

    Returns:
        Array of shape (n_samples,) with entropy values
    """
    # Avoid log(0)
    probs = np.clip(softmax_score, 1e-10, 1.0)
    entropy = -np.sum(probs * np.log(probs), axis=1)
    return entropy


# ============================================================================
# NCM DESCRIPTOR
# ============================================================================

@dataclass
class NCMDescriptor:
    """
    Descriptor for a Non-Conformity Measure (NCM)

    Defines how to obtain uncertainty scores for a specific method variant.

    Attributes:
        name: NCM identifier (e.g., "MSP", "Pseudo-Loss", "Margin")
              This will be appended to the base method name
        source_field: Direct field name to use (mutually exclusive with compute_fn)
        compute_fn: Function to compute uncertainty from other fields
        source_fields: Fields needed for compute_fn
        description: Human-readable description of this NCM

    Examples:
        # Direct field selection
        NCMDescriptor(
            name="Pseudo-Loss",
            source_field="Uncertainties (Month Ahead)",
            description="Pseudo-loss uncertainty from HCC training"
        )

        # Computed NCM
        NCMDescriptor(
            name="MSP",
            compute_fn=compute_softmax_uncertainty,
            source_fields=["softmax-score"],
            description="Maximum Softmax Probability uncertainty"
        )
    """
    name: str
    source_field: Optional[str] = None
    compute_fn: Optional[Callable] = None
    source_fields: Optional[List[str]] = None
    description: str = ""

    def __post_init__(self):
        """Validate descriptor configuration"""
        has_source = self.source_field is not None
        has_compute = self.compute_fn is not None

        if not has_source and not has_compute:
            raise ValueError(f"NCM '{self.name}' must specify either source_field or compute_fn")

        if has_source and has_compute:
            raise ValueError(f"NCM '{self.name}' cannot specify both source_field and compute_fn")

        if has_compute and not self.source_fields:
            raise ValueError(f"NCM '{self.name}' with compute_fn must specify source_fields")

    def compute_or_select(self, result_dict: Dict[str, Any]) -> np.ndarray:
        """
        Compute or select uncertainty from result dictionary

        Args:
            result_dict: Result dictionary with raw data

        Returns:
            Uncertainty array

        Raises:
            KeyError: If required fields are missing
        """
        if self.source_field:
            # Direct field selection
            if self.source_field not in result_dict:
                raise KeyError(f"Source field '{self.source_field}' not found for NCM '{self.name}'")
            return result_dict[self.source_field]

        else:
            # Computed NCM
            # Check all source fields exist
            missing = [f for f in self.source_fields if f not in result_dict]
            if missing:
                raise KeyError(f"Missing source fields for NCM '{self.name}': {missing}")

            # Extract source data
            source_data = [result_dict[field] for field in self.source_fields]

            # Compute uncertainty
            return self.compute_fn(*source_data)


# ============================================================================
# NCM REGISTRY
# ============================================================================

# Maps trainer mode (base name) to available NCMs
NCM_REGISTRY: Dict[str, List[NCMDescriptor]] = {

    # HCC: Pseudo-Loss (direct) and MSP (computed)
    "HCC": [
        NCMDescriptor(
            name="Pseudo-Loss",
            source_field="Uncertainties (Month Ahead)",
            description="Pseudo-loss uncertainty from HCC training"
        ),
        NCMDescriptor(
            name="MSP",
            compute_fn=compute_softmax_uncertainty,
            source_fields=["softmax-score"],
            description="Maximum Softmax Probability uncertainty"
        ),
    ],

    # DeepDrebin: Only MSP
    "CE": [  # DeepDrebin uses "CE" as trainer mode
        NCMDescriptor(
            name="MSP",
            compute_fn=compute_softmax_uncertainty,
            source_fields=["softmax-score"],
            description="Maximum Softmax Probability uncertainty"
        ),
    ],

    # CADE: OOD score (direct) and MSP (computed)
    "CADE": [
        NCMDescriptor(
            name="OOD",
            source_field="Uncertainties (Month Ahead)",
            description="Out-of-distribution score from CADE"
        ),
        NCMDescriptor(
            name="MSP",
            compute_fn=compute_softmax_uncertainty,
            source_fields=["softmax-score"],
            description="Maximum Softmax Probability uncertainty"
        ),
    ],

    # SVC: Margin (computed from decision function)
    "SVC": [
        NCMDescriptor(
            name="Margin",
            compute_fn=compute_margin_uncertainty,
            source_fields=["decision_function"],
            description="Margin-based uncertainty (negated distance to hyperplane)"
        ),
    ],

    # NAC: MSP only
    "NAC": [
        NCMDescriptor(
            name="NAC",
            source_field="Uncertainties (Month Ahead)",
            description="Neighborhood aggregation conformal prediction"
        ),
    ],

    # Transcendent-ICE methods
    "Trans-ICE-cred": [
        NCMDescriptor(
            name="cred",
            source_field="Uncertainties (Month Ahead)",
            description="Credibility-based uncertainty from ICE"
        ),
    ],

    "Trans-ICE-cred+conf": [
        NCMDescriptor(
            name="cred+conf",
            source_field="Uncertainties (Month Ahead)",
            description="Combined credibility+confidence from ICE"
        ),
    ],
}


# ============================================================================
# NCM EXPANSION
# ============================================================================

def get_ncms_for_method(trainer_mode: str) -> List[NCMDescriptor]:
    """
    Get available NCMs for a trainer mode

    Args:
        trainer_mode: Trainer mode string (e.g., "HCC", "CE", "CADE")

    Returns:
        List of NCM descriptors for this method
    """
    # Direct lookup
    if trainer_mode in NCM_REGISTRY:
        return NCM_REGISTRY[trainer_mode]

    # Try pattern matching for variants
    # e.g., "HCC (warm)" → "HCC", "CADE (cold)" → "CADE"
    base_trainer = trainer_mode.split("(")[0].strip()
    if base_trainer in NCM_REGISTRY:
        return NCM_REGISTRY[base_trainer]

    # No NCMs defined - return empty list
    return []


def expand_result_with_ncms(result_dict: Dict[str, Any],
                             trainer_mode_key: str = "Trainer-Mode",
                             uncertainty_key: str = "Uncertainties (Month Ahead)",
                             clean_temp_fields: bool = True) -> List[Dict[str, Any]]:
    """
    Expand a single result into multiple NCM variants

    Takes a result dictionary and creates separate result entries for each
    available NCM, computing uncertainties as needed.

    Args:
        result_dict: Result dictionary from experiment
        trainer_mode_key: Key for trainer mode in result dict
        uncertainty_key: Key where computed uncertainty should be stored
        clean_temp_fields: Remove temporary fields (softmax-score, etc.) after use

    Returns:
        List of result dictionaries, one per NCM variant

    Example:
        Input: HCC result with "Uncertainties (Month Ahead)" and "softmax-score"
        Output: [
            {... "Trainer-Mode": "HCC (warm) - Pseudo-Loss", "Uncertainties": pseudo_loss},
            {... "Trainer-Mode": "HCC (warm) - MSP", "Uncertainties": softmax_unc}
        ]
    """
    trainer_mode = result_dict.get(trainer_mode_key, "")

    # Get NCMs for this method
    ncms = get_ncms_for_method(trainer_mode)

    if not ncms:
        # No NCM expansion needed - return as-is
        return [result_dict]

    # Expand into multiple variants
    expanded = []

    for ncm in ncms:
        # Create a copy for this NCM variant
        variant = result_dict.copy()

        try:
            # Compute or select uncertainty
            uncertainty = ncm.compute_or_select(result_dict)
            variant[uncertainty_key] = uncertainty

            # Update trainer mode with NCM suffix
            # Handle both "HCC" and "HCC (warm)" cases
            if " - " not in trainer_mode:
                # Add NCM suffix
                variant[trainer_mode_key] = f"{trainer_mode} - {ncm.name}"
            else:
                # Already has suffix (e.g., loaded from file with NCM in name)
                # Keep as-is
                pass

            # Clean up temporary fields
            if clean_temp_fields:
                temp_fields = ["softmax-score", "decision_function"]
                for field in temp_fields:
                    if field in variant:
                        del variant[field]

            expanded.append(variant)

        except KeyError as e:
            # Missing required fields - skip this NCM variant
            print(f"Warning: Skipping NCM '{ncm.name}' for {trainer_mode}: {e}")
            continue

    return expanded


def expand_results_with_ncms(results: List[Dict[str, Any]],
                             **kwargs) -> List[Dict[str, Any]]:
    """
    Expand multiple results with NCMs

    Args:
        results: List of result dictionaries
        **kwargs: Passed to expand_result_with_ncms

    Returns:
        Expanded list of results with NCM variants
    """
    expanded = []
    for result in results:
        expanded.extend(expand_result_with_ncms(result, **kwargs))
    return expanded


# ============================================================================
# NCM-AWARE BASE NAME EXTRACTION
# ============================================================================

def extract_base_name_from_trainer_mode(trainer_mode: str) -> str:
    """
    Extract base name from trainer mode with NCM suffix

    Args:
        trainer_mode: Full trainer mode string (e.g., "HCC (warm) - MSP")

    Returns:
        Base name without NCM suffix (e.g., "HCC (warm)")

    Examples:
        "HCC (warm) - MSP" → "HCC (warm)"
        "HCC (warm) - Pseudo-Loss" → "HCC (warm)"
        "CE - MSP" → "CE"
        "SVC - Margin" → "SVC"
        "HCC" → "HCC" (no change if no NCM suffix)
    """
    if " - " in trainer_mode:
        # Split on last " - " to handle cases like "HCC (warm) - MSP"
        base = trainer_mode.rsplit(" - ", 1)[0]
        return base

    return trainer_mode


def get_ncm_suffix_from_trainer_mode(trainer_mode: str) -> Optional[str]:
    """
    Extract NCM suffix from trainer mode

    Args:
        trainer_mode: Full trainer mode string

    Returns:
        NCM suffix (e.g., "MSP") or None if no suffix
    """
    if " - " in trainer_mode:
        suffix = trainer_mode.rsplit(" - ", 1)[1]
        return suffix

    return None
