"""
Schema definitions for Aurora framework

Defines data structures for experimental results with flexible hyperparameters
and support for custom groupings.
"""

from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Literal, List
from datetime import datetime
import numpy as np
import numpy.typing as npt


# Type aliases for clarity
Dataset = Literal["androzoo", "apigraph", "transcendent"]
ArrayInt = npt.NDArray[np.int64]
ArrayFloat = npt.NDArray[np.float64]
ArrayBool = npt.NDArray[np.bool_]


@dataclass
class ExperimentMetadata:
    """
    Metadata for an experiment collection

    Tracks provenance and versioning information for reproducibility.
    """
    experiment_name: str
    """Human-readable experiment name"""

    source_file: str
    """Original file path the results were loaded from"""

    load_timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    """ISO timestamp when results were loaded"""

    description: Optional[str] = None
    """Optional description of experiment"""

    version: str = "1.0"
    """Schema version for forward/backward compatibility"""

    tags: List[str] = field(default_factory=list)
    """Tags for categorizing experiments"""

    custom_metadata: Dict[str, Any] = field(default_factory=dict)
    """Additional custom metadata fields"""


@dataclass
class ExperimentResult:
    """
    Single-month result from an experiment

    This is the core data structure representing one test month for one
    experimental configuration.

    Design Philosophy:
    - CORE FIELDS: Always required, used for grouping/validation
    - HYPERPARAMETERS: Optional, flexible dict for experiment-specific configs
    - DATA ARRAYS: Raw predictions, labels, uncertainties
    - COMPUTED METRICS: Optional, can be computed during ingestion

    The flexible hyperparameters dict allows different experiments to have
    different configuration fields without requiring schema changes.
    """

    # ===== CORE IDENTIFIERS (Always Required) =====
    test_month: int
    """Month index in test sequence (0-indexed)"""

    dataset: Dataset
    """Dataset name (androzoo, apigraph, transcendent)"""

    monthly_label_budget: int
    """Number of labels available per month (typically 50, 100, 200, 400)"""

    trainer_mode: str
    """Training method (e.g., 'CE', 'SVC', 'DANN+CE-Sampler')"""

    sampler_mode: str
    """Sampling strategy (e.g., 'full_first_year_subsample_months')"""

    # ===== METHOD IDENTIFICATION =====
    base_name: str
    """
    Human-readable base method name for grouping

    CRITICAL: This is the primary method identifier used in Aurora tables.
    Examples:
    - "DeepDrebin (cold) - MSP"
    - "SVM (C=1.0)"
    - "DANN (Combined Uncertainty)"

    This MUST be unique within a (budget, sampler, dataset) combination.
    """

    # ===== DATA ARRAYS (Always Required) =====
    predictions: ArrayInt
    """Binary predictions (0 or 1) for each test sample"""

    labels: ArrayInt
    """Ground truth labels (0 or 1) for each test sample"""

    uncertainties_month_ahead: ArrayFloat
    """
    Uncertainty scores at beginning of month

    These are the uncertainties used for predictions before retraining.
    Higher values = more uncertain.
    """

    uncertainties_past_month: ArrayFloat
    """
    Uncertainty scores at end of month

    These are the uncertainties after observing the month's data,
    used for sample selection for next month.
    """

    # ===== OPTIONAL: HYPERPARAMETERS =====
    hyperparameters: Dict[str, Any] = field(default_factory=dict)
    """
    Flexible dictionary for experiment-specific hyperparameters

    Common fields (all optional):
    - "Num-Epochs": int - Training epochs
    - "CE-Train-Aug-Scale": float - Augmentation scale
    - "DANN-Uncertainty-Method": str - DANN uncertainty type
    - "svm_c": float - SVM regularization parameter
    - "#-Monthly-Samples-First-Year": int - Initial training samples
    - "DANN-Training-Epochs": int - DANN-specific epochs
    - "CE-Training-Epochs": int - CE-specific epochs
    - "Random-Seed": int - Random seed for reproducibility
    - "Rejector-Mode": str - Rejection strategy

    ANY field can be added here without schema changes!
    """

    # ===== OPTIONAL: ADDITIONAL DATA =====
    rejection_mask: Optional[ArrayBool] = None
    """Boolean mask indicating which samples were rejected"""

    selection_mask: Optional[ArrayBool] = None
    """Boolean mask indicating which samples were selected for labeling"""

    # ===== OPTIONAL: PRE-COMPUTED METRICS =====
    f1: Optional[float] = None
    """F1 score (can be computed during ingestion if None)"""

    fnr: Optional[float] = None
    """False Negative Rate"""

    fpr: Optional[float] = None
    """False Positive Rate"""

    def __post_init__(self):
        """Validate array shapes on initialization"""
        n_samples = len(self.predictions)

        # Validate array lengths match
        if len(self.labels) != n_samples:
            raise ValueError(f"Labels length {len(self.labels)} != predictions length {n_samples}")

        if len(self.uncertainties_month_ahead) != n_samples:
            raise ValueError(
                f"uncertainties_month_ahead length {len(self.uncertainties_month_ahead)} "
                f"!= predictions length {n_samples}"
            )

        if len(self.uncertainties_past_month) != n_samples:
            raise ValueError(
                f"uncertainties_past_month length {len(self.uncertainties_past_month)} "
                f"!= predictions length {n_samples}"
            )

        # Validate optional masks if present
        if self.rejection_mask is not None and len(self.rejection_mask) != n_samples:
            raise ValueError(f"rejection_mask length != predictions length")

        if self.selection_mask is not None and len(self.selection_mask) != n_samples:
            raise ValueError(f"selection_mask length != predictions length")

        # Validate predictions and labels are binary
        if not np.all(np.isin(self.predictions, [0, 1])):
            raise ValueError("Predictions must be binary (0 or 1)")

        if not np.all(np.isin(self.labels, [0, 1])):
            raise ValueError("Labels must be binary (0 or 1)")

    def get_hyperparameter(self, key: str, default: Any = None) -> Any:
        """
        Get a hyperparameter value with fallback

        Args:
            key: Hyperparameter name
            default: Value to return if key not found

        Returns:
            Hyperparameter value or default
        """
        return self.hyperparameters.get(key, default)

    def get_grouping_key(self, *keys: str) -> tuple:
        """
        Get a tuple of values for grouping results

        This allows flexible grouping by any combination of core fields
        and hyperparameters.

        Args:
            *keys: Field names to include in grouping

        Returns:
            Tuple of values for the specified keys

        Example:
            >>> result.get_grouping_key("dataset", "monthly_label_budget", "Num-Epochs")
            ("androzoo", 200, 30)
        """
        values = []
        for key in keys:
            # Check core fields first
            if hasattr(self, key):
                values.append(getattr(self, key))
            # Then check hyperparameters
            elif key in self.hyperparameters:
                values.append(self.hyperparameters[key])
            else:
                # Return None if key not found (allows partial grouping)
                values.append(None)
        return tuple(values)

    def to_dict(self) -> Dict[str, Any]:
        """
        Convert to dictionary (for serialization or DataFrame creation)

        Returns:
            Dictionary with all fields flattened
        """
        result_dict = {
            # Core fields
            "Test-Month": self.test_month,
            "Dataset": self.dataset,
            "Monthly-Label-Budget": self.monthly_label_budget,
            "Trainer-Mode": self.trainer_mode,
            "Sampler-Mode": self.sampler_mode,
            "Base-Name": self.base_name,

            # Data arrays
            "Predictions": self.predictions,
            "Labels": self.labels,
            "Uncertainties (Month Ahead)": self.uncertainties_month_ahead,
            "Uncertainties (Past Month)": self.uncertainties_past_month,

            # Optional arrays
            "Rejection-Mask": self.rejection_mask,
            "Selection-Mask": self.selection_mask,

            # Metrics
            "F1": self.f1,
            "FNR": self.fnr,
            "FPR": self.fpr,
        }

        # Add all hyperparameters (flattened)
        for key, value in self.hyperparameters.items():
            result_dict[key] = value

        return result_dict


