"""
Data ingestion module for Aurora framework

Handles loading experimental results from various sources and converting to
standardized format.
"""

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Dict, Any, Callable, List
from datetime import datetime
import pickle
import json
import numpy as np

from .schema import (
    ExperimentResult,
    ExperimentMetadata,
    ResultsCollection,
)
from .validation import ResultsValidator, ValidationError


class ResultsLoader(ABC):
    """
    Abstract base class for loading experimental results

    Subclasses implement loading from specific sources (pickle, CSV, database, etc.)
    """

    def __init__(self, auto_validate: bool = True, strict_validation: bool = False):
        """
        Args:
            auto_validate: Automatically validate after loading
            strict_validation: Treat validation warnings as errors
        """
        self.auto_validate = auto_validate
        self.validator = ResultsValidator(strict=strict_validation)

    @abstractmethod
    def load(self, source: Any) -> ResultsCollection:
        """
        Load results from source

        Args:
            source: Source to load from (path, connection, etc.)

        Returns:
            ResultsCollection with loaded results
        """
        pass

    def _validate_if_enabled(self, collection: ResultsCollection) -> ResultsCollection:
        """Validate collection if auto_validate is enabled"""
        if self.auto_validate:
            report = self.validator.validate(collection)

            if report.has_errors():
                error_msgs = [str(e) for e in report.get_errors()]
                raise ValidationError(
                    f"Validation failed for {collection.metadata.experiment_name}:\n"
                    + "\n".join(error_msgs)
                )

            if report.has_warnings():
                warning_msgs = [str(w) for w in report.get_warnings()]
                print(f"⚠️  Warnings for {collection.metadata.experiment_name}:")
                for msg in warning_msgs:
                    print(f"  {msg}")

        return collection


