"""
Data Quality Module for Aurora Framework

Provides comprehensive data quality checks and validation for experimental results.
"""

from dataclasses import dataclass, field
from typing import Dict, List, Set, Optional, Tuple, Any
from collections import Counter, defaultdict
import pandas as pd
import numpy as np

from .schema import ResultsCollection, ExperimentResult


@dataclass
class DataQualityIssue:
    """Represents a single data quality issue"""

    severity: str  # 'error', 'warning', 'info'
    category: str  # 'completeness', 'consistency', 'uniqueness', 'progression'
    message: str
    context: Dict[str, Any] = field(default_factory=dict)

    def __str__(self) -> str:
        severity_symbol = {
            'error': '❌ ERROR',
            'warning': '⚠️  WARNING',
            'info': 'ℹ️  INFO'
        }
        return f"{severity_symbol.get(self.severity, '?')} [{self.category}] {self.message}"


@dataclass
class DataQualityReport:
    """Comprehensive data quality report"""

    collection_name: str
    total_results: int
    issues: List[DataQualityIssue] = field(default_factory=list)
    summary_stats: Dict[str, Any] = field(default_factory=dict)

    def add_issue(self, severity: str, category: str, message: str, **context):
        """Add a data quality issue"""
        self.issues.append(DataQualityIssue(
            severity=severity,
            category=category,
            message=message,
            context=context
        ))

    def has_errors(self) -> bool:
        """Check if report contains errors"""
        return any(issue.severity == 'error' for issue in self.issues)

    def has_warnings(self) -> bool:
        """Check if report contains warnings"""
        return any(issue.severity == 'warning' for issue in self.issues)

    def get_errors(self) -> List[DataQualityIssue]:
        """Get all errors"""
        return [issue for issue in self.issues if issue.severity == 'error']

    def get_warnings(self) -> List[DataQualityIssue]:
        """Get all warnings"""
        return [issue for issue in self.issues if issue.severity == 'warning']

    def print_report(self, show_info: bool = False, max_issues: int = 50, declutter: bool = True):
        """
        Print formatted report

        Args:
            show_info: Include info-level issues
            max_issues: Maximum issues to show in detail
            declutter: Group similar issues for cleaner output
        """
        print("=" * 80)
        print(f"DATA QUALITY REPORT: {self.collection_name}")
        print("=" * 80)
        print(f"\nTotal results: {self.total_results}")

        # Summary stats
        if self.summary_stats:
            print("\n--- Summary Statistics ---")
            for key, value in self.summary_stats.items():
                print(f"  {key}: {value}")

        # Count by severity
        error_count = len(self.get_errors())
        warning_count = len(self.get_warnings())
        info_count = len([i for i in self.issues if i.severity == 'info'])

        print(f"\n--- Issue Summary ---")
        print(f"  Errors: {error_count}")
        print(f"  Warnings: {warning_count}")
        print(f"  Info: {info_count}")

        # Show issues
        issues_to_show = self.issues
        if not show_info:
            issues_to_show = [i for i in issues_to_show if i.severity != 'info']

        if declutter:
            self._print_decluttered_issues(issues_to_show, max_issues)
        else:
            self._print_all_issues(issues_to_show, max_issues)

        print("\n" + "=" * 80)

        # Final verdict
        if self.has_errors():
            print("❌ DATA QUALITY CHECK FAILED - Errors detected")
        elif self.has_warnings():
            print("⚠️  DATA QUALITY CHECK PASSED WITH WARNINGS")
        else:
            print("✅ DATA QUALITY CHECK PASSED")
        print("=" * 80)

    def _print_all_issues(self, issues_to_show: List[DataQualityIssue], max_issues: int):
        """Print all issues with full details"""
        if issues_to_show:
            print(f"\n--- Issues (showing up to {max_issues}) ---")
            for i, issue in enumerate(issues_to_show[:max_issues]):
                print(f"\n{i+1}. {issue}")
                if issue.context:
                    for key, value in issue.context.items():
                        # Truncate long values
                        if isinstance(value, (list, set)) and len(value) > 10:
                            value = f"{list(value)[:10]}... ({len(value)} items)"
                        elif isinstance(value, str) and len(value) > 100:
                            value = value[:100] + "..."
                        print(f"   {key}: {value}")

            if len(issues_to_show) > max_issues:
                print(f"\n... and {len(issues_to_show) - max_issues} more issues")

    def _print_decluttered_issues(self, issues_to_show: List[DataQualityIssue], max_issues: int):
        """Print issues grouped by type for cleaner output"""
        if not issues_to_show:
            print("\n✅ No issues found!")
            return

        # Group issues by (severity, category, message pattern)
        from collections import defaultdict
        issue_groups = defaultdict(list)

        for issue in issues_to_show:
            # Create grouping key from first 60 chars of message
            key = (issue.severity, issue.category, issue.message[:60])
            issue_groups[key].append(issue)

        print(f"\n--- Issues by Category (Total: {len(issues_to_show)}) ---")

        # Show errors first
        errors = [g for g in issue_groups.items() if g[0][0] == 'error']
        if errors:
            print(f"\n🔴 ERRORS ({sum(len(g[1]) for g in errors)} total):")
            for (severity, category, msg_prefix), group in sorted(errors, key=lambda x: len(x[1]), reverse=True):
                if len(group) == 1:
                    # Show single issue in detail
                    issue = group[0]
                    print(f"\n  ❌ [{category}] {issue.message}")
                    if issue.context:
                        for key, value in list(issue.context.items())[:5]:
                            if isinstance(value, (list, set)) and len(value) > 5:
                                value = f"{list(value)[:5]}... ({len(value)} total)"
                            print(f"     {key}: {value}")
                else:
                    # Group similar issues
                    print(f"\n  ❌ [{category}] {group[0].message}")
                    print(f"     Occurrences: {len(group)} times")
                    # Show sample contexts
                    sample_contexts = {}
                    for issue in group[:3]:
                        for key, value in issue.context.items():
                            if key not in sample_contexts:
                                sample_contexts[key] = set()
                            if isinstance(value, (list, tuple)):
                                value = str(value)
                            sample_contexts[key].add(str(value))

                    for key, values in list(sample_contexts.items())[:3]:
                        if len(values) <= 3:
                            print(f"     {key}: {', '.join(list(values)[:3])}")
                        else:
                            print(f"     {key}: {len(values)} unique values")

        # Show warnings
        warnings = [g for g in issue_groups.items() if g[0][0] == 'warning']
        if warnings:
            print(f"\n🟡 WARNINGS ({sum(len(g[1]) for g in warnings)} total):")
            shown = 0
            for (severity, category, msg_prefix), group in sorted(warnings, key=lambda x: len(x[1]), reverse=True):
                if shown >= 5:  # Limit warnings shown
                    remaining = sum(len(g[1]) for _, g in warnings[shown:])
                    print(f"\n  ... and {remaining} more warnings (use declutter=False for full list)")
                    break

                if len(group) == 1:
                    issue = group[0]
                    print(f"\n  ⚠️  [{category}] {issue.message}")
                    if issue.context:
                        for key, value in list(issue.context.items())[:3]:
                            if isinstance(value, (list, set)) and len(value) > 3:
                                value = f"{list(value)[:3]}..."
                            print(f"     {key}: {value}")
                else:
                    print(f"\n  ⚠️  [{category}] {group[0].message}")
                    print(f"     Occurrences: {len(group)} times")

                shown += 1


