"""
Tests for Performance-Based Rejection Module

Tests the statistically principled metrics and threshold calibration
for performance-based rejection.
"""

import pytest
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal

from aurora.performance_rejection import (
    PerformanceRejectionSimulator,
    TargetAdherenceMetrics,
    ReliabilityDiagramData,
    MonthlyRejectionResult,
    mann_kendall_test,
    bootstrap_ci,
    run_performance_rejection_analysis,
    compute_metrics_numba,
    apply_uncertainty_threshold,
)


# ============================================================================
# TEST DATA FIXTURES
# ============================================================================

@pytest.fixture
def simple_binary_data():
    """Simple binary classification data for testing."""
    np.random.seed(42)
    n = 1000

    # True labels: 70% class 0 (benign), 30% class 1 (malware)
    labels = np.random.choice([0, 1], size=n, p=[0.7, 0.3])

    # Predictions: 90% accurate
    predictions = labels.copy()
    flip_mask = np.random.random(n) < 0.1
    predictions[flip_mask] = 1 - predictions[flip_mask]

    # Uncertainties: Higher for misclassified samples
    uncertainties = np.random.random(n) * 0.5  # Base uncertainty
    misclassified = predictions != labels
    uncertainties[misclassified] += 0.3  # Misclassified are more uncertain

    return predictions, labels, uncertainties


@pytest.fixture
def monthly_data():
    """Monthly data for deployment simulation."""
    np.random.seed(42)
    n_months = 12
    samples_per_month = 500

    predictions_by_month = []
    labels_by_month = []
    uncertainties_by_month = []

    for month in range(n_months):
        # Simulate drift: accuracy decreases slightly over time
        accuracy = 0.92 - 0.01 * month

        labels = np.random.choice([0, 1], size=samples_per_month, p=[0.7, 0.3])
        predictions = labels.copy()

        # Introduce errors based on accuracy
        flip_mask = np.random.random(samples_per_month) > accuracy
        predictions[flip_mask] = 1 - predictions[flip_mask]

        # Uncertainties
        uncertainties = np.random.random(samples_per_month) * 0.5
        misclassified = predictions != labels
        uncertainties[misclassified] += 0.25

        predictions_by_month.append(predictions)
        labels_by_month.append(labels)
        uncertainties_by_month.append(uncertainties)

    return predictions_by_month, labels_by_month, uncertainties_by_month


# ============================================================================
# STATISTICAL UTILITIES TESTS
# ============================================================================

class TestMannKendall:
    """Tests for Mann-Kendall trend test."""

    def test_increasing_trend(self):
        """Increasing sequence should have positive tau."""
        x = np.arange(10, dtype=float)
        tau, p = mann_kendall_test(x)

        assert tau > 0.9, f"Expected tau > 0.9 for increasing sequence, got {tau}"
        assert p < 0.05, f"Expected significant p-value, got {p}"

    def test_decreasing_trend(self):
        """Decreasing sequence should have negative tau."""
        x = np.arange(10, 0, -1, dtype=float)
        tau, p = mann_kendall_test(x)

        assert tau < -0.9, f"Expected tau < -0.9 for decreasing sequence, got {tau}"
        assert p < 0.05, f"Expected significant p-value, got {p}"

    def test_no_trend(self):
        """Random sequence should have tau near 0."""
        np.random.seed(42)
        x = np.random.random(100)
        tau, p = mann_kendall_test(x)

        assert abs(tau) < 0.3, f"Expected |tau| < 0.3 for random sequence, got {tau}"

    def test_short_sequence(self):
        """Short sequences should return default values."""
        x = np.array([1.0, 2.0])  # Only 2 elements
        tau, p = mann_kendall_test(x)

        assert tau == 0.0
        assert p == 1.0


class TestBootstrapCI:
    """Tests for bootstrap confidence intervals."""

    def test_mean_ci(self):
        """Bootstrap CI for mean should contain true mean."""
        np.random.seed(42)
        true_mean = 5.0
        x = np.random.normal(true_mean, 1.0, size=100)

        ci_low, ci_high = bootstrap_ci(x, statistic='mean', confidence=0.95, seed=42)

        assert ci_low < true_mean < ci_high, \
            f"True mean {true_mean} not in CI [{ci_low}, {ci_high}]"

    def test_ci_width(self):
        """CI should be narrower with larger sample."""
        np.random.seed(42)

        # Small sample
        x_small = np.random.normal(0, 1, size=20)
        ci_low_s, ci_high_s = bootstrap_ci(x_small, seed=42)
        width_small = ci_high_s - ci_low_s

        # Large sample
        x_large = np.random.normal(0, 1, size=500)
        ci_low_l, ci_high_l = bootstrap_ci(x_large, seed=42)
        width_large = ci_high_l - ci_low_l

        assert width_large < width_small, \
            f"Large sample CI should be narrower: {width_large} vs {width_small}"

    def test_empty_array(self):
        """Empty array should return NaN."""
        ci_low, ci_high = bootstrap_ci(np.array([]))

        assert np.isnan(ci_low)
        assert np.isnan(ci_high)