class PickleResultsLoader(ResultsLoader):
    """
    Loader for legacy pickle files

    Converts old-style dict results to new ExperimentResult format.
    Supports flexible base name mapping and field transformations.
    """

    def __init__(
        self,
        base_name_mapper: Optional[Callable[[Dict], str]] = None,
        field_transforms: Optional[Dict[str, Callable[[Any], Any]]] = None,
        filters: Optional[List[Callable[[Dict], bool]]] = None,
        auto_validate: bool = True,
        strict_validation: bool = False,
    ):
        """
        Args:
            base_name_mapper: Function to extract/create base_name from dict
                             If None, uses default mapping logic
            field_transforms: Dict mapping field names to transformation functions
                             Example: {"Trainer-Mode": lambda x: x.replace("_", " ")}
            filters: List of filter functions (return True to keep item)
                     Example: [lambda x: x["Test-Month"] <= 22]
            auto_validate: Automatically validate after loading
            strict_validation: Treat validation warnings as errors
        """
        super().__init__(auto_validate, strict_validation)
        self.base_name_mapper = base_name_mapper or self._default_base_name_mapper
        self.field_transforms = field_transforms or {}
        self.filters = filters or []

    def load(self, path: Path | str, experiment_name: Optional[str] = None) -> ResultsCollection:
        """
        Load results from pickle file

        Args:
            path: Path to pickle file
            experiment_name: Optional experiment name (defaults to filename)

        Returns:
            ResultsCollection with loaded and validated results
        """
        path = Path(path)

        if not path.exists():
            raise FileNotFoundError(f"Pickle file not found: {path}")

        # Load raw data
        with open(path, 'rb') as f:
            raw_results = pickle.load(f)

        if not isinstance(raw_results, list):
            raise ValueError(f"Expected list of dicts, got {type(raw_results)}")

        # Create metadata
        metadata = ExperimentMetadata(
            experiment_name=experiment_name or path.stem,
            source_file=str(path.absolute()),
            load_timestamp=datetime.now().isoformat(),
            description=f"Loaded from {path.name}",
        )

        # Convert and filter results
        results = []
        for i, raw_result in enumerate(raw_results):
            # Apply filters
            if not self._apply_filters(raw_result):
                continue

            # Apply field transforms
            raw_result = self._apply_transforms(raw_result)

            try:
                result = self._convert_result(raw_result)
                results.append(result)
            except Exception as e:
                raise ValueError(
                    f"Failed to convert result {i} from {path.name}: {e}\n"
                    f"Result keys: {list(raw_result.keys())}"
                )

        collection = ResultsCollection(metadata=metadata, results=results)

        # Validate if enabled
        return self._validate_if_enabled(collection)

    def _apply_filters(self, raw_result: Dict) -> bool:
        """Apply all filter functions"""
        for filter_fn in self.filters:
            if not filter_fn(raw_result):
                return False
        return True

    def _apply_transforms(self, raw_result: Dict) -> Dict:
        """Apply field transformations"""
        for field, transform_fn in self.field_transforms.items():
            if field in raw_result:
                raw_result[field] = transform_fn(raw_result[field])
        return raw_result

    def _convert_result(self, raw: Dict) -> ExperimentResult:
        """Convert legacy dict to ExperimentResult"""

        # Extract core fields (required)
        try:
            test_month = raw["Test-Month"]
            dataset = raw["Dataset"]
            monthly_label_budget = raw["Monthly-Label-Budget"]
            trainer_mode = raw.get("Trainer-Mode", "")
            sampler_mode = raw.get("Sampler-Mode", "")
        except KeyError as e:
            raise ValueError(f"Missing required field: {e}")

        # Get base name
        base_name = self.base_name_mapper(raw)

        # Extract data arrays (required)
        try:
            predictions = np.array(raw["Predictions"])
            labels = np.array(raw["Labels"])
            uncertainties_month_ahead = np.array(raw["Uncertainties (Month Ahead)"])
            uncertainties_past_month = np.array(raw["Uncertainties (Past Month)"])
        except KeyError as e:
            raise ValueError(f"Missing required data array: {e}")

        # Extract optional arrays
        rejection_mask = np.array(raw["Rejection-Mask"]) if "Rejection-Mask" in raw else None
        selection_mask = np.array(raw["Selection-Mask"]) if "Selection-Mask" in raw else None

        # Extract optional metrics (convert to numpy arrays if present)
        f1 = np.array(raw["F1"]) if "F1" in raw else None
        fnr = np.array(raw["FNR"]) if "FNR" in raw else None
        fpr = np.array(raw["FPR"]) if "FPR" in raw else None

        # Build hyperparameters dict
        hyperparameters = self._extract_hyperparameters(raw)

        return ExperimentResult(
            test_month=test_month,
            dataset=dataset,
            monthly_label_budget=monthly_label_budget,
            trainer_mode=trainer_mode,
            sampler_mode=sampler_mode,
            base_name=base_name,
            predictions=predictions,
            labels=labels,
            uncertainties_month_ahead=uncertainties_month_ahead,
            uncertainties_past_month=uncertainties_past_month,
            rejection_mask=rejection_mask,
            selection_mask=selection_mask,
            hyperparameters=hyperparameters,
            f1=f1,
            fnr=fnr,
            fpr=fpr,
        )

    def _extract_hyperparameters(self, raw: Dict) -> Dict[str, Any]:
        """
        Extract hyperparameters from raw dict

        Any field not in the core fields list is considered a hyperparameter.
        """
        # Core fields that are NOT hyperparameters
        core_fields = {
            "Test-Month", "Dataset", "Monthly-Label-Budget",
            "Trainer-Mode", "Sampler-Mode", "Base-Name",
            "Predictions", "Labels",
            "Uncertainties (Month Ahead)", "Uncertainties (Past Month)",
            "Rejection-Mask", "Selection-Mask",
            "F1", "FNR", "FPR",
            "Method-Name",  # Legacy field
        }

        hyperparameters = {}
        for key, value in raw.items():
            if key not in core_fields:
                # Don't include numpy arrays as hyperparameters
                if not isinstance(value, np.ndarray):
                    hyperparameters[key] = value

        return hyperparameters

    def _default_base_name_mapper(self, raw: Dict) -> str:
        """
        Default base name mapper - just returns Trainer-Mode

        For custom naming logic, provide your own mapper function.

        Example:
            def my_mapper(raw):
                if raw["Trainer-Mode"] == "CE":
                    return "DeepDrebin (cold) - MSP"
                return raw["Trainer-Mode"]

            loader = PickleResultsLoader(base_name_mapper=my_mapper)
        """
        # Check for explicit base name
        if "Base-Name" in raw:
            return raw["Base-Name"]

        # Check for legacy method name
        if "Method-Name" in raw:
            return raw["Method-Name"]

        # Default: just use trainer mode
        return raw.get("Trainer-Mode", "Unknown")


