#!/usr/bin/env python3
"""
Rejection Simulation Edge Cases Test Suite

Tests the rejection simulator with problematic inputs to ensure robustness.

Critical edge cases:
1. Empty arrays
2. Insufficient samples (<20)
3. All same predictions
4. Budget = 0 (reject nothing)
5. Budget > N (reject everything)
6. Tied uncertainties
7. NaN/Inf values
8. Single-month data
"""

import sys
sys.path.insert(0, 'src')
sys.path.insert(0, '.')

import numpy as np
from aurora import PostHocRejectorSimulator, compute_metrics_numba
from numba.typed import List


def test_empty_arrays():
    """Test rejection with empty arrays"""
    print("=" * 80)
    print("TEST 1: Empty Arrays")
    print("=" * 80)

    try:
        uncertainties = List()
        uncertainties.append(np.array([]))
        predictions = List()
        predictions.append(np.array([]))
        labels = List()
        labels.append(np.array([]))

        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[10],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with empty arrays")
        print(f"  Results: {results}")
        return {'status': 'pass', 'results': results}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_insufficient_samples():
    """Test rejection with <20 samples"""
    print("\n" + "=" * 80)
    print("TEST 2: Insufficient Samples (<20)")
    print("=" * 80)

    # Only 10 samples (compute_metrics_numba requires ≥20)
    n_samples = 10

    uncertainties = List()
    uncertainties.append(np.random.random(n_samples))

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[3],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with {n_samples} samples")
        print(f"  F1 (budget=3): {results[3]['F1']}")

        # Should return NaN metrics since n < 20
        if np.isnan(results[3]['F1']):
            print(f"  ✓ Correctly returns NaN for insufficient samples")
            return {'status': 'pass', 'returns_nan': True}
        else:
            print(f"  ⚠️ Warning: Returns numeric value for <20 samples")
            return {'status': 'warning', 'returns_nan': False}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_all_same_predictions():
    """Test rejection when all predictions are the same"""
    print("\n" + "=" * 80)
    print("TEST 3: All Same Predictions")
    print("=" * 80)

    n_samples = 100

    uncertainties = List()
    uncertainties.append(np.random.random(n_samples))

    # All predictions = 1
    predictions = List()
    predictions.append(np.ones(n_samples, dtype=np.int64))

    # Mixed labels
    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[10, 20],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with constant predictions")
        print(f"  F1 (budget=0): {results[0]['F1']:.4f}")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")
        print(f"  F1 (budget=20): {results[20]['F1']:.4f}")

        # F1 should be computable (or 0 if all predictions are wrong)
        return {'status': 'pass', 'results': results}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_budget_zero():
    """Test rejection with budget=0 (reject nothing)"""
    print("\n" + "=" * 80)
    print("TEST 4: Budget = 0 (Reject Nothing)")
    print("=" * 80)

    n_samples = 100

    uncertainties = List()
    uncertainties.append(np.random.random(n_samples))

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[0, 10],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with budget=0")
        print(f"  F1 (budget=0): {results[0]['F1']:.4f}")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")

        # Budget=0 should not reject anything
        # So F1 should be the baseline F1
        if results[0]['monthly_Rejections'][0] == 0:
            print(f"  ✓ Correctly rejects 0 samples with budget=0")
            return {'status': 'pass', 'rejects_zero': True}
        else:
            print(f"  ⚠️ Warning: Rejects {results[0]['monthly_Rejections'][0]} samples with budget=0")
            return {'status': 'warning', 'rejects_zero': False}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_budget_exceeds_samples():
    """Test rejection with budget > number of samples"""
    print("\n" + "=" * 80)
    print("TEST 5: Budget > N (Budget Exceeds Sample Count)")
    print("=" * 80)

    n_samples = 50

    uncertainties = List()
    uncertainties.append(np.random.random(n_samples))

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        # Budget = 100, but only 50 samples
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[25, 100],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with budget > N")
        print(f"  F1 (budget=25): {results[25]['F1']:.4f}")
        print(f"  F1 (budget=100): {results[100]['F1']:.4f}")
        print(f"  Rejections (budget=25): {results[25]['monthly_Rejections'][0]}")
        print(f"  Rejections (budget=100): {results[100]['monthly_Rejections'][0]}")

        # Should reject at most N samples
        max_rejections = results[100]['monthly_Rejections'][0]
        if max_rejections <= n_samples:
            print(f"  ✓ Correctly limits rejections to {max_rejections} ≤ {n_samples}")
            return {'status': 'pass', 'results': results}
        else:
            print(f"  ✗ FAIL: Rejects {max_rejections} > {n_samples} samples!")
            return {'status': 'fail', 'over_rejection': True}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_tied_uncertainties():
    """Test rejection with many tied uncertainty values"""
    print("\n" + "=" * 80)
    print("TEST 6: Tied Uncertainties")
    print("=" * 80)

    n_samples = 100

    # All uncertainties are the same!
    uncertainties = List()
    uncertainties.append(np.full(n_samples, 0.5))

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[10, 20],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with tied uncertainties")
        print(f"  F1 (budget=0): {results[0]['F1']:.4f}")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")
        print(f"  Rejections (budget=10): {results[10]['monthly_Rejections'][0]}")

        # With all ties, rejection selection is arbitrary but should still work
        return {'status': 'pass', 'results': results}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_nan_inf_values():
    """Test rejection with NaN/Inf uncertainty values"""
    print("\n" + "=" * 80)
    print("TEST 7: NaN/Inf Values in Uncertainties")
    print("=" * 80)

    n_samples = 100

    uncertainties_data = np.random.random(n_samples)
    # Inject some NaN and Inf
    uncertainties_data[10:15] = np.nan
    uncertainties_data[20:25] = np.inf

    uncertainties = List()
    uncertainties.append(uncertainties_data)

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[10],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with NaN/Inf values")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")

        if np.isnan(results[10]['F1']) or np.isinf(results[10]['F1']):
            print(f"  ⚠️ Warning: Returns NaN/Inf in results")
            return {'status': 'warning', 'returns_nan': True}
        else:
            print(f"  ✓ Handles NaN/Inf gracefully")
            return {'status': 'pass', 'handles_nan': True}

    except Exception as e:
        print(f"⚠️ Expected failure with NaN/Inf: {type(e).__name__}")
        print(f"  This is ACCEPTABLE - NaN/Inf should be cleaned before simulation")
        return {'status': 'expected_fail', 'error': str(e)}