# ============================================================================
# CORE REJECTION FUNCTIONS TESTS
# ============================================================================

class TestComputeMetrics:
    """Tests for compute_metrics_numba."""

    def test_perfect_classification(self):
        """Perfect classification should give F1=1, FNR=0, FPR=0."""
        labels = np.array([0, 0, 0, 1, 1, 1] * 10)  # 60 samples
        predictions = labels.copy()

        f1, fnr, fpr = compute_metrics_numba(labels, predictions)

        assert_allclose(f1, 1.0, rtol=1e-6)
        assert_allclose(fnr, 0.0, rtol=1e-6)
        assert_allclose(fpr, 0.0, rtol=1e-6)

    def test_all_wrong(self):
        """All wrong predictions should give F1=0."""
        labels = np.array([0, 0, 0, 1, 1, 1] * 10)
        predictions = 1 - labels

        f1, fnr, fpr = compute_metrics_numba(labels, predictions)

        assert_allclose(f1, 0.0, rtol=1e-6)
        assert_allclose(fnr, 1.0, rtol=1e-6)
        assert_allclose(fpr, 1.0, rtol=1e-6)

    def test_small_sample(self):
        """Small samples should return NaN."""
        labels = np.array([0, 1, 0, 1, 0])  # Only 5 samples
        predictions = labels.copy()

        f1, fnr, fpr = compute_metrics_numba(labels, predictions)

        assert np.isnan(f1)
        assert np.isnan(fnr)
        assert np.isnan(fpr)


class TestApplyThreshold:
    """Tests for apply_uncertainty_threshold."""

    def test_threshold_accepts_low_uncertainty(self):
        """Low uncertainty samples should be accepted."""
        predictions = np.array([0, 1, 0, 1, 0])
        labels = np.array([0, 1, 0, 1, 0])
        uncertainties = np.array([0.1, 0.2, 0.3, 0.4, 0.5])

        # Threshold at 0.35 should accept first 3 samples
        acc_preds, acc_labels, n_acc, n_rej = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold=0.35
        )

        assert n_acc == 3
        assert n_rej == 2
        assert_array_equal(acc_preds, np.array([0, 1, 0]))
        assert_array_equal(acc_labels, np.array([0, 1, 0]))

    def test_threshold_rejects_all(self):
        """Very low threshold should reject all."""
        predictions = np.array([0, 1, 0, 1, 0])
        labels = np.array([0, 1, 0, 1, 0])
        uncertainties = np.array([0.1, 0.2, 0.3, 0.4, 0.5])

        acc_preds, acc_labels, n_acc, n_rej = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold=0.05
        )

        assert n_acc == 0
        assert n_rej == 5

    def test_threshold_accepts_all(self):
        """Very high threshold should accept all."""
        predictions = np.array([0, 1, 0, 1, 0])
        labels = np.array([0, 1, 0, 1, 0])
        uncertainties = np.array([0.1, 0.2, 0.3, 0.4, 0.5])

        acc_preds, acc_labels, n_acc, n_rej = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold=1.0
        )

        assert n_acc == 5
        assert n_rej == 0


# ============================================================================
# PERFORMANCE REJECTION SIMULATOR TESTS
# ============================================================================