@dataclass
class ResultsCollection:
    """
    Collection of experimental results with metadata

    Represents a complete experiment or set of related experiments
    loaded from a single source.
    """
    metadata: ExperimentMetadata
    """Metadata about this collection"""

    results: List[ExperimentResult]
    """List of individual month results"""

    def __len__(self) -> int:
        return len(self.results)

    def __iter__(self):
        return iter(self.results)

    def __getitem__(self, idx: int) -> ExperimentResult:
        return self.results[idx]

    def filter(self, predicate) -> 'ResultsCollection':
        """
        Filter results based on a predicate function

        Args:
            predicate: Function that takes ExperimentResult and returns bool

        Returns:
            New ResultsCollection with filtered results

        Example:
            >>> # Filter for specific dataset
            >>> androzo_results = collection.filter(lambda r: r.dataset == "androzoo")
            >>> # Filter for specific hyperparameter
            >>> epoch30_results = collection.filter(
            ...     lambda r: r.get_hyperparameter("Num-Epochs") == 30
            ... )
        """
        filtered = [r for r in self.results if predicate(r)]
        return ResultsCollection(
            metadata=self.metadata,
            results=filtered
        )

    def get_unique_values(self, *keys: str) -> Dict[str, set]:
        """
        Get unique values for specified fields across all results

        Args:
            *keys: Field names to get unique values for

        Returns:
            Dictionary mapping field names to sets of unique values

        Example:
            >>> collection.get_unique_values("dataset", "monthly_label_budget")
            {
                "dataset": {"androzoo", "apigraph"},
                "monthly_label_budget": {50, 100, 200, 400}
            }
        """
        unique_vals = {key: set() for key in keys}

        for result in self.results:
            for key in keys:
                # Check core fields
                if hasattr(result, key):
                    unique_vals[key].add(getattr(result, key))
                # Check hyperparameters
                elif key in result.hyperparameters:
                    unique_vals[key].add(result.hyperparameters[key])

        return unique_vals

    def group_by(self, *keys: str) -> Dict[tuple, List[ExperimentResult]]:
        """
        Group results by specified keys

        Args:
            *keys: Field names to group by

        Returns:
            Dictionary mapping grouping keys to lists of results

        Example:
            >>> # Group by dataset and budget
            >>> groups = collection.group_by("dataset", "monthly_label_budget")
            >>> for (dataset, budget), results in groups.items():
            ...     print(f"{dataset} @ budget {budget}: {len(results)} results")
        """
        groups = {}

        for result in self.results:
            group_key = result.get_grouping_key(*keys)
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(result)

        return groups

    def to_dict_list(self) -> List[Dict[str, Any]]:
        """
        Convert to list of dictionaries (legacy format)

        Returns:
            List of dictionaries, one per result
        """
        return [r.to_dict() for r in self.results]
