"""
Composable pipeline system for Aurora framework.

This module provides a chainable Pipeline that composes transformations.
The Pipeline is general-purpose and knows nothing about Aurora's specific data.

Key principles:
- Immutability: Each operation returns a NEW Pipeline instance
- Lazy evaluation: Pipeline builds up operations but doesn't execute until apply()
- Type preservation: Works with List[Dict] and returns List[Dict]

Example:
    >>> pipeline = (
    ...     Pipeline()
    ...     .rename_fields({"Test-Month": "month"})
    ...     .filter(Filter("month", "<=", 22))
    ...     .convert_arrays(fields=["Predictions", "Labels"])
    ... )
    >>> results = pipeline.apply(raw_data)
"""

from __future__ import annotations

import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np

from .filters import Filter, FilterChain


# Type for pipeline step: (operation_name, params_dict)
PipelineStep = Tuple[str, Dict[str, Any]]


class Pipeline:
    """
    Composable data transformation pipeline.

    Pipelines are immutable - each method returns a new Pipeline.
    Operations are lazy - nothing executes until apply() is called.

    Attributes:
        _steps: Internal list of (operation_name, params) tuples

    Example:
        >>> pipeline = (
        ...     Pipeline()
        ...     .rename_fields({"Seed": "random_seed"})
        ...     .filter(Filter("month", "<=", 22))
        ...     .convert_arrays()
        ...     .add_computed_field("base_name", lambda r: r.get("trainer"))
        ... )
        >>> results = pipeline.apply(raw_data)
    """

    def __init__(self, steps: Optional[List[PipelineStep]] = None):
        """
        Create a new Pipeline.

        Args:
            steps: Internal list of (operation_name, params) tuples.
                   Users should not pass this directly.
        """
        self._steps: List[PipelineStep] = list(steps) if steps is not None else []

    def _add_step(self, name: str, params: Dict[str, Any]) -> Pipeline:
        """Create new pipeline with additional step."""
        new_steps = self._steps + [(name, params)]
        return Pipeline(new_steps)

    def rename_fields(self, mapping: Dict[str, str]) -> Pipeline:
        """
        Add field renaming step.

        Args:
            mapping: Dict of {old_name: new_name}

        Returns:
            New Pipeline with rename step added

        Example:
            >>> p = Pipeline().rename_fields({"Test-Month": "month"})
        """
        return self._add_step("rename_fields", {"mapping": mapping})

    def filter(
        self,
        filter_or_fn: Union[Filter, FilterChain, Callable[[Dict], bool]],
    ) -> Pipeline:
        """
        Add filtering step.

        Args:
            filter_or_fn: Filter, FilterChain, or callable that returns bool

        Returns:
            New Pipeline with filter step added

        Example:
            >>> p = Pipeline().filter(Filter("month", "<=", 22))
            >>> p = Pipeline().filter(lambda r: r.get("valid", False))
        """
        return self._add_step("filter", {"filter": filter_or_fn})

    def convert_arrays(
        self,
        fields: Optional[List[str]] = None,
        exclude: Optional[List[str]] = None,
    ) -> Pipeline:
        """
        Convert list fields to numpy arrays.

        Args:
            fields: Specific fields to convert (None = all list fields)
            exclude: Fields to exclude from conversion

        Returns:
            New Pipeline with array conversion step added

        Example:
            >>> p = Pipeline().convert_arrays()  # All lists
            >>> p = Pipeline().convert_arrays(fields=["Predictions", "Labels"])
        """
        return self._add_step(
            "convert_arrays",
            {"fields": fields, "exclude": exclude or []},
        )

    def add_computed_field(
        self,
        field_name: str,
        compute_fn: Callable[[Dict], Any],
    ) -> Pipeline:
        """
        Add a computed field to each record.

        Args:
            field_name: Name of new field
            compute_fn: Function that takes record and returns value

        Returns:
            New Pipeline with computed field step added

        Example:
            >>> p = Pipeline().add_computed_field(
            ...     "base_name",
            ...     lambda r: r.get("Trainer-Mode", "Unknown")
            ... )
        """
        return self._add_step(
            "add_computed_field",
            {"field_name": field_name, "compute_fn": compute_fn},
        )

    def remove_fields(self, fields: List[str]) -> Pipeline:
        """
        Remove specified fields from records.

        Args:
            fields: List of field names to remove

        Returns:
            New Pipeline with field removal step added

        Example:
            >>> p = Pipeline().remove_fields(["temp_field", "debug_info"])
        """
        return self._add_step("remove_fields", {"fields": fields})

    def transform_field(
        self,
        field: str,
        transform_fn: Callable[[Any], Any],
    ) -> Pipeline:
        """
        Transform values in a specific field.

        Args:
            field: Field to transform
            transform_fn: Function to apply to field values

        Returns:
            New Pipeline with field transform step added

        Example:
            >>> p = Pipeline().transform_field(
            ...     "probability",
            ...     lambda x: x * 100  # Convert to percentage
            ... )
        """
        return self._add_step(
            "transform_field",
            {"field": field, "transform_fn": transform_fn},
        )

    def apply(self, records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Execute pipeline on data.

        Args:
            records: List of dictionaries to process

        Returns:
            Processed list of dictionaries

        Note:
            - Records are processed in order of steps added
            - Filtering can reduce the number of records
            - Original records are NOT mutated
        """
        # Start with deep copies to avoid mutation
        result = [self._deep_copy_record(r) for r in records]

        # Apply each step in order
        for step_name, params in self._steps:
            result = self._apply_step(step_name, params, result)

        return result

    def _apply_step(
        self,
        step_name: str,
        params: Dict[str, Any],
        records: List[Dict[str, Any]],
    ) -> List[Dict[str, Any]]:
        """Apply a single step to records."""
        if step_name == "rename_fields":
            return self._exec_rename_fields(records, params["mapping"])
        elif step_name == "filter":
            return self._exec_filter(records, params["filter"])
        elif step_name == "convert_arrays":
            return self._exec_convert_arrays(
                records, params["fields"], params["exclude"]
            )
        elif step_name == "add_computed_field":
            return self._exec_add_computed_field(
                records, params["field_name"], params["compute_fn"]
            )
        elif step_name == "remove_fields":
            return self._exec_remove_fields(records, params["fields"])
        elif step_name == "transform_field":
            return self._exec_transform_field(
                records, params["field"], params["transform_fn"]
            )
        else:
            raise ValueError(f"Unknown step: {step_name}")

    def _exec_rename_fields(
        self,
        records: List[Dict[str, Any]],
        mapping: Dict[str, str],
    ) -> List[Dict[str, Any]]:
        """Execute rename_fields step."""
        result = []
        for record in records:
            new_record = {}
            for key, value in record.items():
                new_key = mapping.get(key, key)
                new_record[new_key] = value
            result.append(new_record)
        return result

    def _exec_filter(
        self,
        records: List[Dict[str, Any]],
        filter_obj: Union[Filter, FilterChain, Callable[[Dict], bool]],
    ) -> List[Dict[str, Any]]:
        """Execute filter step."""
        if isinstance(filter_obj, (Filter, FilterChain)):
            return [r for r in records if filter_obj.apply(r)]
        else:
            # Callable
            return [r for r in records if filter_obj(r)]

    def _exec_convert_arrays(
        self,
        records: List[Dict[str, Any]],
        fields: Optional[List[str]],
        exclude: List[str],
    ) -> List[Dict[str, Any]]:
        """Execute convert_arrays step."""
        result = []
        for record in records:
            new_record = {}
            for key, value in record.items():
                # Skip excluded fields
                if key in exclude:
                    new_record[key] = value
                    continue

                # If specific fields requested, only convert those
                if fields is not None and key not in fields:
                    new_record[key] = value
                    continue

                # Convert lists to numpy arrays
                if isinstance(value, list):
                    new_record[key] = np.array(value)
                else:
                    new_record[key] = value

            result.append(new_record)
        return result

    def _exec_add_computed_field(
        self,
        records: List[Dict[str, Any]],
        field_name: str,
        compute_fn: Callable[[Dict], Any],
    ) -> List[Dict[str, Any]]:
        """Execute add_computed_field step."""
        for record in records:
            record[field_name] = compute_fn(record)
        return records

    def _exec_remove_fields(
        self,
        records: List[Dict[str, Any]],
        fields: List[str],
    ) -> List[Dict[str, Any]]:
        """Execute remove_fields step."""
        for record in records:
            for field in fields:
                record.pop(field, None)
        return records

    def _exec_transform_field(
        self,
        records: List[Dict[str, Any]],
        field: str,
        transform_fn: Callable[[Any], Any],
    ) -> List[Dict[str, Any]]:
        """Execute transform_field step."""
        for record in records:
            if field in record:
                record[field] = transform_fn(record[field])
        return records

    @staticmethod
    def _deep_copy_record(record: Dict[str, Any]) -> Dict[str, Any]:
        """Create a deep copy of a record."""
        return copy.deepcopy(record)

    def __len__(self) -> int:
        """Number of steps in pipeline."""
        return len(self._steps)

    def __repr__(self) -> str:
        return f"Pipeline({len(self)} steps)"

    def describe(self) -> str:
        """
        Human-readable description of pipeline steps.

        Returns:
            Multi-line string describing each step
        """
        if not self._steps:
            return "Pipeline: empty (0 steps)"

        lines = [f"Pipeline: {len(self)} steps"]
        for i, (name, params) in enumerate(self._steps, 1):
            desc = self._describe_step(name, params)
            lines.append(f"  {i}. {desc}")
        return "\n".join(lines)

    def _describe_step(self, name: str, params: Dict[str, Any]) -> str:
        """Generate description for a single step."""
        if name == "rename_fields":
            mapping = params["mapping"]
            if len(mapping) <= 3:
                items = ", ".join(f"{k}->{v}" for k, v in mapping.items())
            else:
                items = f"{len(mapping)} fields"
            return f"rename_fields({items})"
        elif name == "filter":
            filter_obj = params["filter"]
            return f"filter({filter_obj})"
        elif name == "convert_arrays":
            fields = params["fields"]
            if fields:
                return f"convert_arrays(fields={fields})"
            return "convert_arrays(all lists)"
        elif name == "add_computed_field":
            return f"add_computed_field({params['field_name']})"
        elif name == "remove_fields":
            return f"remove_fields({params['fields']})"
        elif name == "transform_field":
            return f"transform_field({params['field']})"
        else:
            return f"{name}({params})"


# Utility functions

def make_numpy_arrays(record: Dict[str, Any]) -> Dict[str, Any]:
    """
    Convert all list fields in a record to numpy arrays.

    Args:
        record: Dictionary with potential list values

    Returns:
        New dictionary with lists converted to numpy arrays
    """
    result = {}
    for key, value in record.items():
        if isinstance(value, list):
            result[key] = np.array(value)
        else:
            result[key] = value
    return result


def deep_copy_record(record: Dict[str, Any]) -> Dict[str, Any]:
    """
    Create a deep copy of a record.

    Args:
        record: Dictionary to copy

    Returns:
        Deep copy of the dictionary
    """
    return copy.deepcopy(record)
