#!/usr/bin/env python3
"""
AURC Validation Test Suite

Tests the correctness of AURC computation and investigates scaling discrepancy.

Background:
- Paper shows AURC values in 60-100 range
- Aurora shows AURC values in 0-10 range
- Need to determine: scaling factor, E-AURC vs raw AURC, or different metric?
"""

import sys
sys.path.insert(0, 'src')
sys.path.insert(0, '.')

import numpy as np
from aurora import compute_aurc, compute_metrics_numba


def test_perfect_oracle():
    """
    Test AURC with perfect uncertainty quantification.

    Expected: AURC ≈ 0 (raw AURC)
              E-AURC ≈ 0 (excess over optimal)
    """
    print("=" * 80)
    print("TEST 1: Perfect Oracle (uncertainty = 0 when correct, 1 when wrong)")
    print("=" * 80)

    n_samples = 1000

    # Create predictions with 10% error rate
    labels = np.random.randint(0, 2, n_samples)
    predictions = labels.copy()

    # Introduce 10% errors
    error_indices = np.random.choice(n_samples, size=int(0.1 * n_samples), replace=False)
    predictions[error_indices] = 1 - predictions[error_indices]

    # Perfect oracle: uncertainty = 1.0 when wrong, 0.0 when correct
    uncertainties = np.zeros(n_samples)
    uncertainties[predictions != labels] = 1.0

    # Compute both raw AURC and E-AURC
    aurc_raw = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=False)
    aurc_excess = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=True)

    error_rate = np.mean(predictions != labels)

    print(f"\nError rate: {error_rate:.4f} ({error_rate*100:.2f}%)")
    print(f"Raw AURC: {aurc_raw:.6f} (×100 = {aurc_raw*100:.4f})")
    print(f"E-AURC:   {aurc_excess:.6f} (×100 = {aurc_excess*100:.4f})")
    print(f"\nExpected: E-AURC ≈ 0 (perfect uncertainty)")
    print(f"Status: {'✓ PASS' if abs(aurc_excess) < 0.01 else '✗ FAIL'}")

    return {
        'error_rate': error_rate,
        'aurc_raw': aurc_raw,
        'aurc_excess': aurc_excess,
        'aurc_raw_scaled': aurc_raw * 100,
        'aurc_excess_scaled': aurc_excess * 100
    }


def test_random_uncertainty():
    """
    Test AURC with random (uninformative) uncertainty.

    Expected: E-AURC ≈ 0 (no better than random)
              Raw AURC ≈ theoretical baseline
    """
    print("\n" + "=" * 80)
    print("TEST 2: Random Uncertainty (no signal)")
    print("=" * 80)

    n_samples = 1000

    # Create predictions with 20% error rate
    labels = np.random.randint(0, 2, n_samples)
    predictions = labels.copy()
    error_indices = np.random.choice(n_samples, size=int(0.2 * n_samples), replace=False)
    predictions[error_indices] = 1 - predictions[error_indices]

    # Random uncertainty (no signal)
    uncertainties = np.random.random(n_samples)

    aurc_raw = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=False)
    aurc_excess = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=True)

    error_rate = np.mean(predictions != labels)

    print(f"\nError rate: {error_rate:.4f} ({error_rate*100:.2f}%)")
    print(f"Raw AURC: {aurc_raw:.6f} (×100 = {aurc_raw*100:.4f})")
    print(f"E-AURC:   {aurc_excess:.6f} (×100 = {aurc_excess*100:.4f})")
    print(f"\nExpected: E-AURC ≈ 0 (random uncertainty provides no information)")
    print(f"Status: {'✓ PASS' if abs(aurc_excess) < 0.05 else '⚠ MARGINAL' if abs(aurc_excess) < 0.1 else '✗ FAIL'}")

    return {
        'error_rate': error_rate,
        'aurc_raw': aurc_raw,
        'aurc_excess': aurc_excess,
        'aurc_raw_scaled': aurc_raw * 100,
        'aurc_excess_scaled': aurc_excess * 100
    }