class TestPerformanceRejectionSimulator:
    """Tests for PerformanceRejectionSimulator."""

    def test_calibrate_threshold_fnr(self, simple_binary_data):
        """Calibration should find threshold achieving target FNR."""
        predictions, labels, uncertainties = simple_binary_data
        sim = PerformanceRejectionSimulator()

        target_fnr = 0.05
        threshold = sim.calibrate_threshold(
            predictions, labels, uncertainties,
            target_metric="FNR", target_value=target_fnr
        )

        # Apply threshold and check FNR
        acc_preds, acc_labels, _, _ = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold
        )

        if len(acc_preds) >= 20:
            _, actual_fnr, _ = compute_metrics_numba(acc_labels, acc_preds)
            # Should be close to target (within tolerance)
            assert abs(actual_fnr - target_fnr) < 0.02, \
                f"FNR {actual_fnr} not close to target {target_fnr}"

    def test_calibrate_threshold_fpr(self, simple_binary_data):
        """Calibration should find threshold achieving target FPR."""
        predictions, labels, uncertainties = simple_binary_data
        sim = PerformanceRejectionSimulator()

        target_fpr = 0.02
        threshold = sim.calibrate_threshold(
            predictions, labels, uncertainties,
            target_metric="FPR", target_value=target_fpr
        )

        # Apply threshold and check FPR
        acc_preds, acc_labels, _, _ = apply_uncertainty_threshold(
            predictions, labels, uncertainties, threshold
        )

        if len(acc_preds) >= 20:
            _, _, actual_fpr = compute_metrics_numba(acc_labels, acc_preds)
            # Should be close to target
            assert abs(actual_fpr - target_fpr) < 0.02, \
                f"FPR {actual_fpr} not close to target {target_fpr}"

    def test_apply_threshold_single_month(self, simple_binary_data):
        """apply_threshold_single_month should return valid results."""
        predictions, labels, uncertainties = simple_binary_data
        sim = PerformanceRejectionSimulator()

        result = sim.apply_threshold_single_month(
            predictions, labels, uncertainties,
            threshold=0.5, month_idx=0
        )

        assert isinstance(result, MonthlyRejectionResult)
        assert result.total_samples == len(predictions)
        assert result.accepted_samples + result.rejected_samples == result.total_samples
        assert 0 <= result.coverage <= 1
        assert not np.isnan(result.baseline_f1)

    def test_simulate_deployment(self, simple_binary_data, monthly_data):
        """simulate_deployment should process all months."""
        val_preds, val_labels, val_uncs = simple_binary_data
        test_preds, test_labels, test_uncs = monthly_data

        sim = PerformanceRejectionSimulator()

        # Calibrate on validation
        threshold = sim.calibrate_threshold(
            val_preds, val_labels, val_uncs,
            target_metric="FNR", target_value=0.05
        )

        # Deploy on test
        results = sim.simulate_deployment(
            test_preds, test_labels, test_uncs,
            threshold=threshold,
            target_metric="FNR", target_value=0.05
        )

        assert len(results) == len(test_preds)
        for r in results:
            assert isinstance(r, MonthlyRejectionResult)

    def test_compute_adherence_metrics(self, simple_binary_data, monthly_data):
        """compute_adherence_metrics should return valid metrics."""
        val_preds, val_labels, val_uncs = simple_binary_data
        test_preds, test_labels, test_uncs = monthly_data

        sim = PerformanceRejectionSimulator()

        threshold = sim.calibrate_threshold(
            val_preds, val_labels, val_uncs,
            target_metric="FNR", target_value=0.05
        )

        monthly_results = sim.simulate_deployment(
            test_preds, test_labels, test_uncs,
            threshold=threshold,
            target_metric="FNR", target_value=0.05
        )

        metrics = sim.compute_adherence_metrics(
            monthly_results,
            target_metric="FNR", target_value=0.05
        )

        assert isinstance(metrics, TargetAdherenceMetrics)

        # Check MAE is non-negative
        assert metrics.mae >= 0, "MAE should be non-negative"

        # Check hit rates are in [0, 1]
        assert 0 <= metrics.hit_rate_1pct <= 1
        assert 0 <= metrics.hit_rate_2pct <= 1
        assert 0 <= metrics.hit_rate_5pct <= 1

        # Check coverage is in [0, 1]
        assert 0 <= metrics.mean_coverage <= 1

        # Check Mann-Kendall tau is in [-1, 1]
        assert -1 <= metrics.mann_kendall_tau <= 1


class TestTargetAdherenceMetrics:
    """Tests for TargetAdherenceMetrics dataclass."""

    def test_to_dict(self):
        """to_dict should return dictionary with all scalar fields."""
        metrics = TargetAdherenceMetrics(
            target_value=0.05,
            target_metric="FNR",
            mae=0.02,
            bias=-0.01,
            rmse=0.025,
            hit_rate_1pct=0.5,
            hit_rate_2pct=0.7,
            hit_rate_5pct=0.9,
            mean_coverage=0.85,
            cv_coverage=0.1,
            min_coverage=0.7,
            cv_deviation=0.5,
            mann_kendall_tau=0.2,
            mann_kendall_p=0.3,
            max_deviation=0.08,
        )

        d = metrics.to_dict()

        assert 'target_value' in d
        assert 'mae' in d
        assert 'hit_rate_2pct' in d
        assert d['mae'] == 0.02
        assert d['target_metric'] == "FNR"