def test_single_month():
    """Test rejection with only one month of data"""
    print("\n" + "=" * 80)
    print("TEST 8: Single Month")
    print("=" * 80)

    n_samples = 100

    # Only 1 month
    uncertainties = List()
    uncertainties.append(np.random.random(n_samples))

    predictions = List()
    predictions.append(np.random.randint(0, 2, n_samples))

    labels = List()
    labels.append(np.random.randint(0, 2, n_samples))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[10, 20],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with single month")
        print(f"  F1 (budget=0): {results[0]['F1']:.4f}")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")
        print(f"  Monthly metrics shape: {len(results[10]['monthly_F1'])}")

        return {'status': 'pass', 'results': results}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_multi_month_progression():
    """Test rejection across multiple months with realistic progression"""
    print("\n" + "=" * 80)
    print("TEST 9: Multi-Month Progression")
    print("=" * 80)

    n_months = 12
    n_samples_per_month = 50

    uncertainties = List()
    predictions = List()
    labels = List()

    for month in range(n_months):
        uncertainties.append(np.random.random(n_samples_per_month))
        predictions.append(np.random.randint(0, 2, n_samples_per_month))
        labels.append(np.random.randint(0, 2, n_samples_per_month))

    try:
        results = PostHocRejectorSimulator(
            uncertainties, predictions, labels,
            rejection_Ns=[5, 10],
            upto_reject=False,
            method='dual_thresh_compounded'
        )

        print(f"✓ No crash with {n_months} months")
        print(f"  F1 (budget=0): {results[0]['F1']:.4f}")
        print(f"  F1 (budget=5): {results[5]['F1']:.4f}")
        print(f"  F1 (budget=10): {results[10]['F1']:.4f}")

        # Check monthly progression
        print(f"\n  Monthly rejections (budget=10):")
        for month, rej_count in enumerate(results[10]['monthly_Rejections']):
            print(f"    Month {month+1}: {rej_count} rejections")

        # Verify compounded method increases threshold over time
        rejections = results[10]['monthly_Rejections']
        if all(rejections >= 0):
            print(f"  ✓ All monthly rejections are non-negative")
            return {'status': 'pass', 'results': results}
        else:
            print(f"  ✗ FAIL: Negative rejections found!")
            return {'status': 'fail', 'negative_rejections': True}

    except Exception as e:
        print(f"✗ FAIL: {type(e).__name__}: {e}")
        return {'status': 'fail', 'error': str(e)}


