"""
Tests for Pipeline - written FIRST before implementation.

TDD Phase 2: These tests define the expected behavior of Pipeline.
All tests should FAIL until implementation is complete.
"""
import pytest
import numpy as np

# This import will fail until implementation exists
try:
    from aurora.pipeline import Pipeline
    from aurora.filters import Filter, FilterChain
except ImportError:
    Pipeline = None
    Filter = None
    FilterChain = None
    pytestmark = pytest.mark.skip(reason="Pipeline not yet implemented")


class TestPipelineCreation:
    """Tests for Pipeline creation"""

    def test_empty_pipeline(self):
        """Create empty pipeline"""
        p = Pipeline()
        assert len(p) == 0

    def test_pipeline_repr(self):
        """Pipeline has useful repr"""
        p = Pipeline()
        assert "Pipeline" in repr(p)
        assert "0 steps" in repr(p)


class TestPipelineImmutability:
    """Tests for Pipeline immutability"""

    def test_rename_returns_new_pipeline(self):
        """rename_fields returns new Pipeline"""
        p1 = Pipeline()
        p2 = p1.rename_fields({"a": "b"})

        assert p1 is not p2
        assert len(p1) == 0
        assert len(p2) == 1

    def test_filter_returns_new_pipeline(self):
        """filter returns new Pipeline"""
        p1 = Pipeline()
        p2 = p1.filter(Filter("x", "==", 1))

        assert p1 is not p2
        assert len(p1) == 0
        assert len(p2) == 1

    def test_chaining_creates_new_pipelines(self):
        """Each chain operation creates new Pipeline"""
        p1 = Pipeline()
        p2 = p1.rename_fields({"a": "b"})
        p3 = p2.filter(Filter("x", "==", 1))
        p4 = p3.convert_arrays()

        assert len(p1) == 0
        assert len(p2) == 1
        assert len(p3) == 2
        assert len(p4) == 3


class TestPipelineRenameFields:
    """Tests for rename_fields operation"""

    def test_rename_single_field(self):
        """Rename a single field"""
        pipeline = Pipeline().rename_fields({"old": "new"})
        result = pipeline.apply([{"old": 1, "other": 2}])

        assert len(result) == 1
        assert "new" in result[0]
        assert "old" not in result[0]
        assert result[0]["new"] == 1
        assert result[0]["other"] == 2

    def test_rename_multiple_fields(self):
        """Rename multiple fields"""
        pipeline = Pipeline().rename_fields({
            "a": "x",
            "b": "y"
        })
        result = pipeline.apply([{"a": 1, "b": 2, "c": 3}])

        assert result[0] == {"x": 1, "y": 2, "c": 3}

    def test_rename_missing_field_ignored(self):
        """Missing fields are silently ignored"""
        pipeline = Pipeline().rename_fields({"missing": "new"})
        result = pipeline.apply([{"existing": 1}])

        assert result[0] == {"existing": 1}

    def test_rename_does_not_mutate_original(self):
        """Original records are not mutated"""
        original = [{"old": 1}]
        pipeline = Pipeline().rename_fields({"old": "new"})
        result = pipeline.apply(original)

        assert "old" in original[0]  # Original unchanged
        assert "new" in result[0]    # Result has new name


class TestPipelineFilter:
    """Tests for filter operation"""

    def test_filter_with_filter_object(self):
        """Filter using Filter object"""
        pipeline = Pipeline().filter(Filter("x", ">", 5))
        result = pipeline.apply([
            {"x": 3},
            {"x": 7},
            {"x": 10}
        ])

        assert len(result) == 2
        assert all(r["x"] > 5 for r in result)

    def test_filter_with_filter_chain(self):
        """Filter using FilterChain"""
        chain = FilterChain([
            Filter("x", ">", 0),
            Filter("x", "<", 10)
        ])
        pipeline = Pipeline().filter(chain)
        result = pipeline.apply([
            {"x": -5},
            {"x": 5},
            {"x": 15}
        ])

        assert len(result) == 1
        assert result[0]["x"] == 5

    def test_filter_with_callable(self):
        """Filter using lambda/function"""
        pipeline = Pipeline().filter(lambda r: r.get("valid", False))
        result = pipeline.apply([
            {"valid": True, "data": 1},
            {"valid": False, "data": 2},
            {"data": 3}  # Missing 'valid' defaults to False
        ])

        assert len(result) == 1
        assert result[0]["data"] == 1

    def test_filter_removes_all(self):
        """Filter can remove all records"""
        pipeline = Pipeline().filter(Filter("x", "==", "nonexistent"))
        result = pipeline.apply([{"x": 1}, {"x": 2}])

        assert result == []

    def test_filter_keeps_all(self):
        """Filter can keep all records"""
        pipeline = Pipeline().filter(Filter("x", ">", 0))
        result = pipeline.apply([{"x": 1}, {"x": 2}])

        assert len(result) == 2