class JSONResultsLoader(ResultsLoader):
    """
    Loader for JSON result files

    Automatically converts lists to numpy arrays and handles JSON-specific
    field naming conventions (e.g., "Seed" → "Random-Seed").

    Supports the same flexible transformations as PickleResultsLoader.
    """

    def __init__(
        self,
        base_name_mapper: Optional[Callable[[Dict], str]] = None,
        field_transforms: Optional[Dict[str, Callable[[Any], Any]]] = None,
        filters: Optional[List[Callable[[Dict], bool]]] = None,
        auto_validate: bool = True,
        strict_validation: bool = False,
        rename_seed_field: bool = True,
    ):
        """
        Args:
            base_name_mapper: Function to extract/create base_name from dict
            field_transforms: Dict mapping field names to transformation functions
            filters: List of filter functions (return True to keep item)
            auto_validate: Automatically validate after loading
            strict_validation: Treat validation warnings as errors
            rename_seed_field: If True, automatically rename "Seed" → "Random-Seed"
        """
        super().__init__(auto_validate, strict_validation)
        self.base_name_mapper = base_name_mapper or self._default_base_name_mapper
        self.field_transforms = field_transforms or {}
        self.filters = filters or []
        self.rename_seed_field = rename_seed_field

    def load(
        self,
        path: Path | str,
        experiment_name: Optional[str] = None
    ) -> ResultsCollection:
        """
        Load results from JSON file

        Args:
            path: Path to JSON file
            experiment_name: Optional experiment name (defaults to filename)

        Returns:
            ResultsCollection with loaded and validated results
        """
        path = Path(path)

        if not path.exists():
            raise FileNotFoundError(f"JSON file not found: {path}")

        # Load raw data
        with open(path, 'r') as f:
            raw_results = json.load(f)

        if not isinstance(raw_results, list):
            raise ValueError(f"Expected list of dicts, got {type(raw_results)}")

        # Create metadata
        metadata = ExperimentMetadata(
            experiment_name=experiment_name or path.stem,
            source_file=str(path.absolute()),
            load_timestamp=datetime.now().isoformat(),
            description=f"Loaded from {path.name}",
        )

        # Convert and filter results
        results = []
        for i, raw_result in enumerate(raw_results):
            # Convert lists to numpy arrays (internalize make_np_arrays)
            raw_result = self._convert_arrays(raw_result)

            # Handle common JSON-specific field naming
            if self.rename_seed_field and "Seed" in raw_result:
                raw_result["Random-Seed"] = raw_result["Seed"]
                del raw_result["Seed"]

            # Apply filters
            if not self._apply_filters(raw_result):
                continue

            # Apply field transforms
            raw_result = self._apply_transforms(raw_result)

            try:
                result = self._convert_result(raw_result)
                results.append(result)
            except Exception as e:
                raise ValueError(
                    f"Failed to convert result {i} from {path.name}: {e}\n"
                    f"Result keys: {list(raw_result.keys())}"
                )

        collection = ResultsCollection(metadata=metadata, results=results)

        # Validate if enabled
        return self._validate_if_enabled(collection)

    def _convert_arrays(self, raw_result: Dict) -> Dict:
        """
        Convert all list fields to numpy arrays

        This internalizes the old make_np_arrays() function.
        """
        for key, value in raw_result.items():
            if isinstance(value, list):
                raw_result[key] = np.array(value)
        return raw_result

    def _apply_filters(self, raw_result: Dict) -> bool:
        """Apply all filter functions"""
        for filter_fn in self.filters:
            if not filter_fn(raw_result):
                return False
        return True

    def _apply_transforms(self, raw_result: Dict) -> Dict:
        """Apply field transformations"""
        for field, transform_fn in self.field_transforms.items():
            if field in raw_result:
                raw_result[field] = transform_fn(raw_result[field])
        return raw_result

    def _convert_result(self, raw: Dict) -> ExperimentResult:
        """Convert JSON dict to ExperimentResult"""

        # Extract core fields (required)
        try:
            test_month = raw["Test-Month"]
            dataset = raw["Dataset"]
            monthly_label_budget = raw["Monthly-Label-Budget"]
            trainer_mode = raw.get("Trainer-Mode", "")
            sampler_mode = raw.get("Sampler-Mode", "")
        except KeyError as e:
            raise ValueError(f"Missing required field: {e}")

        # Get base name
        base_name = self.base_name_mapper(raw)

        # Extract data arrays (required for JSON)
        try:
            predictions = np.array(raw["Predictions"])
            labels = np.array(raw["Labels"])
            uncertainties_month_ahead = np.array(raw["Uncertainties (Month Ahead)"])
        except KeyError as e:
            raise ValueError(f"Missing required data array: {e}")

        # Uncertainties (Past Month) is optional for JSON files
        uncertainties_past_month = (
            np.array(raw["Uncertainties (Past Month)"])
            if "Uncertainties (Past Month)" in raw
            else uncertainties_month_ahead  # Use month_ahead as fallback
        )

        # Extract optional arrays
        rejection_mask = np.array(raw["Rejection-Mask"]) if "Rejection-Mask" in raw else None
        selection_mask = np.array(raw["Selection-Mask"]) if "Selection-Mask" in raw else None

        # Extract optional metrics (convert to numpy arrays if present)
        f1 = np.array(raw["F1"]) if "F1" in raw else None
        fnr = np.array(raw["FNR"]) if "FNR" in raw else None
        fpr = np.array(raw["FPR"]) if "FPR" in raw else None

        # Build hyperparameters dict
        hyperparameters = self._extract_hyperparameters(raw)

        return ExperimentResult(
            test_month=test_month,
            dataset=dataset,
            monthly_label_budget=monthly_label_budget,
            trainer_mode=trainer_mode,
            sampler_mode=sampler_mode,
            base_name=base_name,
            predictions=predictions,
            labels=labels,
            uncertainties_month_ahead=uncertainties_month_ahead,
            uncertainties_past_month=uncertainties_past_month,
            rejection_mask=rejection_mask,
            selection_mask=selection_mask,
            hyperparameters=hyperparameters,
            f1=f1,
            fnr=fnr,
            fpr=fpr,
        )

    def _extract_hyperparameters(self, raw: Dict) -> Dict[str, Any]:
        """Extract hyperparameters from raw dict"""
        # Core fields that are NOT hyperparameters
        core_fields = {
            "Test-Month", "Dataset", "Monthly-Label-Budget",
            "Trainer-Mode", "Sampler-Mode", "Base-Name",
            "Predictions", "Labels",
            "Uncertainties (Month Ahead)", "Uncertainties (Past Month)",
            "Rejection-Mask", "Selection-Mask",
            "F1", "FNR", "FPR",
            "Method-Name", "Seed", "Random-Seed",
            "softmax-score", "softmax-uncertainty",  # JSON-specific
        }

        hyperparameters = {}
        for key, value in raw.items():
            if key not in core_fields:
                # Don't include numpy arrays as hyperparameters
                if not isinstance(value, np.ndarray):
                    hyperparameters[key] = value

        return hyperparameters

    def _default_base_name_mapper(self, raw: Dict) -> str:
        """
        Default base name mapper - just returns Trainer-Mode

        For custom naming logic, provide your own mapper function.

        Example:
            def my_mapper(raw):
                if "HCC" in raw["Trainer-Mode"]:
                    return "HCC (warm) - Pseudo-Loss"
                return raw["Trainer-Mode"]

            loader = JSONResultsLoader(base_name_mapper=my_mapper)
        """
        # Check for explicit base name
        if "Base-Name" in raw:
            return raw["Base-Name"]

        # Check for legacy method name
        if "Method-Name" in raw:
            return raw["Method-Name"]

        # Default: just use trainer mode
        return raw.get("Trainer-Mode", "Unknown")