class DataQualityChecker:
    """
    Comprehensive data quality checker for Aurora experimental results

    Checks:
    1. Completeness: All required fields present
    2. Uniqueness: Correct number of results per test-month
    3. Consistency: Same configurations across groups
    4. Progression: Proper monthly progression
    5. Base-name integrity: Consistent base-name groupings
    """

    def __init__(
        self,
        strict_uniqueness: bool = True,
        expected_seeds: Optional[int] = None,
        check_monthly_progression: bool = True,
        check_base_name_consistency: bool = True
    ):
        """
        Args:
            strict_uniqueness: If True, enforce exactly 1 result per (config, test_month)
                              If False, allow multiple results (multi-seed experiments)
            expected_seeds: Expected number of seeds for multi-seed experiments
            check_monthly_progression: Check that months progress incrementally
            check_base_name_consistency: Check base-name groupings are consistent
        """
        self.strict_uniqueness = strict_uniqueness
        self.expected_seeds = expected_seeds
        self.check_monthly_progression = check_monthly_progression
        self.check_base_name_consistency = check_base_name_consistency

    def check_quality(self, collection: ResultsCollection) -> DataQualityReport:
        """
        Perform comprehensive data quality check

        Returns:
            DataQualityReport with all findings
        """
        report = DataQualityReport(
            collection_name=collection.metadata.experiment_name,
            total_results=len(collection)
        )

        # Basic stats
        report.summary_stats = self._compute_summary_stats(collection)

        # Run checks
        self._check_completeness(collection, report)
        self._check_uniqueness(collection, report)
        self._check_consistency(collection, report)

        if self.check_monthly_progression:
            self._check_monthly_progression(collection, report)

        if self.check_base_name_consistency:
            self._check_base_name_consistency(collection, report)

        return report

    def _compute_summary_stats(self, collection: ResultsCollection) -> Dict[str, Any]:
        """Compute summary statistics"""
        stats = {}

        # Unique values
        unique_values = collection.get_unique_values(
            "dataset", "monthly_label_budget", "sampler_mode",
            "trainer_mode", "base_name"
        )

        stats["Datasets"] = sorted(unique_values.get("dataset", []))
        stats["Budgets"] = sorted(unique_values.get("monthly_label_budget", []))
        stats["Unique base names"] = len(unique_values.get("base_name", []))
        stats["Sampler modes"] = len(unique_values.get("sampler_mode", []))
        stats["Trainer modes"] = len(unique_values.get("trainer_mode", []))

        # Test month range
        test_months = [r.test_month for r in collection.results]
        if test_months:
            stats["Test month range"] = f"{min(test_months)} - {max(test_months)}"

        return stats

    def _create_hyperparam_key(self, result: ExperimentResult, exclude_seeds: bool = True) -> tuple:
        """
        Create a hashable key from hyperparameters (excluding Random-Seed by default).

        This ensures different hyperparameter configs (e.g., Num-Epochs=30 vs 50)
        are treated as separate experiments.

        Args:
            result: ExperimentResult to extract hyperparameters from
            exclude_seeds: Whether to exclude Random-Seed from the key

        Returns:
            Tuple of (key, value) pairs sorted by key
        """
        hyperparam_items = []
        for k, v in sorted(result.hyperparameters.items()):
            if exclude_seeds and k == 'Random-Seed':
                continue
            # Convert to hashable type
            if isinstance(v, (list, np.ndarray)):
                v = tuple(v) if len(v) < 100 else f"array_len_{len(v)}"
            hyperparam_items.append((k, v))
        return tuple(hyperparam_items)

    def _check_completeness(self, collection: ResultsCollection, report: DataQualityReport):
        """Check that all required fields are present and valid"""
        for i, result in enumerate(collection.results):
            # Check required arrays
            if result.predictions is None or len(result.predictions) == 0:
                report.add_issue(
                    'error', 'completeness',
                    f"Result {i}: Missing or empty predictions",
                    result_index=i,
                    dataset=result.dataset,
                    test_month=result.test_month
                )

            if result.labels is None or len(result.labels) == 0:
                report.add_issue(
                    'error', 'completeness',
                    f"Result {i}: Missing or empty labels",
                    result_index=i,
                    dataset=result.dataset,
                    test_month=result.test_month
                )

            # Check array length consistency
            if result.predictions is not None and result.labels is not None:
                if len(result.predictions) != len(result.labels):
                    report.add_issue(
                        'error', 'completeness',
                        f"Result {i}: Predictions and labels have different lengths",
                        result_index=i,
                        pred_len=len(result.predictions),
                        label_len=len(result.labels)
                    )

            # Check base_name is not empty
            if not result.base_name or result.base_name.strip() == "":
                report.add_issue(
                    'error', 'completeness',
                    f"Result {i}: base_name is empty or missing",
                    result_index=i,
                    dataset=result.dataset,
                    test_month=result.test_month
                )

    def _check_uniqueness(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check uniqueness constraints

        For each (dataset, trainer_mode, sampler_mode, monthly_label_budget, test_month, hyperparameters),
        we should have exactly 1 result (single-seed) or N results (multi-seed).

        IMPORTANT: Hyperparameters (except Random-Seed) are included in the uniqueness check!
        This ensures that results with different Num-Epochs, etc. are not treated as duplicates.
        """
        # Group by configuration + test_month + hyperparameters (excluding random seed)
        groups = defaultdict(list)

        for i, result in enumerate(collection.results):
            # Get random seed if available
            random_seed = result.hyperparameters.get('Random-Seed', None)

            # Create hashable hyperparameters key (excluding Random-Seed)
            hyperparam_key = self._create_hyperparam_key(result, exclude_seeds=True)

            # Configuration key (including hyperparameters!)
            config_key = (
                result.dataset,
                result.trainer_mode,
                result.sampler_mode,
                result.monthly_label_budget,
                result.base_name,
                result.test_month,
                hyperparam_key  # Include all hyperparameters except Random-Seed
            )

            groups[config_key].append((i, random_seed))

        # Check each group
        for config_key, items in groups.items():
            dataset, trainer, sampler, budget, base_name, test_month, hyperparam_key = config_key
            count = len(items)
            seeds = [seed for _, seed in items]

            if self.strict_uniqueness:
                # Expect exactly 1 result
                if count > 1:
                    report.add_issue(
                        'error', 'uniqueness',
                        f"Duplicate results for configuration",
                        dataset=dataset,
                        trainer_mode=trainer,
                        sampler_mode=sampler,
                        budget=budget,
                        base_name=base_name,
                        test_month=test_month,
                        count=count,
                        seeds=seeds
                    )
            else:
                # Multi-seed experiments
                if self.expected_seeds is not None:
                    if count != self.expected_seeds:
                        report.add_issue(
                            'warning', 'uniqueness',
                            f"Expected {self.expected_seeds} seeds, found {count}",
                            dataset=dataset,
                            trainer_mode=trainer,
                            sampler_mode=sampler,
                            budget=budget,
                            base_name=base_name,
                            test_month=test_month,
                            expected=self.expected_seeds,
                            actual=count,
                            seeds=seeds
                        )

                # Check for duplicate seeds
                non_none_seeds = [s for s in seeds if s is not None]
                if len(non_none_seeds) != len(set(non_none_seeds)):
                    seed_counts = Counter(non_none_seeds)
                    duplicates = {s: c for s, c in seed_counts.items() if c > 1}
                    report.add_issue(
                        'error', 'uniqueness',
                        f"Duplicate random seeds found",
                        dataset=dataset,
                        trainer_mode=trainer,
                        sampler_mode=sampler,
                        budget=budget,
                        base_name=base_name,
                        test_month=test_month,
                        duplicate_seeds=duplicates
                    )

    def _check_consistency(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check consistency across configurations

        IMPORTANT: Array sizes varying across MONTHS is EXPECTED (temporal drift).
        We only check within-month consistency (across seeds) and cross-dataset consistency.
        """
        # Check array length consistency within same month (across seeds)
        self._check_array_length_consistency_within_month(collection, report)

        # Check cross-dataset consistency (same month should have same test data)
        self._check_cross_dataset_consistency(collection, report)

    def _check_array_length_consistency_within_month(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check that for the same (dataset, budget, sampler, base_name, test_month, hyperparameters),
        all results (across different seeds) have the same array lengths.

        This is CRITICAL: multi-seed experiments should all test on the same data!
        """
        # Group by (dataset, budget, sampler, base_name, test_month, hyperparameters)
        groups = defaultdict(list)

        for result in collection.results:
            hyperparam_key = self._create_hyperparam_key(result, exclude_seeds=True)
            group_key = (
                result.dataset,
                result.monthly_label_budget,
                result.sampler_mode,
                result.base_name,
                result.test_month,
                hyperparam_key  # Include hyperparameters!
            )
            groups[group_key].append(result)

        # Check each group
        for group_key, results in groups.items():
            if len(results) <= 1:
                continue  # Only one result, nothing to compare

            dataset, budget, sampler, base_name, test_month, hyperparam_key = group_key

            # Get array lengths
            pred_lengths = []
            label_lengths = []
            uncert_lengths = []

            for r in results:
                if r.predictions is not None:
                    pred_lengths.append(len(r.predictions))
                if r.labels is not None:
                    label_lengths.append(len(r.labels))
                if r.uncertainties_past_month is not None:
                    uncert_lengths.append(len(r.uncertainties_past_month))

            # Check predictions
            if pred_lengths and len(set(pred_lengths)) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Different prediction array lengths across seeds for same month",
                    dataset=dataset,
                    budget=budget,
                    sampler=sampler,
                    base_name=base_name,
                    test_month=test_month,
                    num_seeds=len(results),
                    pred_lengths=sorted(set(pred_lengths)),
                    seed_counts=dict(Counter(pred_lengths))
                )

            # Check labels
            if label_lengths and len(set(label_lengths)) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Different label array lengths across seeds for same month",
                    dataset=dataset,
                    budget=budget,
                    sampler=sampler,
                    base_name=base_name,
                    test_month=test_month,
                    num_seeds=len(results),
                    label_lengths=sorted(set(label_lengths)),
                    seed_counts=dict(Counter(label_lengths))
                )

            # Check uncertainties
            if uncert_lengths and len(set(uncert_lengths)) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Different uncertainty array lengths across seeds for same month",
                    dataset=dataset,
                    budget=budget,
                    sampler=sampler,
                    base_name=base_name,
                    test_month=test_month,
                    num_seeds=len(results),
                    uncert_lengths=sorted(set(uncert_lengths)),
                    seed_counts=dict(Counter(uncert_lengths))
                )

    def _check_cross_dataset_consistency(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check that for the same (dataset, test_month), all experiments have the same array sizes.

        This is CRITICAL: Different methods testing on the same dataset/month should have
        identical test set sizes. If they differ, it means they're testing on different data!
        """
        # Group by (dataset, test_month)
        dataset_month_sizes = defaultdict(lambda: defaultdict(set))

        for result in collection.results:
            key = (result.dataset, result.test_month)

            if result.predictions is not None:
                dataset_month_sizes[key]['predictions'].add(len(result.predictions))
            if result.labels is not None:
                dataset_month_sizes[key]['labels'].add(len(result.labels))

        # Check each (dataset, month) pair
        for (dataset, test_month), sizes_dict in dataset_month_sizes.items():
            pred_sizes = sizes_dict.get('predictions', set())
            label_sizes = sizes_dict.get('labels', set())

            # Multiple prediction sizes = different test sets!
            if len(pred_sizes) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Different test set sizes across experiments for same dataset/month",
                    dataset=dataset,
                    test_month=test_month,
                    pred_sizes=sorted(pred_sizes),
                    affected_experiments=f"{len([r for r in collection.results if r.dataset == dataset and r.test_month == test_month])} experiments"
                )

            # Multiple label sizes = different test sets!
            if len(label_sizes) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Different label array sizes across experiments for same dataset/month",
                    dataset=dataset,
                    test_month=test_month,
                    label_sizes=sorted(label_sizes),
                    affected_experiments=f"{len([r for r in collection.results if r.dataset == dataset and r.test_month == test_month])} experiments"
                )

    def _check_monthly_progression(self, collection: ResultsCollection, report: DataQualityReport):
        """Check that months progress correctly"""
        # Group by configuration
        groups = defaultdict(list)

        for result in collection.results:
            config_key = (
                result.dataset,
                result.trainer_mode,
                result.sampler_mode,
                result.monthly_label_budget,
                result.base_name,
                result.hyperparameters.get('Random-Seed', 'default')
            )
            groups[config_key].append(result.test_month)

        # Check each configuration
        for config_key, test_months in groups.items():
            dataset, trainer, sampler, budget, base_name, seed = config_key

            # Sort months
            sorted_months = sorted(test_months)

            # Check for gaps
            if len(sorted_months) > 1:
                diffs = np.diff(sorted_months)
                unique_diffs = set(diffs)

                # Expect all differences to be 1 (consecutive months)
                if unique_diffs != {1}:
                    report.add_issue(
                        'warning', 'progression',
                        f"Non-consecutive month progression",
                        dataset=dataset,
                        trainer_mode=trainer,
                        sampler_mode=sampler,
                        budget=budget,
                        base_name=base_name,
                        seed=seed,
                        month_diffs=sorted(unique_diffs),
                        month_range=f"{sorted_months[0]}-{sorted_months[-1]}"
                    )

    def _check_base_name_consistency(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check that base-name groupings are consistent

        Critical checks:
        1. Same base_name should have same trainer_mode/sampler_mode
        2. Same number of months per base_name (per dataset/budget combo)
        3. Same number of results per month (for multi-seed experiments)
        """
        # Check 1: Map base_name to configurations
        base_name_configs = defaultdict(set)

        for result in collection.results:
            base_name_configs[result.base_name].add((
                result.trainer_mode,
                result.sampler_mode
            ))

        # Check for inconsistencies in trainer/sampler
        for base_name, configs in base_name_configs.items():
            if len(configs) > 1:
                report.add_issue(
                    'warning', 'consistency',
                    f"Base name maps to multiple sampler modes (may affect aggregation)",
                    base_name=base_name,
                    num_combinations=len(configs),
                    configurations=list(configs),
                    note="If sampler mode should be distinguished, include it in base_name"
                )

        # Check 2 & 3: Month count and seed count consistency per base-name
        self._check_base_name_month_consistency(collection, report)

    def _check_base_name_month_consistency(self, collection: ResultsCollection, report: DataQualityReport):
        """
        Check that each base-name has:
        - Same number of months across (dataset, budget, sampler) combinations
        - Same number of results per month (for multi-seed)
        """
        # Group by (dataset, budget, sampler, base_name)
        groups = defaultdict(lambda: defaultdict(list))

        for result in collection.results:
            group_key = (
                result.dataset,
                result.monthly_label_budget,
                result.sampler_mode,
                result.base_name
            )
            groups[group_key][result.test_month].append(result)

        # Check month counts per base-name
        base_name_month_counts = defaultdict(set)

        for group_key, months_dict in groups.items():
            dataset, budget, sampler, base_name = group_key
            num_months = len(months_dict)

            # Track month counts for this base-name
            check_key = (dataset, budget, sampler)
            base_name_month_counts[(base_name, check_key)].add(num_months)

            # Check seed counts per month
            for test_month, results in months_dict.items():
                num_results = len(results)

                # Check if this base-name has consistent result counts across months
                # (This is important for multi-seed experiments)
                all_counts_this_config = [len(r) for r in months_dict.values()]

                if len(set(all_counts_this_config)) > 1:
                    report.add_issue(
                        'warning', 'consistency',
                        f"Inconsistent number of results per month for base-name",
                        base_name=base_name,
                        dataset=dataset,
                        budget=budget,
                        sampler=sampler,
                        counts_per_month=dict(Counter(all_counts_this_config)),
                        example_month=test_month,
                        this_month_count=num_results
                    )

        # Check that each base-name has same month count across configs
        for (base_name, check_key), month_counts in base_name_month_counts.items():
            dataset, budget, sampler = check_key

            if len(month_counts) > 1:
                report.add_issue(
                    'error', 'consistency',
                    f"Base-name has different number of months across configurations",
                    base_name=base_name,
                    dataset=dataset,
                    budget=budget,
                    sampler=sampler,
                    month_counts=sorted(month_counts)
                )


def check_results_df(
    results_df: pd.DataFrame,
    group_by_keys: List[str] = None,
    additional_keys: List[str] = None
) -> pd.DataFrame:
    """
    Legacy compatibility: check_results() from aurora.ipynb

    Analyzes monthly performance groupings and checks for duplicates.

    Args:
        results_df: DataFrame of results
        group_by_keys: Primary grouping keys
        additional_keys: Additional keys (like Random-Seed)

    Returns:
        DataFrame with analysis of each group
    """
    if group_by_keys is None:
        group_by_keys = ["dataset", "trainer_mode", "monthly_label_budget"]
    if additional_keys is None:
        additional_keys = ["random_seed"] if "random_seed" in results_df.columns else []

    all_keys = group_by_keys + additional_keys

    # Group and aggregate
    grouped = results_df.groupby(all_keys).agg({"test_month": list})

    # Add analysis columns
    grouped["Test-Month Counter"] = grouped["test_month"].apply(lambda x: Counter(x))
    grouped["Max Month Repetition"] = grouped["Test-Month Counter"].apply(lambda x: max(x.values()) if x else 0)
    grouped["Max Month Index"] = grouped["Test-Month Counter"].apply(lambda x: max(dict(x).keys()) if x else None)
    grouped["Monthly Progression"] = grouped["test_month"].apply(lambda x: set(np.diff(x)) if len(x) > 1 else set())

    return grouped