class TestPipelineConvertArrays:
    """Tests for convert_arrays operation"""

    def test_convert_all_lists(self):
        """Convert all list fields to numpy arrays"""
        pipeline = Pipeline().convert_arrays()
        result = pipeline.apply([{
            "list_field": [1, 2, 3],
            "scalar_field": 42,
            "string_field": "hello"
        }])

        assert isinstance(result[0]["list_field"], np.ndarray)
        assert result[0]["scalar_field"] == 42
        assert result[0]["string_field"] == "hello"

    def test_convert_specific_fields(self):
        """Convert only specified fields"""
        pipeline = Pipeline().convert_arrays(fields=["a"])
        result = pipeline.apply([{
            "a": [1, 2, 3],
            "b": [4, 5, 6]
        }])

        assert isinstance(result[0]["a"], np.ndarray)
        assert isinstance(result[0]["b"], list)  # Not converted

    def test_convert_with_exclusions(self):
        """Exclude specific fields from conversion"""
        pipeline = Pipeline().convert_arrays(exclude=["keep_list"])
        result = pipeline.apply([{
            "convert_me": [1, 2],
            "keep_list": [3, 4]
        }])

        assert isinstance(result[0]["convert_me"], np.ndarray)
        assert isinstance(result[0]["keep_list"], list)

    def test_convert_nested_lists(self):
        """Handle nested lists (2D arrays)"""
        pipeline = Pipeline().convert_arrays()
        result = pipeline.apply([{
            "matrix": [[1, 2], [3, 4]]
        }])

        assert isinstance(result[0]["matrix"], np.ndarray)
        assert result[0]["matrix"].shape == (2, 2)


class TestPipelineAddComputedField:
    """Tests for add_computed_field operation"""

    def test_add_simple_computed_field(self):
        """Add a computed field"""
        pipeline = Pipeline().add_computed_field(
            "doubled",
            lambda r: r.get("x", 0) * 2
        )
        result = pipeline.apply([{"x": 5}])

        assert result[0]["doubled"] == 10
        assert result[0]["x"] == 5

    def test_add_computed_field_from_multiple_fields(self):
        """Computed field can use multiple source fields"""
        pipeline = Pipeline().add_computed_field(
            "sum",
            lambda r: r.get("a", 0) + r.get("b", 0)
        )
        result = pipeline.apply([{"a": 3, "b": 4}])

        assert result[0]["sum"] == 7

    def test_add_computed_field_overwrites(self):
        """Computed field can overwrite existing field"""
        pipeline = Pipeline().add_computed_field(
            "x",
            lambda r: r.get("x", 0) * 2
        )
        result = pipeline.apply([{"x": 5}])

        assert result[0]["x"] == 10

    def test_computed_field_with_conditional(self):
        """Computed field with conditional logic"""
        def base_name(r):
            trainer = r.get("trainer", "")
            if trainer == "CE":
                return "DeepDrebin"
            return trainer

        pipeline = Pipeline().add_computed_field("base_name", base_name)
        result = pipeline.apply([
            {"trainer": "CE"},
            {"trainer": "HCC"}
        ])

        assert result[0]["base_name"] == "DeepDrebin"
        assert result[1]["base_name"] == "HCC"