def test_worst_case_uncertainty():
    """
    Test AURC with worst-case (inverted) uncertainty.

    Expected: E-AURC > 0 (worse than random)
              High positive excess
    """
    print("\n" + "=" * 80)
    print("TEST 3: Worst-Case Uncertainty (inverted signal)")
    print("=" * 80)

    n_samples = 1000

    # Create predictions with 15% error rate
    labels = np.random.randint(0, 2, n_samples)
    predictions = labels.copy()
    error_indices = np.random.choice(n_samples, size=int(0.15 * n_samples), replace=False)
    predictions[error_indices] = 1 - predictions[error_indices]

    # Worst-case: uncertainty = 0.0 when wrong, 1.0 when correct (inverted!)
    uncertainties = np.ones(n_samples)
    uncertainties[predictions != labels] = 0.0

    aurc_raw = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=False)
    aurc_excess = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=True)

    error_rate = np.mean(predictions != labels)

    print(f"\nError rate: {error_rate:.4f} ({error_rate*100:.2f}%)")
    print(f"Raw AURC: {aurc_raw:.6f} (×100 = {aurc_raw*100:.4f})")
    print(f"E-AURC:   {aurc_excess:.6f} (×100 = {aurc_excess*100:.4f})")
    print(f"\nExpected: E-AURC > 0 (inverted uncertainty is harmful)")
    print(f"Status: {'✓ PASS' if aurc_excess > 0.01 else '✗ FAIL'}")

    return {
        'error_rate': error_rate,
        'aurc_raw': aurc_raw,
        'aurc_excess': aurc_excess,
        'aurc_raw_scaled': aurc_raw * 100,
        'aurc_excess_scaled': aurc_excess * 100
    }


def test_realistic_malware_scenario():
    """
    Test AURC with realistic malware detection scenario.

    Simulates:
    - 85% accuracy (15% error rate)
    - Good but not perfect uncertainty calibration
    """
    print("\n" + "=" * 80)
    print("TEST 4: Realistic Malware Detection Scenario")
    print("=" * 80)

    n_samples = 5000

    # Balanced dataset
    labels = np.concatenate([np.zeros(2500), np.ones(2500)])
    np.random.shuffle(labels)

    # 85% accuracy classifier
    predictions = labels.copy()
    error_indices = np.random.choice(n_samples, size=int(0.15 * n_samples), replace=False)
    predictions[error_indices] = 1 - predictions[error_indices]

    # Good uncertainty: correlated with errors but not perfect
    # Higher uncertainty for errors, but with noise
    uncertainties = np.random.random(n_samples) * 0.3  # Base uncertainty
    uncertainties[predictions != labels] += np.random.random(int(0.15 * n_samples)) * 0.7  # More uncertain for errors
    uncertainties = np.clip(uncertainties, 0, 1)

    aurc_raw = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=False)
    aurc_excess = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=True)

    # Compute F1 as well
    f1, fnr, fpr = compute_metrics_numba(labels.astype(np.int64), predictions.astype(np.int64))
    error_rate = np.mean(predictions != labels)

    print(f"\nF1 Score: {f1:.4f} ({f1*100:.2f}%)")
    print(f"Error rate: {error_rate:.4f} ({error_rate*100:.2f}%)")
    print(f"Raw AURC: {aurc_raw:.6f} (×100 = {aurc_raw*100:.4f})")
    print(f"E-AURC:   {aurc_excess:.6f} (×100 = {aurc_excess*100:.4f})")
    print(f"\nExpected: E-AURC slightly positive (good but not perfect uncertainty)")
    print(f"Status: {'✓ PASS' if -0.01 < aurc_excess < 0.1 else '⚠ CHECK'}")

    return {
        'f1': f1,
        'error_rate': error_rate,
        'aurc_raw': aurc_raw,
        'aurc_excess': aurc_excess,
        'aurc_raw_scaled': aurc_raw * 100,
        'aurc_excess_scaled': aurc_excess * 100
    }


