"""
Validation module for Aurora framework

Implements comprehensive validation inspired by the critical check_results function.
Validates data integrity, completeness, and consistency.
"""

from dataclasses import dataclass, field
from typing import List, Dict, Set, Any, Optional, Tuple
from collections import Counter
import numpy as np
import pandas as pd

from .schema import ExperimentResult, ResultsCollection


@dataclass
class ValidationIssue:
    """Represents a single validation issue"""
    severity: str  # "error", "warning", "info"
    category: str  # "completeness", "consistency", "format", "data_integrity"
    message: str
    details: Dict[str, Any] = field(default_factory=dict)

    def __str__(self) -> str:
        detail_str = ", ".join(f"{k}={v}" for k, v in self.details.items())
        return f"[{self.severity.upper()}] {self.category}: {self.message} ({detail_str})"


@dataclass
class ValidationReport:
    """
    Comprehensive validation report

    Tracks all issues found during validation and provides summary statistics.
    """
    issues: List[ValidationIssue] = field(default_factory=list)
    summary: Dict[str, Any] = field(default_factory=dict)

    def add_error(self, category: str, message: str, **details):
        """Add an error to the report"""
        self.issues.append(ValidationIssue("error", category, message, details))

    def add_warning(self, category: str, message: str, **details):
        """Add a warning to the report"""
        self.issues.append(ValidationIssue("warning", category, message, details))

    def add_info(self, category: str, message: str, **details):
        """Add an info message to the report"""
        self.issues.append(ValidationIssue("info", category, message, details))

    def has_errors(self) -> bool:
        """Check if report contains any errors"""
        return any(issue.severity == "error" for issue in self.issues)

    def has_warnings(self) -> bool:
        """Check if report contains any warnings"""
        return any(issue.severity == "warning" for issue in self.issues)

    def get_errors(self) -> List[ValidationIssue]:
        """Get all errors"""
        return [i for i in self.issues if i.severity == "error"]

    def get_warnings(self) -> List[ValidationIssue]:
        """Get all warnings"""
        return [i for i in self.issues if i.severity == "warning"]

    def __str__(self) -> str:
        lines = ["=== Validation Report ==="]
        lines.append(f"Total Issues: {len(self.issues)}")
        lines.append(f"Errors: {len(self.get_errors())}")
        lines.append(f"Warnings: {len(self.get_warnings())}")
        lines.append("")

        if self.issues:
            lines.append("Issues:")
            for issue in self.issues:
                lines.append(f"  {issue}")
        else:
            lines.append("✓ No issues found")

        if self.summary:
            lines.append("")
            lines.append("Summary:")
            for key, value in self.summary.items():
                lines.append(f"  {key}: {value}")

        return "\n".join(lines)


