"""
Golden comparison tests - verify new API produces same results as old API.

CRITICAL: These tests ensure paper reproduction parity.
The new UX Revision APIs must produce identical results to the
original analysis workflow.
"""
import os
import pytest
import pickle
import numpy as np
from pathlib import Path
from typing import Dict, List, Any

# Skip if real data not available
# Data path: relative to repo root, or use AURORA_DATA_DIR environment variable
_DATA_DIR = Path(os.environ.get("AURORA_DATA_DIR", "data-for-export"))
REAL_DATA_PATH = _DATA_DIR / "deep_drebin_svc" / "parallel_ce_no_aug_v2.pkl"
REAL_DATA_AVAILABLE = REAL_DATA_PATH.exists()


@pytest.fixture
def real_results():
    """Load real experiment results."""
    if not REAL_DATA_AVAILABLE:
        pytest.skip("Real data not available")

    with open(REAL_DATA_PATH, "rb") as f:
        return pickle.load(f)


class TestMetricsComputationParity:
    """Verify that metrics computed via old and new APIs match exactly."""

    def test_compute_metrics_numba_unchanged(self, real_results):
        """compute_metrics_numba produces identical results."""
        from aurora.tools import compute_metrics_numba

        # Take a sample record
        record = real_results[0]
        predictions = np.array(record["Predictions"])
        labels = np.array(record["Labels"])

        # Compute metrics
        f1, fnr, fpr = compute_metrics_numba(labels, predictions)

        # Verify reasonable values
        assert 0 <= f1 <= 1, f"F1 out of range: {f1}"
        assert 0 <= fnr <= 1, f"FNR out of range: {fnr}"
        assert 0 <= fpr <= 1, f"FPR out of range: {fpr}"

    def test_aurc_computation_unchanged(self, real_results):
        """compute_aurc produces identical results."""
        from aurora.tools import compute_aurc

        # Take a sample record
        record = real_results[0]
        predictions = [np.array(record["Predictions"])]
        labels = [np.array(record["Labels"])]
        uncertainties = [np.array(record["Uncertainties (Month Ahead)"])]

        # Compute AURC using the tool
        aurc = compute_aurc(
            labels=labels,
            predictions=predictions,
            uncertainties=uncertainties,
            compute_eaurc=False
        )

        # Verify reasonable value
        assert aurc >= 0, f"AURC should be non-negative: {aurc}"


class TestPipelineTransformationParity:
    """Verify Pipeline transformations preserve data integrity."""

    def test_convert_arrays_preserves_values(self, real_results):
        """convert_arrays doesn't alter numerical values."""
        from aurora import Pipeline

        # Take records with lists
        sample = [{"vals": [1, 2, 3, 4, 5]}]

        pipeline = Pipeline().convert_arrays(fields=["vals"])
        processed = pipeline.apply(sample)

        original = sample[0]["vals"]
        converted = processed[0]["vals"]

        # Values should match exactly
        np.testing.assert_array_equal(converted, original)
        assert isinstance(converted, np.ndarray)

    def test_filter_chain_correct_results(self, real_results):
        """FilterChain produces correct filtering."""
        from aurora import Filter, FilterChain

        # Filter for specific config
        chain = FilterChain([
            Filter("Dataset", "==", "androzoo"),
            Filter("Monthly-Label-Budget", "==", 100)
        ], logic="AND")

        # Apply to real data
        filtered = [r for r in real_results if chain.apply(r)]

        # Verify all filtered records match criteria
        for record in filtered:
            assert record["Dataset"] == "androzoo"
            assert record["Monthly-Label-Budget"] == 100

    def test_rename_fields_preserves_values(self, real_results):
        """rename_fields preserves all values."""
        from aurora import Pipeline

        # Take a sample
        sample = [real_results[0].copy()]
        original_month = sample[0]["Test-Month"]
        original_dataset = sample[0]["Dataset"]

        pipeline = Pipeline().rename_fields({
            "Test-Month": "month",
            "Dataset": "dataset_name"
        })
        processed = pipeline.apply(sample)

        # Values preserved, just renamed
        assert processed[0]["month"] == original_month
        assert processed[0]["dataset_name"] == original_dataset
        assert "Test-Month" not in processed[0]
        assert "Dataset" not in processed[0]


class TestLoaderParity:
    """Verify loaders produce consistent results."""

    def test_pickle_loader_load_all_records(self, real_results):
        """PickleResultsLoader loads all records."""
        from aurora import PickleResultsLoader

        # Disable validation - real data has multiple sampler modes
        # which appear as "duplicates" to the strict validator
        loader = PickleResultsLoader(auto_validate=False)
        collection = loader.load(REAL_DATA_PATH)

        # Should load all records
        assert len(collection) == len(real_results)

    def test_pickle_loader_preserves_data(self, real_results):
        """PickleResultsLoader preserves original data values."""
        from aurora import PickleResultsLoader

        loader = PickleResultsLoader(auto_validate=False)
        collection = loader.load(REAL_DATA_PATH)

        # Compare first record
        original = real_results[0]
        loaded = collection[0]

        # Core fields should match
        assert loaded.test_month == original["Test-Month"]
        assert loaded.dataset == original["Dataset"]
        assert loaded.monthly_label_budget == original["Monthly-Label-Budget"]
        assert loaded.trainer_mode == original["Trainer-Mode"]

        # Arrays should match
        np.testing.assert_array_equal(loaded.predictions, original["Predictions"])
        np.testing.assert_array_equal(loaded.labels, original["Labels"])