def test_all_methods():
    """Test all rejection methods"""
    print("\n" + "=" * 80)
    print("TEST 10: All Rejection Methods")
    print("=" * 80)

    n_samples = 100
    n_months = 3

    uncertainties = List()
    predictions = List()
    labels = List()

    for month in range(n_months):
        uncertainties.append(np.random.random(n_samples))
        predictions.append(np.random.randint(0, 2, n_samples))
        labels.append(np.random.randint(0, 2, n_samples))

    methods = [
        'single_thresh_simple',
        'single_thresh_compounded',
        'dual_thresh_simple',
        'dual_thresh_compounded'
    ]

    method_results = {}

    for method in methods:
        try:
            results = PostHocRejectorSimulator(
                uncertainties, predictions, labels,
                rejection_Ns=[10],
                upto_reject=False,
                method=method
            )

            print(f"  ✓ {method}: F1={results[10]['F1']:.4f}")
            method_results[method] = {'status': 'pass', 'f1': results[10]['F1']}

        except Exception as e:
            print(f"  ✗ {method}: {type(e).__name__}: {e}")
            method_results[method] = {'status': 'fail', 'error': str(e)}

    # Summary
    passing = sum(1 for r in method_results.values() if r['status'] == 'pass')
    print(f"\n  Summary: {passing}/{len(methods)} methods pass")

    if passing == len(methods):
        print(f"  ✓ ALL METHODS PASS")
        return {'status': 'pass', 'results': method_results}
    else:
        print(f"  ⚠️ Some methods failed")
        return {'status': 'partial', 'results': method_results}


def run_all_tests():
    """Run all edge case tests"""
    print("\n" + "=" * 80)
    print("REJECTION SIMULATION EDGE CASES TEST SUITE")
    print("Testing robustness with problematic inputs")
    print("=" * 80)

    results = {}

    results['empty_arrays'] = test_empty_arrays()
    results['insufficient_samples'] = test_insufficient_samples()
    results['all_same_predictions'] = test_all_same_predictions()
    results['budget_zero'] = test_budget_zero()
    results['budget_exceeds'] = test_budget_exceeds_samples()
    results['tied_uncertainties'] = test_tied_uncertainties()
    results['nan_inf'] = test_nan_inf_values()
    results['single_month'] = test_single_month()
    results['multi_month'] = test_multi_month_progression()
    results['all_methods'] = test_all_methods()

    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY OF RESULTS")
    print("=" * 80)

    statuses = [r['status'] for r in results.values()]
    pass_count = statuses.count('pass')
    fail_count = statuses.count('fail')
    warning_count = statuses.count('warning')
    expected_fail_count = statuses.count('expected_fail')

    print(f"\n✓ Pass: {pass_count}")
    print(f"⚠️ Warning: {warning_count}")
    print(f"✗ Fail: {fail_count}")
    print(f"⚠️ Expected Fail: {expected_fail_count}")

    print(f"\nTotal: {len(results)} tests")

    if fail_count == 0:
        print(f"\n✅ ALL CRITICAL TESTS PASS")
        print(f"   Rejection simulator is robust to edge cases")
    else:
        print(f"\n⚠️ SOME TESTS FAILED")
        print(f"   Review failed tests and fix issues")

    # Failed tests detail
    if fail_count > 0:
        print(f"\nFailed tests:")
        for test_name, result in results.items():
            if result['status'] == 'fail':
                print(f"  - {test_name}: {result.get('error', 'Unknown error')}")

    print("\n" + "=" * 80)
    print("✓ EDGE CASES TEST SUITE COMPLETE")
    print("=" * 80)

    return results


if __name__ == "__main__":
    results = run_all_tests()

    print("\n" + "=" * 80)
    print("RECOMMENDATIONS")
    print("=" * 80)
    print("\n1. Input Validation:")
    print("   - Add checks for empty arrays before simulation")
    print("   - Warn if sample count < 20 (metrics may be unreliable)")
    print("   - Clean NaN/Inf values before simulation")
    print("\n2. Documentation:")
    print("   - Document expected input format")
    print("   - Document minimum sample requirements")
    print("   - Document budget behavior (0, exceeds N)")
    print("\n3. Error Handling:")
    print("   - Return NaN gracefully for insufficient samples")
    print("   - Handle tied uncertainties deterministically")
    print("   - Clip budget to valid range [0, N]")