def test_paper_comparison():
    """
    Compare Aurora AURC values with paper values.

    This test loads actual results and compares scaling.
    """
    print("\n" + "=" * 80)
    print("TEST 5: Paper Comparison - Scaling Investigation")
    print("=" * 80)

    # Sample values from paper (BM=100)
    paper_values = {
        'CADE (cold) - OOD': {'azoo': 87, 'apig': 73, 'trans': 100},
        'HCC (warm) - MSP': {'azoo': 90, 'apig': 85, 'trans': 92},
        'SVC - Margin': {'azoo': 88, 'apig': 69, 'trans': 93},
    }

    # Sample values from Aurora (BM=100)
    aurora_values = {
        'CADE (cold) - OOD': {'azoo': 5, 'apig': 2, 'trans': 5},
        'HCC (warm) - MSP': {'azoo': 1, 'apig': 1, 'trans': 2},
        'SVC - Margin': {'azoo': 2, 'apig': 1, 'trans': 3},
    }

    print("\n{:<25} {:<10} {:<12} {:<12} {:<15}".format(
        "Method", "Dataset", "Paper", "Aurora", "Scaling Factor"
    ))
    print("-" * 80)

    all_ratios = []

    for method in paper_values:
        for dataset in paper_values[method]:
            paper_val = paper_values[method][dataset]
            aurora_val = aurora_values[method][dataset]

            ratio = paper_val / aurora_val if aurora_val > 0 else np.nan
            all_ratios.append(ratio)

            print("{:<25} {:<10} {:<12} {:<12} {:<15.2f}".format(
                method[:25], dataset, paper_val, aurora_val, ratio
            ))

    mean_ratio = np.mean([r for r in all_ratios if not np.isnan(r)])
    std_ratio = np.std([r for r in all_ratios if not np.isnan(r)])

    print("\n" + "-" * 80)
    print(f"Mean scaling factor: {mean_ratio:.2f} ± {std_ratio:.2f}")
    print(f"\nHypothesis 1: Paper uses AURC × {mean_ratio:.0f}")
    print(f"Hypothesis 2: Paper uses different normalization")
    print(f"Hypothesis 3: Paper uses E-AURC with different baseline")

    # Check if consistent scaling factor
    if std_ratio / mean_ratio < 0.3:  # CV < 30%
        print(f"\n✓ Scaling factor is consistent (CV = {(std_ratio/mean_ratio)*100:.1f}%)")
    else:
        print(f"\n⚠ Scaling factor varies significantly (CV = {(std_ratio/mean_ratio)*100:.1f}%)")

    return {
        'mean_ratio': mean_ratio,
        'std_ratio': std_ratio,
        'all_ratios': all_ratios
    }