class TestPipelineRemoveFields:
    """Tests for remove_fields operation"""

    def test_remove_single_field(self):
        """Remove a single field"""
        pipeline = Pipeline().remove_fields(["temp"])
        result = pipeline.apply([{"temp": 1, "keep": 2}])

        assert "temp" not in result[0]
        assert "keep" in result[0]

    def test_remove_multiple_fields(self):
        """Remove multiple fields"""
        pipeline = Pipeline().remove_fields(["a", "b"])
        result = pipeline.apply([{"a": 1, "b": 2, "c": 3}])

        assert result[0] == {"c": 3}

    def test_remove_missing_field_ignored(self):
        """Removing missing field is silently ignored"""
        pipeline = Pipeline().remove_fields(["missing"])
        result = pipeline.apply([{"existing": 1}])

        assert result[0] == {"existing": 1}


class TestPipelineTransformField:
    """Tests for transform_field operation"""

    def test_transform_numeric_field(self):
        """Transform numeric field"""
        pipeline = Pipeline().transform_field("x", lambda v: v * 2)
        result = pipeline.apply([{"x": 5}])

        assert result[0]["x"] == 10

    def test_transform_string_field(self):
        """Transform string field"""
        pipeline = Pipeline().transform_field("name", str.upper)
        result = pipeline.apply([{"name": "hello"}])

        assert result[0]["name"] == "HELLO"

    def test_transform_missing_field_ignored(self):
        """Missing field is silently ignored"""
        pipeline = Pipeline().transform_field("missing", lambda v: v * 2)
        result = pipeline.apply([{"existing": 1}])

        assert result[0] == {"existing": 1}


class TestPipelineChaining:
    """Tests for chaining multiple operations"""

    def test_rename_then_filter(self):
        """Rename fields then filter"""
        pipeline = (
            Pipeline()
            .rename_fields({"old_month": "month"})
            .filter(Filter("month", "<=", 22))
        )
        result = pipeline.apply([
            {"old_month": 20},
            {"old_month": 25}
        ])

        assert len(result) == 1
        assert result[0]["month"] == 20

    def test_complex_pipeline(self):
        """Complex pipeline with multiple operations"""
        pipeline = (
            Pipeline()
            .rename_fields({"Test-Month": "month", "Trainer-Mode": "trainer"})
            .filter(Filter("month", "<=", 22))
            .add_computed_field("base_name", lambda r: r.get("trainer", "Unknown"))
            .convert_arrays(fields=["data"])
            .remove_fields(["temp"])
        )

        data = [
            {"Test-Month": 20, "Trainer-Mode": "CE", "data": [1, 2], "temp": "x"},
            {"Test-Month": 25, "Trainer-Mode": "HCC", "data": [3, 4], "temp": "y"},
        ]
        result = pipeline.apply(data)

        assert len(result) == 1
        assert result[0]["month"] == 20
        assert result[0]["base_name"] == "CE"
        assert isinstance(result[0]["data"], np.ndarray)
        assert "temp" not in result[0]


class TestPipelineApplyEmpty:
    """Tests for applying pipeline to empty data"""

    def test_apply_to_empty_list(self):
        """Pipeline handles empty input"""
        pipeline = Pipeline().rename_fields({"a": "b"})
        result = pipeline.apply([])

        assert result == []

    def test_empty_pipeline_passthrough(self):
        """Empty pipeline passes data through unchanged"""
        pipeline = Pipeline()
        data = [{"x": 1}, {"x": 2}]
        result = pipeline.apply(data)

        assert len(result) == 2
        # Should be copies, not same objects
        assert result[0] is not data[0]


class TestPipelineDescribe:
    """Tests for describe() method"""

    def test_describe_empty(self):
        """Describe empty pipeline"""
        desc = Pipeline().describe()
        assert "empty" in desc.lower() or "0" in desc

    def test_describe_with_steps(self):
        """Describe pipeline with steps"""
        pipeline = (
            Pipeline()
            .rename_fields({"a": "b"})
            .filter(Filter("x", ">", 0))
        )
        desc = pipeline.describe()

        assert "rename" in desc.lower()
        assert "filter" in desc.lower()