class TestReliabilityDiagramData:
    """Tests for ReliabilityDiagramData."""

    def test_calibration_error_perfect(self):
        """Perfect calibration should have zero error."""
        targets = np.array([0.01, 0.05, 0.10, 0.15])
        observed = targets.copy()  # Perfect match

        data = ReliabilityDiagramData(
            targets=targets,
            observed_mean=observed,
            observed_std=np.zeros_like(targets),
            observed_ci_low=observed,
            observed_ci_high=observed,
            coverage_mean=np.ones_like(targets) * 0.9,
            coverage_std=np.zeros_like(targets),
            n_months=12,
        )

        error = data.calibration_error()
        assert_allclose(error, 0.0, atol=1e-6)

    def test_calibration_error_systematic(self):
        """Systematic over-shooting should have positive error."""
        targets = np.array([0.01, 0.05, 0.10, 0.15])
        observed = targets + 0.02  # Always 2% over

        data = ReliabilityDiagramData(
            targets=targets,
            observed_mean=observed,
            observed_std=np.zeros_like(targets),
            observed_ci_low=observed,
            observed_ci_high=observed,
            coverage_mean=np.ones_like(targets) * 0.9,
            coverage_std=np.zeros_like(targets),
            n_months=12,
        )

        error = data.calibration_error()
        assert error > 0, "Systematic error should give positive calibration error"


# ============================================================================
# CONVENIENCE FUNCTION TESTS
# ============================================================================

class TestRunPerformanceRejectionAnalysis:
    """Tests for run_performance_rejection_analysis."""

    def test_full_analysis(self, simple_binary_data, monthly_data):
        """Full analysis should return all expected keys."""
        val_preds, val_labels, val_uncs = simple_binary_data
        test_preds, test_labels, test_uncs = monthly_data

        results = run_performance_rejection_analysis(
            val_preds, val_labels, val_uncs,
            test_preds, test_labels, test_uncs,
            fnr_targets=[0.05, 0.10],
            fpr_targets=[0.02, 0.05],
        )

        assert 'fnr_metrics' in results
        assert 'fpr_metrics' in results
        assert 'fnr_reliability' in results
        assert 'fpr_reliability' in results

        assert len(results['fnr_metrics']) == 2  # Two FNR targets
        assert len(results['fpr_metrics']) == 2  # Two FPR targets

        assert isinstance(results['fnr_reliability'], ReliabilityDiagramData)
        assert isinstance(results['fpr_reliability'], ReliabilityDiagramData)


# ============================================================================
# INTEGRATION TESTS
# ============================================================================

class TestIntegration:
    """Integration tests for the full workflow."""

    def test_end_to_end_workflow(self, simple_binary_data, monthly_data):
        """Test complete workflow from calibration to metrics."""
        val_preds, val_labels, val_uncs = simple_binary_data
        test_preds, test_labels, test_uncs = monthly_data

        sim = PerformanceRejectionSimulator()

        # 1. Calibrate on validation
        target_fnr = 0.05
        threshold = sim.calibrate_threshold(
            val_preds, val_labels, val_uncs,
            target_metric="FNR", target_value=target_fnr
        )

        assert threshold is not None
        assert not np.isnan(threshold)

        # 2. Simulate deployment
        monthly_results = sim.simulate_deployment(
            test_preds, test_labels, test_uncs,
            threshold=threshold,
            target_metric="FNR", target_value=target_fnr
        )

        assert len(monthly_results) == 12  # 12 months

        # 3. Compute adherence metrics
        metrics = sim.compute_adherence_metrics(
            monthly_results,
            target_metric="FNR", target_value=target_fnr
        )

        # 4. Verify metrics are reasonable
        # MAE should be small if calibration works
        assert metrics.mae < 0.2, f"MAE too large: {metrics.mae}"

        # Hit rate at 5% tolerance should be reasonable
        assert metrics.hit_rate_5pct > 0.3, f"HitRate@5% too low: {metrics.hit_rate_5pct}"

        # Coverage should be reasonable (not rejecting everything)
        assert metrics.mean_coverage > 0.5, f"Coverage too low: {metrics.mean_coverage}"

    def test_target_grid_simulation(self, simple_binary_data, monthly_data):
        """Test grid simulation over multiple targets."""
        val_preds, val_labels, val_uncs = simple_binary_data
        test_preds, test_labels, test_uncs = monthly_data

        sim = PerformanceRejectionSimulator()

        fnr_targets = [0.02, 0.05, 0.10, 0.15]

        adherence_results = sim.simulate_target_grid(
            val_preds, val_labels, val_uncs,
            test_preds, test_labels, test_uncs,
            target_metric="FNR",
            target_values=fnr_targets,
        )

        assert len(adherence_results) == len(fnr_targets)

        # Stricter targets should have lower coverage (more rejection)
        coverages = [m.mean_coverage for m in adherence_results]
        assert coverages[0] < coverages[-1], \
            "Stricter targets should require lower coverage"

        # Create reliability diagram
        reliability = sim.compute_reliability_diagram(adherence_results)

        assert len(reliability.targets) == len(fnr_targets)
        assert all(reliability.observed_mean >= 0)


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