def test_aurc_properties():
    """
    Test mathematical properties of AURC.

    Properties:
    1. 0 ≤ AURC ≤ 1 (raw AURC is a probability/area)
    2. E-AURC can be negative (better than optimal)
    3. AURC increases with error rate (all else equal)
    """
    print("\n" + "=" * 80)
    print("TEST 6: Mathematical Properties of AURC")
    print("=" * 80)

    n_samples = 2000
    labels = np.random.randint(0, 2, n_samples)

    results = []

    print("\nTesting AURC vs Error Rate:")
    print("{:<15} {:<15} {:<15} {:<15}".format(
        "Error Rate", "Raw AURC", "E-AURC", "AURC×100"
    ))
    print("-" * 60)

    for error_rate_target in [0.05, 0.10, 0.15, 0.20, 0.25, 0.30]:
        predictions = labels.copy()
        n_errors = int(error_rate_target * n_samples)
        error_indices = np.random.choice(n_samples, size=n_errors, replace=False)
        predictions[error_indices] = 1 - predictions[error_indices]

        # Good uncertainty (correlated with errors)
        uncertainties = np.random.random(n_samples) * 0.3
        uncertainties[predictions != labels] += np.random.random(n_errors) * 0.7
        uncertainties = np.clip(uncertainties, 0, 1)

        aurc_raw = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=False)
        aurc_excess = compute_aurc([labels], [predictions], [uncertainties], compute_eaurc=True)

        actual_error_rate = np.mean(predictions != labels)

        print("{:<15.3f} {:<15.6f} {:<15.6f} {:<15.2f}".format(
            actual_error_rate, aurc_raw, aurc_excess, aurc_raw * 100
        ))

        results.append({
            'error_rate': actual_error_rate,
            'aurc_raw': aurc_raw,
            'aurc_excess': aurc_excess
        })

    # Check property: AURC should generally increase with error rate
    aurc_values = [r['aurc_raw'] for r in results]
    is_monotonic = all(aurc_values[i] <= aurc_values[i+1] for i in range(len(aurc_values)-1))

    print("\n" + "-" * 60)
    print(f"Property check: AURC increases with error rate")
    print(f"Status: {'✓ PASS (monotonic)' if is_monotonic else '⚠ Not strictly monotonic (due to random uncertainty)'}")

    return results


def run_all_tests():
    """Run all AURC validation tests."""
    print("\n" + "=" * 80)
    print("AURC VALIDATION TEST SUITE")
    print("Investigating scaling discrepancy: Paper (60-100) vs Aurora (0-10)")
    print("=" * 80)

    results = {}

    # Run all tests
    results['perfect_oracle'] = test_perfect_oracle()
    results['random_uncertainty'] = test_random_uncertainty()
    results['worst_case'] = test_worst_case_uncertainty()
    results['realistic'] = test_realistic_malware_scenario()
    results['paper_comparison'] = test_paper_comparison()
    results['properties'] = test_aurc_properties()

    # Summary
    print("\n" + "=" * 80)
    print("SUMMARY OF FINDINGS")
    print("=" * 80)

    print("\n1. AURC Range:")
    print(f"   - Raw AURC is typically 0.0 to 0.3 (0% to 30%)")
    print(f"   - Multiplying by 100 gives 0 to 30 range")
    print(f"   - Aurora shows 0-10 range (raw: 0.0-0.1)")
    print(f"   - Paper shows 60-100 range")

    print("\n2. Scaling Factor:")
    ratio = results['paper_comparison']['mean_ratio']
    print(f"   - Mean Paper/Aurora ratio: {ratio:.2f}")
    print(f"   - This suggests paper uses ~{ratio:.0f}× Aurora values")

    print("\n3. Hypotheses:")
    print(f"   H1: Paper uses different AURC formula")
    print(f"   H2: Paper uses E-AURC with different baseline")
    print(f"   H3: Paper uses complementary metric (100 - something)")
    print(f"   H4: Paper uses different risk metric in risk-coverage curve")

    print("\n4. Recommendations:")
    print(f"   ✓ Read paper Section on AURC definition carefully")
    print(f"   ✓ Check if paper uses 'selection rate' vs 'coverage'")
    print(f"   ✓ Check if paper uses 'accuracy' vs 'risk' on y-axis")
    print(f"   ✓ Verify Aurora is computing E-AURC correctly")

    print("\n" + "=" * 80)
    print("✓ AURC VALIDATION TESTS COMPLETE")
    print("=" * 80)

    return results


if __name__ == "__main__":
    results = run_all_tests()

    print("\n" + "=" * 80)
    print("TEST RESULTS SAVED")
    print("=" * 80)
    print("\nNext steps:")
    print("1. Review paper methodology section for AURC definition")
    print("2. Determine correct scaling factor or formula")
    print("3. Update Aurora implementation if needed")
    print("4. Document findings in metrics guide")
