"""
Base Name Mapping - Generic Utilities

Provides simple utilities for tracking base-name mappings.
Users define their own mapping functions based on their data.

Philosophy:
- Framework provides the mechanism (loaders accept mapper functions)
- Users provide the policy (what constitutes a base-name)
"""

from typing import Dict, List, Optional, Any
from collections import defaultdict


class BaseNameRegistry:
    """
    Simple registry for tracking base-name mappings

    Useful for debugging and ensuring consistency across experiments.
    Tracks which detailed configurations map to which base names.
    """

    def __init__(self):
        self.mapping: Dict[str, str] = {}  # detailed_config → base_name
        self.reverse_mapping: Dict[str, List[str]] = defaultdict(list)  # base_name → [configs]

    def register(self, config_key: str, base_name: str):
        """
        Register a mapping from configuration to base name

        Args:
            config_key: Some identifier for the configuration (e.g., full method name)
            base_name: The base name for grouping
        """
        if config_key in self.mapping:
            if self.mapping[config_key] != base_name:
                raise ValueError(
                    f"Inconsistent mapping: '{config_key}' already maps to "
                    f"'{self.mapping[config_key]}', cannot remap to '{base_name}'"
                )
        else:
            self.mapping[config_key] = base_name
            self.reverse_mapping[base_name].append(config_key)

    def get_base_name(self, config_key: str) -> Optional[str]:
        """Get base name for a configuration"""
        return self.mapping.get(config_key)

    def get_configs(self, base_name: str) -> List[str]:
        """Get all configurations that map to a base name"""
        return self.reverse_mapping.get(base_name, [])

    def get_all_base_names(self) -> List[str]:
        """Get all base names"""
        return list(self.reverse_mapping.keys())

    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the registry"""
        return {
            "total_configs": len(self.mapping),
            "total_base_names": len(self.reverse_mapping),
            "avg_configs_per_base": (
                len(self.mapping) / len(self.reverse_mapping)
                if self.reverse_mapping else 0
            )
        }

    def print_summary(self, max_per_base: int = 5):
        """
        Print summary of mappings

        Args:
            max_per_base: Maximum configs to show per base name
        """
        print(f"\n{'='*80}")
        print("BASE NAME REGISTRY SUMMARY")
        print(f"{'='*80}")

        stats = self.get_stats()
        print(f"\nTotal base names: {stats['total_base_names']}")
        print(f"Total configurations: {stats['total_configs']}")
        print(f"Avg configs per base: {stats['avg_configs_per_base']:.1f}")

        print(f"\n{'='*80}")
        print("MAPPINGS")
        print(f"{'='*80}")

        for base_name in sorted(self.reverse_mapping.keys()):
            configs = self.reverse_mapping[base_name]
            print(f"\n📊 Base Name: {base_name}")
            print(f"   ({len(configs)} configuration(s))")

            for i, config in enumerate(configs[:max_per_base]):
                print(f"   {i+1}. {config}")

            if len(configs) > max_per_base:
                print(f"   ... and {len(configs) - max_per_base} more")


# Simple default mapper - just use a field as-is
def simple_field_mapper(field_name: str = "Trainer-Mode"):
    """
    Create a simple mapper that just returns a field value

    Args:
        field_name: Field to use as base name

    Returns:
        Mapper function

    Example:
        >>> mapper = simple_field_mapper("Trainer-Mode")
        >>> base_name = mapper({"Trainer-Mode": "CE", ...})
    """
    def mapper(result_dict: Dict[str, Any]) -> str:
        return result_dict.get(field_name, "Unknown")

    return mapper


# Example: Composite mapper based on multiple fields
def composite_field_mapper(*field_names: str, separator: str = " - "):
    """
    Create a mapper that combines multiple fields

    Args:
        field_names: Fields to combine
        separator: String to join fields with

    Returns:
        Mapper function

    Example:
        >>> mapper = composite_field_mapper("Trainer-Mode", "Sampler-Mode")
        >>> base_name = mapper({"Trainer-Mode": "CE", "Sampler-Mode": "full_year"})
        >>> # Returns: "CE - full_year"
    """
    def mapper(result_dict: Dict[str, Any]) -> str:
        parts = [str(result_dict.get(field, "?")) for field in field_names]
        return separator.join(parts)

    return mapper