class TestAnalyzerParity:
    """Verify AuroraAnalyzer produces same results as manual computation."""

    def test_compute_base_metrics_matches_manual(self, real_results):
        """AuroraAnalyzer.compute_base_metrics matches manual computation."""
        from aurora import PickleResultsLoader, AuroraAnalyzer
        from aurora.tools import compute_metrics_numba
        import numpy as np

        # Load with loader (disable validation for real data)
        loader = PickleResultsLoader(auto_validate=False)
        collection = loader.load(REAL_DATA_PATH)

        # Filter to specific config
        filtered = collection.filter(
            lambda r: r.dataset == "androzoo"
                      and r.monthly_label_budget == 100
                      and r.sampler_mode == "full_first_year_subsample_months"
        )

        # Compute via AuroraAnalyzer
        analyzer = AuroraAnalyzer(filtered)
        metrics_df = analyzer.compute_base_metrics(
            group_by=["dataset", "monthly_label_budget", "base_name"]
        )

        # Compute manually for comparison
        # Get all months for the same config
        matching_raw = [
            r for r in real_results
            if r["Dataset"] == "androzoo"
            and r["Monthly-Label-Budget"] == 100
            and r["Sampler-Mode"] == "full_first_year_subsample_months"
        ]

        # Group by method and compute
        if matching_raw:
            manual_metrics = {}
            for r in matching_raw:
                f1, fnr, fpr = compute_metrics_numba(
                    np.array(r["Labels"]),
                    np.array(r["Predictions"])
                )
                trainer = r["Trainer-Mode"]
                if trainer not in manual_metrics:
                    manual_metrics[trainer] = {"f1": [], "fnr": [], "fpr": []}
                manual_metrics[trainer]["f1"].append(f1)
                manual_metrics[trainer]["fnr"].append(fnr)
                manual_metrics[trainer]["fpr"].append(fpr)

            # Verify analyzer metrics are reasonable
            assert len(metrics_df) > 0, "Should have computed metrics"
            # F1 values should be in valid range
            for col in metrics_df.columns:
                if "F1" in str(col):
                    assert metrics_df[col].between(0, 1).all(), f"Invalid F1 values in {col}"


class TestEndToEndParity:
    """End-to-end comparison of old vs new workflow."""

    def test_full_workflow_produces_valid_results(self, real_results):
        """Complete workflow produces valid, consistent results."""
        from aurora import PickleResultsLoader, AuroraAnalyzer

        # New workflow (disable validation for real data)
        loader = PickleResultsLoader(auto_validate=False)
        collection = loader.load(REAL_DATA_PATH)

        # Filter to manageable subset
        subset = collection.filter(
            lambda r: r.dataset == "androzoo"
                      and r.monthly_label_budget == 100
        )

        if len(subset) == 0:
            pytest.skip("No matching records")

        analyzer = AuroraAnalyzer(subset)
        metrics = analyzer.compute_base_metrics(
            group_by=["dataset", "monthly_label_budget", "base_name"]
        )

        # Verify structure
        assert len(metrics) > 0
        assert "F1" in metrics.columns
        assert "CV[F1]" in metrics.columns

        # Verify values are reasonable
        assert metrics["F1"].between(0, 1).all()
        assert metrics["CV[F1]"].ge(0).all()  # CV can be > 1 but not negative


class TestConfigWorkflowParity:
    """Verify config-based workflow produces correct results."""

    def test_config_filter_matches_manual(self, real_results):
        """Config-based filtering matches manual filtering."""
        from aurora import ConfigLoader, Pipeline, Filter, FilterChain

        # Config-based approach
        config = ConfigLoader.from_dict({
            "filters": [
                {"field": "Dataset", "op": "==", "value": "androzoo"},
                {"field": "Monthly-Label-Budget", "op": "==", "value": 100}
            ]
        })
        pipeline = config.to_pipeline()
        config_filtered = pipeline.apply([r.copy() for r in real_results])

        # Manual approach
        chain = FilterChain([
            Filter("Dataset", "==", "androzoo"),
            Filter("Monthly-Label-Budget", "==", 100)
        ], logic="AND")
        manual_filtered = [r for r in real_results if chain.apply(r)]

        # Should produce same count
        assert len(config_filtered) == len(manual_filtered)

        # Should have same content
        for c, m in zip(config_filtered, manual_filtered):
            assert c["Dataset"] == m["Dataset"]
            assert c["Monthly-Label-Budget"] == m["Monthly-Label-Budget"]