class ResultsValidator:
    """
    Validates experimental results for integrity and completeness

    This class implements comprehensive validation inspired by the critical
    check_results function, with additional checks for data quality.
    """

    # Expected number of test months per dataset
    EXPECTED_MONTHS = {
        "androzoo": 22,
        "apigraph": 70,
        "transcendent": 46,
    }

    def __init__(
        self,
        strict: bool = True,
        expected_months: Optional[Dict[str, int]] = None
    ):
        """
        Initialize validator

        Args:
            strict: If True, treat warnings as errors
            expected_months: Override default expected months per dataset
        """
        self.strict = strict
        self.expected_months = expected_months or self.EXPECTED_MONTHS

    def validate(
        self,
        collection: ResultsCollection,
        grouping_keys: Optional[List[str]] = None
    ) -> ValidationReport:
        """
        Comprehensive validation of results collection

        Args:
            collection: Results to validate
            grouping_keys: Keys to use for grouping validation
                          Default: ["Dataset", "Trainer-Mode", "Monthly-Label-Budget", "Random-Seed"]

        Returns:
            ValidationReport with all issues found

        This implements the logic from check_results() plus additional checks:
        1. Completeness: All expected months present
        2. Uniqueness: No duplicate months
        3. Progression: Months progress sequentially
        4. Data integrity: Arrays are valid
        5. Consistency: Field values are consistent within groups
        """
        report = ValidationReport()

        if not collection.results:
            report.add_error("completeness", "Empty results collection")
            return report

        # Default grouping keys
        if grouping_keys is None:
            grouping_keys = [
                "dataset",
                "trainer_mode",
                "monthly_label_budget",
            ]

            # Add Random-Seed if present
            first_result = collection.results[0]
            if first_result.get_hyperparameter("Random-Seed") is not None:
                grouping_keys.append("Random-Seed")

        # Convert to DataFrame for check_results-style validation
        df = self._collection_to_dataframe(collection)

        # Perform grouped validation (like check_results)
        self._validate_grouped(df, grouping_keys, report)

        # Additional validations
        self._validate_base_names(collection, report)
        self._validate_data_arrays(collection, report)
        self._validate_cutoff_months(collection, report)

        # Generate summary
        self._generate_summary(collection, report)

        return report

    def _collection_to_dataframe(self, collection: ResultsCollection) -> pd.DataFrame:
        """Convert collection to DataFrame for validation"""
        rows = []
        for result in collection.results:
            row = {
                "Test-Month": result.test_month,
                "Dataset": result.dataset,
                "Monthly-Label-Budget": result.monthly_label_budget,
                "Trainer-Mode": result.trainer_mode,
                "Sampler-Mode": result.sampler_mode,
                "Base-Name": result.base_name,
            }
            # Add hyperparameters
            row.update(result.hyperparameters)
            rows.append(row)

        return pd.DataFrame(rows)

    def _validate_grouped(
        self,
        df: pd.DataFrame,
        grouping_keys: List[str],
        report: ValidationReport
    ):
        """
        Validate grouped results (implements check_results logic)

        This is the CRITICAL validation that mirrors check_results()!
        """
        # Convert field names to DataFrame column names
        column_mapping = {
            "dataset": "Dataset",
            "trainer_mode": "Trainer-Mode",
            "monthly_label_budget": "Monthly-Label-Budget",
            "sampler_mode": "Sampler-Mode",
            "base_name": "Base-Name",
        }

        # Map grouping keys to column names
        df_grouping_keys = []
        for key in grouping_keys:
            if key in column_mapping:
                df_grouping_keys.append(column_mapping[key])
            else:
                # Hyperparameter key
                df_grouping_keys.append(key)

        # Group and aggregate (like check_results)
        try:
            grouped = df.groupby(df_grouping_keys).agg({"Test-Month": list})
        except KeyError as e:
            report.add_error(
                "format",
                f"Missing grouping key in data: {e}",
                grouping_keys=df_grouping_keys
            )
            return

        # Validate each group
        for group_key, row in grouped.iterrows():
            months = row["Test-Month"]

            # Check for duplicates
            month_counts = Counter(months)
            duplicates = {m: c for m, c in month_counts.items() if c > 1}

            if duplicates:
                report.add_error(
                    "uniqueness",
                    "Duplicate test months found in group",
                    group=group_key,
                    duplicates=duplicates
                )

            # Check monthly progression
            sorted_months = sorted(months)
            if len(sorted_months) > 1:
                diffs = np.diff(sorted_months)
                expected_diff = 1
                irregular_diffs = set(diffs) - {expected_diff}

                if irregular_diffs:
                    # Check if it's a missing month or a gap
                    if max(irregular_diffs) > 1:
                        report.add_error(
                            "completeness",
                            "Non-sequential month progression (missing months)",
                            group=group_key,
                            months=sorted_months,
                            diffs=list(diffs)
                        )
                    else:
                        report.add_warning(
                            "consistency",
                            "Irregular month progression",
                            group=group_key,
                            diffs=list(irregular_diffs)
                        )

            # Check expected number of months
            group_dict = dict(zip(df_grouping_keys, group_key)) if isinstance(group_key, tuple) else {df_grouping_keys[0]: group_key}
            dataset = group_dict.get("Dataset")

            if dataset in self.expected_months:
                expected = self.expected_months[dataset]
                actual = len(sorted_months)

                # Account for 0-indexed months
                max_month = max(sorted_months)

                if max_month > expected:
                    report.add_warning(
                        "completeness",
                        f"Test month {max_month} exceeds expected maximum {expected}",
                        group=group_key,
                        dataset=dataset
                    )

                if actual < expected:
                    report.add_info(
                        "completeness",
                        f"Fewer months than expected ({actual} < {expected})",
                        group=group_key,
                        dataset=dataset
                    )

    def _validate_base_names(
        self,
        collection: ResultsCollection,
        report: ValidationReport
    ):
        """
        Validate that base names are unique within groupings

        Base-Name should uniquely identify a method within
        (dataset, monthly_label_budget, sampler_mode) combinations.
        """
        # Group by dataset, budget, sampler, base_name
        groups = collection.group_by(
            "dataset",
            "monthly_label_budget",
            "sampler_mode",
            "base_name"
        )

        # Check for conflicting configurations with same base name
        for group_key, results in groups.items():
            dataset, budget, sampler, base_name = group_key

            # Check if trainer_mode is consistent
            trainer_modes = set(r.trainer_mode for r in results)
            if len(trainer_modes) > 1:
                report.add_error(
                    "consistency",
                    "Same base name used with different trainer modes",
                    base_name=base_name,
                    dataset=dataset,
                    budget=budget,
                    trainer_modes=list(trainer_modes)
                )

    def _validate_data_arrays(
        self,
        collection: ResultsCollection,
        report: ValidationReport
    ):
        """Validate data array integrity"""
        for i, result in enumerate(collection.results):
            # Check for empty arrays
            if len(result.predictions) == 0:
                report.add_error(
                    "data_integrity",
                    "Empty predictions array",
                    result_index=i,
                    month=result.test_month
                )

            # Check for NaN or infinite values in uncertainties
            if np.any(~np.isfinite(result.uncertainties_month_ahead)):
                report.add_warning(
                    "data_integrity",
                    "Non-finite values in uncertainties_month_ahead",
                    result_index=i,
                    month=result.test_month
                )

            if np.any(~np.isfinite(result.uncertainties_past_month)):
                report.add_warning(
                    "data_integrity",
                    "Non-finite values in uncertainties_past_month",
                    result_index=i,
                    month=result.test_month
                )

            # Check class balance (warn if too imbalanced)
            pos_rate = np.mean(result.labels)
            if pos_rate < 0.01 or pos_rate > 0.99:
                report.add_warning(
                    "data_integrity",
                    f"Extreme class imbalance (positive rate: {pos_rate:.1%})",
                    result_index=i,
                    month=result.test_month,
                    dataset=result.dataset
                )

    def _validate_cutoff_months(
        self,
        collection: ResultsCollection,
        report: ValidationReport
    ):
        """Validate that no results exceed expected cutoff months"""
        for result in collection.results:
            if result.dataset in self.expected_months:
                max_expected = self.expected_months[result.dataset]
                if result.test_month > max_expected:
                    report.add_warning(
                        "completeness",
                        f"Month {result.test_month} exceeds expected cutoff {max_expected}",
                        dataset=result.dataset,
                        month=result.test_month
                    )

    def _generate_summary(
        self,
        collection: ResultsCollection,
        report: ValidationReport
    ):
        """Generate summary statistics"""
        report.summary = {
            "total_results": len(collection.results),
            "datasets": len(collection.get_unique_values("dataset")["dataset"]),
            "unique_base_names": len(collection.get_unique_values("base_name")["base_name"]),
            "validation_passed": not report.has_errors(),
        }

        # Add per-dataset counts
        groups = collection.group_by("dataset")
        for (dataset,), results in groups.items():
            months = sorted(set(r.test_month for r in results))
            report.summary[f"{dataset}_months"] = f"{min(months)}-{max(months)} ({len(months)} total)"


def quick_validate(collection: ResultsCollection) -> bool:
    """
    Quick validation check (raises exception if errors found)

    Args:
        collection: Results to validate

    Returns:
        True if validation passed

    Raises:
        ValidationError: If validation fails
    """
    validator = ResultsValidator()
    report = validator.validate(collection)

    if report.has_errors():
        error_msg = "\n".join(str(e) for e in report.get_errors())
        raise ValidationError(f"Validation failed:\n{error_msg}")

    return True


class ValidationError(Exception):
    """Raised when validation fails"""
    pass