def create_cutoff_month_filter(
    cutoff_months: Optional[Dict[str, int]] = None
) -> Callable[[Dict], bool]:
    """
    Create a filter function for cutoff months

    Args:
        cutoff_months: Dict mapping dataset names to max month index
                       Defaults to {"androzoo": 22, "apigraph": 70, "transcendent": 46}

    Returns:
        Filter function that returns True if month <= cutoff

    Example:
        >>> filter_fn = create_cutoff_month_filter()
        >>> loader = PickleResultsLoader(filters=[filter_fn])
        >>> results = loader.load("data.pkl")
    """
    if cutoff_months is None:
        cutoff_months = {
            "androzoo": 22,
            "apigraph": 70,
            "transcendent": 46,
        }

    def filter_fn(item: Dict) -> bool:
        dataset = item.get("Dataset")
        if dataset in cutoff_months:
            return item.get("Test-Month", 0) <= cutoff_months[dataset]
        return True

    return filter_fn


def create_hyperparameter_filter(
    hyperparameter: str,
    values: List[Any],
    exclude: bool = False
) -> Callable[[Dict], bool]:
    """
    Create a filter for specific hyperparameter values

    Args:
        hyperparameter: Hyperparameter field name
        values: List of values to filter for (or exclude)
        exclude: If True, exclude these values; if False, keep only these values

    Returns:
        Filter function

    Example:
        >>> # Keep only results with 30 or 50 epochs
        >>> filter_fn = create_hyperparameter_filter("Num-Epochs", [30, 50])
        >>> loader = PickleResultsLoader(filters=[filter_fn])

        >>> # Exclude results with 10 epochs
        >>> filter_fn = create_hyperparameter_filter("Num-Epochs", [10], exclude=True)
    """
    def filter_fn(item: Dict) -> bool:
        value = item.get(hyperparameter)
        if exclude:
            return value not in values
        else:
            return value in values

    return filter_fn


def combine_collections(*collections: ResultsCollection) -> ResultsCollection:
    """
    Combine multiple results collections into one

    Args:
        *collections: ResultsCollection instances to combine

    Returns:
        New ResultsCollection with all results combined

    Example:
        >>> results1 = loader.load("experiment1.pkl")
        >>> results2 = loader.load("experiment2.pkl")
        >>> combined = combine_collections(results1, results2)
    """
    if not collections:
        raise ValueError("No collections provided")

    # Create combined metadata
    experiment_names = [c.metadata.experiment_name for c in collections]
    combined_metadata = ExperimentMetadata(
        experiment_name=f"Combined({', '.join(experiment_names)})",
        source_file="multiple",
        description=f"Combined from {len(collections)} collections",
        custom_metadata={
            "source_collections": experiment_names,
            "total_results": sum(len(c) for c in collections),
        }
    )

    # Combine all results
    all_results = []
    for collection in collections:
        all_results.extend(collection.results)

    return ResultsCollection(
        metadata=combined_metadata,
        results=all_results
    )
