import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.ticker import PercentFormatter
from colorspacious import cspace_convert

from scipy.spatial.distance import pdist
import colorsys
import numpy as np

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
from matplotlib.lines import Line2D
import matplotlib.gridspec as gridspec
import math
from concurrent.futures import ThreadPoolExecutor
import pandas as pd



import matplotlib.pyplot as plt
import re
import numpy as np
from matplotlib.ticker import PercentFormatter, LogLocator, FuncFormatter
from matplotlib.colors import LinearSegmentedColormap

import matplotlib.pyplot as plt
import re
import numpy as np
from matplotlib.ticker import PercentFormatter, LogLocator, FuncFormatter
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm as cm



SAMPLER_NAME_MPR = {
    'full_first_year_subsample_months': r'$B_0 = |\mathcal{D}_0|$',
    'weight_first_year_subsample_months': 'Weighted First Year',
    'subsample_first_year_subsample_months': r'$B_0 = 4800$', 
    'stratk_tau_init_tau_random': r'$B_0 = 4800\text{-}\B_{\text{rand}}$'
}

SAMPLER_LINE_MPR = {
    'full_first_year_subsample_months': '-',
    'weight_first_year_subsample_months': ':',
    'subsample_first_year_subsample_months': '--',
    'stratk_tau_init_tau_random': ':'
}


##################
### Plot Utils ###
##################

def generate_distinct_colors(n_colors, seed=0):
    """
    Generate maximally distinct colors using HSV space optimization.
    
    Parameters:
    -----------
    n_colors : int
        Number of distinct colors to generate
    seed : int, optional (default=0)
        Random seed for reproducibility
    
    Returns:
    --------
    np.ndarray
        Array of RGBA colors where each color is maximally distinct from others
    """
    def hsv_to_rgb(hsv):
        """Convert HSV color to RGB."""
        return colorsys.hsv_to_rgb(hsv[0], hsv[1], hsv[2])
    
    def compute_color_distances(colors_hsv):
        """Compute pairwise distances between colors in CAM02-UCS space."""
        colors_rgb = np.array([hsv_to_rgb(c) for c in colors_hsv])
        colors_cam = cspace_convert(colors_rgb, "sRGB1", "CAM02-UCS")
        distances = pdist(colors_cam)
        return np.min(distances)
    
    np.random.seed(seed)
    
    # Initialize with golden ratio method for hue
    golden_ratio = (1 + np.sqrt(5)) / 2
    hues = np.array([(i * golden_ratio) % 1.0 for i in range(n_colors)])
    
    # Create initial colors with maximum saturation and value
    colors_hsv = np.column_stack([
        hues,
        np.ones(n_colors) * 0.85,  # Slightly reduced saturation for better visibility
        np.ones(n_colors) * 0.95   # Slightly reduced value for better contrast
    ])
    
    # Optimize positions using simple perturbation
    best_distance = compute_color_distances(colors_hsv)
    best_colors = colors_hsv.copy()
    
    n_iterations = 1000
    temperature = 1.0
    cooling_rate = 0.95
    
    for i in range(n_iterations):
        # Perturb colors
        perturbed_colors = colors_hsv.copy()
        idx = np.random.randint(n_colors)
        perturbed_colors[idx, 0] = (perturbed_colors[idx, 0] + np.random.normal(0, 0.1)) % 1.0
        perturbed_colors[idx, 1] = np.clip(perturbed_colors[idx, 1] + np.random.normal(0, 0.1), 0.6, 0.9)
        perturbed_colors[idx, 2] = np.clip(perturbed_colors[idx, 2] + np.random.normal(0, 0.1), 0.8, 1.0)
        
        # Compute new distance
        new_distance = compute_color_distances(perturbed_colors)
        
        # Accept if better or with probability based on temperature
        if new_distance > best_distance or np.random.random() < np.exp((new_distance - best_distance) / temperature):
            colors_hsv = perturbed_colors
            if new_distance > best_distance:
                best_distance = new_distance
                best_colors = perturbed_colors.copy()
        
        temperature *= cooling_rate
    
    # Convert optimized colors to RGB
    colors_rgb = np.array([hsv_to_rgb(c) for c in best_colors])
    
    # Add alpha channel
    colors_rgba = np.column_stack([colors_rgb, np.ones(n_colors)])
    
    return colors_rgba

def create_distinct_color_map(unique_classifiers, seed=0):
    """
    Create a color map using maximally distinct colors.
    
    Parameters:
    -----------
    unique_classifiers : array-like
        List of unique classifier names or identifiers
    seed : int, optional (default=0)
        Random seed for reproducibility
    
    Returns:
    --------
    dict
        Mapping of classifiers to distinct colors
    """
    colors = generate_distinct_colors(len(unique_classifiers), seed)
    return dict(zip(unique_classifiers, colors))



###############################
### Preliminary Experiments ### 
###############################




def plot_samples_vs_f1(df, log_scale=True, ylim=None, consistent_ylim=True, fig_height=3):
    """
    Create publication-quality subplots showing F1 scores across different sample sizes and label budgets.
    
    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame with MultiIndex columns and index
    log_scale : bool, default=True
        Whether to use log scale for the x-axis
    ylim : tuple, default=None
        Optional custom y-axis limits (min, max)
    consistent_ylim : bool, default=True
        Whether to use the same y-axis limits for all subplots
    fig_height : float, default=12
        Height of the figure in inches
    """
    # Configure plot style for research paper
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans'],
        'font.size': 11,
        'axes.titlesize': 12,
        'axes.labelsize': 11,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 9,
        'figure.dpi': 300
    })
    
    # Get unique datasets from the column MultiIndex first level
    datasets = df.columns.get_level_values(0).unique()
    
    # Get unique label budgets from the index first level
    label_budgets = sorted(df.index.get_level_values(0).unique())
    
    # Create figure with rectangular aspect ratio
    fig_width = min(9.0, 3.0 * len(datasets))  # 3.0 inches per subplot, max 9 inches
    
    # Create figure and axes
    fig, axes = plt.subplots(1, len(datasets), figsize=(fig_width, fig_height))
    
    # Make axes iterable if there's only one subplot
    if len(datasets) == 1:
        axes = [axes]
    

    cmap = cm.get_cmap("turbo", int(1.5*len(label_budgets)))
    colors = cmap(np.arange(len(label_budgets)))

    
    # Style elements - using the markers in your image but with a single line style
    markers = ['o', 's', 'D', '^']  # Circle, square, diamond, triangle
    line_style = '-'  # Use only a solid line style for all
    
    # For consistent y-limits across all subplots
    all_f1_values = []
    
    # First pass to collect all F1 values if we need consistent y-limits
    if consistent_ylim and not ylim:
        for dataset in datasets:
            for budget in label_budgets:
                budget_mask = df.index.get_level_values(0) == budget
                budget_data = df[budget_mask]
                
                for idx_tuple in budget_data.index:
                    try:
                        f1_score = budget_data.loc[idx_tuple, (dataset, 'F1')]
                        if isinstance(f1_score, str):
                            f1_score = float(f1_score.strip('%')) / 100
                        if not np.isnan(f1_score):
                            all_f1_values.append(f1_score)
                    except:
                        continue
    
    # Calculate global y-limits if needed
    global_ylim = None
    if consistent_ylim and all_f1_values and not ylim:
        f1_min = max(0, np.floor(min(all_f1_values) * 100 - 5))
        f1_max = min(100, np.ceil(max(all_f1_values) * 100 + 5))
        global_ylim = (f1_min, f1_max)
    
    # Process each dataset
    for dataset_idx, dataset in enumerate(datasets):
        ax = axes[dataset_idx]
        
        for budget_idx, budget in enumerate(label_budgets):
            # Calculate style elements with cycling for larger datasets
            color = colors[budget_idx % len(colors)]
            marker = markers[budget_idx % len(markers)]
            
            # Get data for current budget
            budget_mask = df.index.get_level_values(0) == budget
            budget_data = df[budget_mask]
            
            # Extract sample sizes and F1 scores
            samples = []
            f1_scores = []
            
            for idx_tuple in budget_data.index:
                model_name = idx_tuple[-1]  # Get the model name from the index tuple
                try:
                    # Extract sample size
                    sample_size = int(re.search(r'Samples=(\d+)', model_name).group(1))
                    
                    # Get F1 score
                    f1_score = budget_data.loc[idx_tuple, (dataset, 'F1')]
                    # Convert from percentage string to float if necessary
                    if isinstance(f1_score, str):
                        f1_score = float(f1_score.strip('%')) / 100
                    
                    samples.append(12*sample_size)  # Multiply by 12 for monthly samples
                    f1_scores.append(f1_score)
                except (AttributeError, ValueError, KeyError) as e:
                    print(f"Warning: Could not process {model_name}: {str(e)}")
                    continue
            
            if samples and f1_scores:
                # Remove all NaNs
                samples = np.array(samples)
                f1_scores = np.array(f1_scores)
                
                mask = ~np.isnan(f1_scores)
                samples = samples[mask]
                f1_scores = f1_scores[mask]

                # Sort by sample size
                sorted_indices = np.argsort(samples)
                sorted_samples = samples[sorted_indices]
                sorted_f1_scores = 100 * f1_scores[sorted_indices]  # Convert to percentage
                
                # Plot line with improved styling
                label = rf'$B_0 = {budget}$'
                ax.plot(
                    sorted_samples, 
                    sorted_f1_scores,
                    marker=markers[0], 
                    color=color,
                    linestyle=line_style,
                    linewidth=1.5,
                    markersize=2.5,
                    label=label
                )
        
        # Customize subplot with better formatting
        ax.set_title(f'{dataset} Dataset', fontweight='bold')
        xlabel = r"$B_0$"
        ax.set_xlabel(xlabel)
        
        # Only show y-label on first subplot
        if dataset_idx == 0:
            ax.set_ylabel('F1 Score')
        
        # Format grid for better readability
        ax.grid(True, linestyle='--', alpha=0.3)
        
        # Format as percentage
        ax.yaxis.set_major_formatter(PercentFormatter())
        
        # Set x-axis to log scale with better formatting
        if log_scale:
            ax.set_xscale('log')
            
            # Add custom formatter for cleaner numbers
            def format_fn(x, pos):
                if x < 1000:
                    return str(int(x))
                else:
                    return f'{int(x/1000)}k'
                    
            ax.xaxis.set_major_formatter(FuncFormatter(format_fn))
            
            # Set better tick locations
            ax.xaxis.set_major_locator(LogLocator(base=10, numticks=5))
            
            # For log scale, we need a non-zero left limit
            min_sample = min(samples) if len(samples) > 0 else 100
            ax.set_xlim(left=min_sample * 0.5)
        else:
            # For linear scale
            ax.set_xlim(left=0)
        
        # Set y-axis limits
        if ylim:
            ax.set_ylim(ylim)
        elif global_ylim:
            ax.set_ylim(global_ylim)
        else:
            # Auto-set y limits to slightly expand beyond the data range
            local_f1_scores = [f for f in f1_scores if not np.isnan(f)]
            if local_f1_scores:
                f1_min = max(0, np.floor(min(local_f1_scores) * 100 - 5))
                f1_max = min(100, np.ceil(max(local_f1_scores) * 100 + 5))
                ax.set_ylim(f1_min, f1_max)
    
        # Set aspect ratio to make the plot rectangular (wider than tall)
        ax.set_box_aspect(0.6)
    
    # Add a single legend for the entire figure, placed below the plots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles, 
        labels, 
        loc='lower center', 
        bbox_to_anchor=(0.5, -0.01),  # Position below the plots
        ncol=len(label_budgets),    # Put all items in one row
        frameon=True, 
        framealpha=0.9,
        edgecolor='lightgray',
        fontsize=9,
        title="Label Budget"
    )
    
    # Add more padding at the bottom to make room for the legend
    plt.subplots_adjust(bottom=0.15)
    # Adjust layout with better spacing
    plt.tight_layout(rect=[0, 0.1, 1, 1])  # Leave space at bottom for legend
    return fig, axes



def plot_sampling_techniques_vs_f1(
        sampling_techniques_dict, 
        log_scale=True, 
        ylim=None, 
        consistent_ylim=True, 
        fig_height=3,
        figname=None, 
):
    """
    Create publication-quality subplots showing average F1 scores across different sampling techniques.
    
    Parameters:
    -----------
    sampling_techniques_dict : dict
        Dictionary where keys are sampling technique names and values are DataFrames
        with the same structure as sample_exp_res_df
    log_scale : bool, default=True
        Whether to use log scale for the x-axis
    ylim : tuple, default=None
        Optional custom y-axis limits (min, max)
    consistent_ylim : bool, default=True
        Whether to use the same y-axis limits for all subplots
    fig_height : float, default=3
        Height of the figure in inches
    """
    # Configure plot style for research paper
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans'],
        'font.size': 11,
        'axes.titlesize': 12,
        'axes.labelsize': 11,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 9,
        'figure.dpi': 300
    })
    
    # Get a sample DataFrame to determine the dataset names
    sample_df = next(iter(sampling_techniques_dict.values()))
    
    # Get unique datasets from the column MultiIndex first level
    datasets = sample_df.columns.get_level_values(0).unique()
    
    # Create figure with rectangular aspect ratio
    fig_width = min(9.0, 3.0 * len(datasets))  # 3.0 inches per subplot, max 9 inches
    
    # Create figure and axes
    fig, axes = plt.subplots(1, len(datasets), figsize=(fig_width, fig_height))
    
    # Make axes iterable if there's only one subplot
    if len(datasets) == 1:
        axes = [axes]
    
    # Create color map for different sampling techniques
    technique_names = list(sampling_techniques_dict.keys())

    # Create a more gradual blue-to-green colormap with more intermediate colors
    cmap = cm.get_cmap("turbo_r", int(2*len(technique_names)))
    colors = cmap(np.arange(len(technique_names)))


    # Style elements - using the markers in your image but with a single line style
    markers = ['o', 's', 'D', '^']  # Circle, square, diamond, triangle
    line_style = '-'  # Use only a solid line style for all
    
    # For consistent y-limits across all subplots
    all_f1_values = []
    
    # First pass to collect all F1 values if we need consistent y-limits
    if consistent_ylim and not ylim:
        for technique_name, df in sampling_techniques_dict.items():
            for dataset in datasets:
                # Get all label budgets for averaging
                label_budgets = df.index.get_level_values(0).unique()
                
                # For each sample size, compute average F1 across all label budgets
                avg_f1_by_sample = {}  # To store {sample_size: avg_f1}
                
                for budget in label_budgets:
                    budget_mask = df.index.get_level_values(0) == budget
                    budget_data = df[budget_mask]
                    
                    for idx_tuple in budget_data.index:
                        model_name = idx_tuple[-1]  # Get the model name from the index tuple
                        try:
                            # Extract sample size
                            sample_size = int(re.search(r'Samples=(\d+)', model_name).group(1))
                            #sample_size = 12 * sample_size  # Multiply by 12 for monthly samples
                            
                            # Get F1 score
                            f1_score = budget_data.loc[idx_tuple, (dataset, 'F1')]
                            # Convert from percentage string to float if necessary
                            if isinstance(f1_score, str):
                                f1_score = float(f1_score.strip('%')) / 100
                            
                            if np.isnan(f1_score):
                                continue
                                
                            # Add to sample size dictionary for averaging later
                            if sample_size not in avg_f1_by_sample:
                                avg_f1_by_sample[sample_size] = []
                            avg_f1_by_sample[sample_size].append(f1_score)
                        except (AttributeError, ValueError, KeyError):
                            continue
                
                # Compute averages by sample size
                for sample_size, f1_values in avg_f1_by_sample.items():
                    if f1_values:
                        avg_f1 = np.mean(f1_values)
                        if not np.isnan(avg_f1):
                            all_f1_values.append(avg_f1)
    
    # Calculate global y-limits if needed
    global_ylim = None
    if consistent_ylim and all_f1_values and not ylim:
        f1_min = max(0, np.floor(min(all_f1_values) * 100 - 5))
        f1_max = min(100, np.ceil(max(all_f1_values) * 100 + 5))
        global_ylim = (f1_min, f1_max)
    
    # Process each dataset
    for dataset_idx, dataset in enumerate(datasets):
        ax = axes[dataset_idx]
        
        for technique_idx, (technique_name, df) in enumerate(sampling_techniques_dict.items()):
            # Calculate style elements with cycling for larger datasets
            color = colors[technique_idx % len(colors)]
            marker = markers[technique_idx % len(markers)]
            
            # Get all label budgets for averaging
            label_budgets = df.index.get_level_values(0).unique()
            
            # For each sample size, compute average F1 across all label budgets
            avg_f1_by_sample = {}  # To store {sample_size: avg_f1}
            
            for budget in label_budgets:
                budget_mask = df.index.get_level_values(0) == budget
                budget_data = df[budget_mask]
                
                for idx_tuple in budget_data.index:
                    model_name = idx_tuple[-1]  # Get the model name from the index tuple
                    try:
                        # Extract sample size
                        sample_size = int(re.search(r'Samples=(\d+)', model_name).group(1))
                        sample_size = 12 * sample_size  # Multiply by 12 for monthly samples
                        
                        # Get F1 score
                        f1_score = budget_data.loc[idx_tuple, (dataset, 'F1')]
                        # Convert from percentage string to float if necessary
                        if isinstance(f1_score, str):
                            f1_score = float(f1_score.strip('%')) / 100
                        
                        if np.isnan(f1_score):
                            continue
                            
                        # Add to sample size dictionary for averaging later
                        if sample_size not in avg_f1_by_sample:
                            avg_f1_by_sample[sample_size] = []
                        avg_f1_by_sample[sample_size].append(f1_score)
                    except (AttributeError, ValueError, KeyError):
                        continue
            
            # Compute averages by sample size
            samples = []
            avg_f1_scores = []
            
            for sample_size, f1_values in avg_f1_by_sample.items():
                if f1_values:
                    avg_f1 = np.mean(f1_values)
                    if not np.isnan(avg_f1):
                        samples.append(sample_size)
                        avg_f1_scores.append(avg_f1)
            
            if samples and avg_f1_scores:
                # Convert to numpy arrays
                samples = np.array(samples)
                avg_f1_scores = np.array(avg_f1_scores)
                
                # Sort by sample size
                sorted_indices = np.argsort(samples)
                sorted_samples = samples[sorted_indices]
                sorted_avg_f1_scores = 100 * avg_f1_scores[sorted_indices]  # Convert to percentage
                
                # Plot line with improved styling
                ax.plot(
                    sorted_samples, 
                    sorted_avg_f1_scores,
                    marker=markers[0], 
                    color=color,
                    linestyle=line_style,
                    linewidth=1.5,
                    markersize=2.5,
                    label=technique_name
                )
        
        # Customize subplot with better formatting
        ax.set_title(f'{dataset} Dataset', fontweight='bold')
        ax.set_xlabel(r'$B_0$')
        
        # Only show y-label on first subplot
        if dataset_idx == 0:
            ax.set_ylabel('Average F1 Score')
        
        # Format grid for better readability
        ax.grid(True, linestyle='--', alpha=0.3)
        
        # Format as percentage
        ax.yaxis.set_major_formatter(PercentFormatter())
        
        # Set x-axis to log scale with better formatting
        if log_scale:
            ax.set_xscale('log')
            
            # Add custom formatter for cleaner numbers
            def format_fn(x, pos):
                if x < 1000:
                    return str(int(x))
                else:
                    return f'{int(x/1000)}k'
                    
            ax.xaxis.set_major_formatter(FuncFormatter(format_fn))
            
            # Set better tick locations
            ax.xaxis.set_major_locator(LogLocator(base=10, numticks=5))
            
            # For log scale, we need a non-zero left limit
            min_sample = min(samples) if len(samples) > 0 else 100
            ax.set_xlim(left=min_sample * 0.5)
        else:
            # For linear scale
            ax.set_xlim(left=0)
        
        # Set y-axis limits
        if ylim:
            ax.set_ylim(ylim)
        elif global_ylim:
            ax.set_ylim(global_ylim)
        else:
            # Auto-set y limits to slightly expand beyond the data range
            local_f1_scores = [f for f in avg_f1_scores if not np.isnan(f)]
            if local_f1_scores:
                f1_min = max(0, np.floor(min(local_f1_scores) * 100 - 5))
                f1_max = min(100, np.ceil(max(local_f1_scores) * 100 + 5))
                ax.set_ylim(f1_min, f1_max)
    
        # Set aspect ratio to make the plot rectangular (wider than tall)
        ax.set_box_aspect(0.6)
    
    # Add a single legend for the entire figure, placed below the plots
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(
        handles, 
        labels, 
        loc='lower center', 
        bbox_to_anchor=(0.5, -0.01),  # Position below the plots
        ncol=min(len(technique_names), 4),  # Use multiple rows if many techniques
        frameon=True, 
        framealpha=0.9,
        edgecolor='lightgray',
        fontsize=9,
        title=r"SubSampling of $\mathcal{D}_0$"
    )
    
    # Add more padding at the bottom to make room for the legend
    plt.subplots_adjust(bottom=0.15)
    # Adjust layout with better spacing
    plt.tight_layout(rect=[0, 0.1, 1, 1])  # Leave space at bottom for legend
    if figname is not None:
        # Save as pdf
        fig.savefig(figname, format='pdf', bbox_inches='tight')
    return fig, axes



########################
### Time-Aware Plots ###
########################

def plot_rejection_metrics_by_month(
    results_list, 
    label_budget:int, 
    rejection_point:int, 
    title:str = "", 
    show_rejection:bool = True,
    show_efficiency:bool = False,
    show_relative_improvement:bool = True,
    month_range:list = None,
    method_name_mpr=None,
    fig_width=None,
    fig_height_per_subplot=3,  # Reduced for more compact plots
    publication_style=True, 
    figname=None
):
    """
    Enhanced version of rejection metrics plot with dual legends for publication.
    """
    # Apply publication-friendly styling if requested
    if publication_style:
        plt.style.use('seaborn-v0_8-whitegrid')
        plt.rcParams.update({
            'font.family': 'sans-serif',
            'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans'],
            'font.size': 12,
            'axes.titlesize': 12,
            'axes.labelsize': 12,
            'xtick.labelsize': 12,
            'ytick.labelsize': 12,
            'legend.fontsize': 12,
            'figure.dpi': 300
        })
    
    # Set default method_name_mpr if None
    if method_name_mpr is None:
        method_name_mpr = lambda x: x
    
    if rejection_point == 0:
        show_efficiency = False 
        show_relative_improvement = False 
        show_rejection = False

    # Filter results
    filtered1 = [entry for entry in results_list if label_budget in entry[0]]
    if not filtered1:
        raise ValueError(f"Label Budget {label_budget} not present in results_list.")

    filtered2 = [entry for entry in filtered1 if rejection_point in entry[1]]
    if not filtered2:
        raise ValueError(f"Rejection point {rejection_point} not present in results_list.")
    
    # Extract data
    methods, diffs, data_list = [], [], []
    for entry in filtered2:
        _, diff, method = entry[0]
        method = method_name_mpr(method)
        methods.append(method)
        diffs.append(diff)
        data_list.append(entry[1][rejection_point])
    
    n_months = len(data_list[0]["monthly_Rejections"])
    
    # Month range handling
    start_month = month_range[0] - 1 if month_range and len(month_range) > 0 and month_range[0] else 0
    end_month = month_range[1] if month_range and len(month_range) > 1 and month_range[1] else n_months
    start_month = max(0, min(start_month, n_months - 1))
    end_month = max(start_month + 1, min(end_month, n_months))

    # Setup method pairs
    method_pairs = [(m, d) for m, d in zip(methods, diffs)]
    unique_pairs = sorted(set(method_pairs))
    n_pairs = len(unique_pairs)

    # Figure setup
    n_plots = 1 + sum([show_rejection, show_efficiency, show_relative_improvement])
    
    # Calculate appropriate figure width if not provided
    if fig_width is None:
        width = max(12, (end_month - start_month) * 0.8)
    else:
        width = fig_width
    
    # Create more compact figure with reduced height per subplot
    # For single plot, increase height slightly for better proportions
    if n_plots == 1:
        fig_height = max(5, fig_height_per_subplot)  # Ensure single plot has sufficient height
    else:
        fig_height = fig_height_per_subplot * n_plots
        
    fig, axes = plt.subplots(n_plots, 1, figsize=(width, fig_height))

    if not isinstance(axes, np.ndarray):
        axes = [axes]
    ax_line = axes[0]
    ax_bar = axes[1] if show_rejection else None
    ax_eff = axes[2] if show_efficiency and show_relative_improvement else None
    ax_rel = axes[-1] if show_relative_improvement else None

    # Visual setup - using better color scheme for publication
    unique_methods = sorted(set(methods))
    
    color_map = create_distinct_color_map(list(set(methods)))
    
    # Map method pairs to colors
    method_color_map = {(m, d): color_map[m] for m, d in method_pairs}
    
    # Linestyle mapping for different sampling methods
    base_diff_linestyle_map = {
        'full_first_year_subsample_months': '-',
        'weight_first_year_subsample_months': ':',
        'subsample_first_year_subsample_months': '--',
    }
    
    # Create a more generic mapping for any diff that contains these patterns
    diff_linestyle_map = {}
    for diff in set(diffs):
        for key, style in base_diff_linestyle_map.items():
            if key in diff:
                diff_linestyle_map[diff] = style
                break
        else:
            # Default linestyle if no match is found
            diff_linestyle_map[diff] = '-'
    
    # Get unique differentiators for the second legend
    unique_diffs = sorted(set(diffs))
    
    # Hatch patterns for different differentiator types
    hatch_map = {'full_first': '', 'subsample_first': '//', 'weight_first': '*'}

    # Spacing setup
    group_width = 1.0  # Total width allocated for each group
    month_gap = 0.15   # Gap between month groups

    # Calculate bar widths and positions
    total_group_space = group_width + month_gap
    bar_width = group_width / (n_pairs + 1)  # +1 for padding
    total_bars_width = n_pairs * bar_width

    # First center should be at group_width/2 to ensure first bar isn't cut off
    centers = group_width/2 + np.arange(end_month - start_month) * total_group_space

    # Month separators
    for month in range(1, end_month - start_month):
        sep_x = (centers[month-1] + centers[month])/2
        for ax in axes:
            ax.axvline(sep_x, color='black', linewidth=0.5, alpha=0.3)
    
    # Range indicators
    if month_range:
        ylims = {}
        for ax in axes:
            ylims[ax] = ax.get_ylim()
            if start_month > 0:
                ax.axvline(centers[0], color='black', linewidth=2)
            if end_month < n_months:
                ax.axvline(centers[-1], color='black', linewidth=2)
    
    # Line plot
    for method, diff in unique_pairs:
        idx = method_pairs.index((method, diff))
        data = data_list[idx]["monthly_F1"][start_month:end_month]
        ls = diff_linestyle_map.get(diff, '-')
        
        # Format label for cleaner display - only include method name for the first legend
        label = method
        
        # For single plot, make lines slightly thicker for better visibility
        line_width = 2.0 if n_plots == 1 else 1.5
        
        ax_line.plot(centers, data*100, color=method_color_map[(method, diff)], linestyle=ls,
                    label=label, marker='o', markersize=4, alpha=0.8, linewidth=line_width)
    
    ax_line.yaxis.set_major_formatter(PercentFormatter())

    # Bar plots
    for month in range(end_month - start_month):
        center = centers[month]
        curr_month = month + start_month
        
        # Calculate starting position for first bar in group
        start_x = center - total_bars_width/2
        
        for i, (method, diff) in enumerate(unique_pairs):
            idx = method_pairs.index((method, diff))
            x = start_x + i * bar_width
            
            # Get hatch pattern
            hatch_pattern = next((v for k, v in hatch_map.items() if k in diff), '')
            
            # Rejection bars
            if show_rejection:
                y = data_list[idx]["monthly_Rejections"][curr_month]
                bar = ax_bar.bar(
                    x, y, 
                    width=bar_width, 
                    color=method_color_map[(method, diff)],
                    hatch=hatch_pattern, 
                    edgecolor='white',
                    label=method  # Only include method name for first legend
                )
                
            # Get F1 values for efficiency and improvement
            rejections = data_list[idx]["monthly_Rejections"][curr_month]
            f1_before = data_list[idx]["monthly_F1_no_rejection"][curr_month]
            f1_after = data_list[idx]["monthly_F1"][curr_month]
            
            # Efficiency bars
            if show_efficiency and rejections > 0:
                efficiency = (f1_after - f1_before) / rejections * 100  # Convert to percentage
                bar = ax_eff.bar(x, efficiency, width=bar_width, 
                              color=method_color_map[(method, diff)],
                              hatch=hatch_pattern, edgecolor='white')
            elif show_efficiency:
                ax_eff.plot(x + bar_width/2, 0, 'kx', markersize=8)
            
            # Relative improvement bars
            if show_relative_improvement:
                improvement = (f1_after - f1_before) * 100  # Convert to percentage
                bar = ax_rel.bar(x, improvement, width=bar_width, 
                              color=method_color_map[(method, diff)],
                              hatch=hatch_pattern, edgecolor='white')
   
    # Formatting
    for ax in axes:
        ax.set_xlim(centers[0] - group_width/2, centers[-1] + group_width/2)
        ax.set_xticks(centers)
        ax.grid(True, alpha=0.3)
    
    # Add more padding to title to fix spacing issue
    if n_plots == 1:
        ax_line.set_title("Monthly F1 Scores", fontweight='bold', pad=14)  # Extra padding for single plot
    else:
        ax_line.set_title("Monthly F1 Scores", fontweight='bold', pad=8)
    ax_line.set_ylabel("F1 Score")
    ax_line.set_xticklabels([])
    
    if ax_bar is not None:
        ax_bar.set_title("Monthly Rejections", fontweight='bold', pad=8)
        ax_bar.set_ylabel("Rejections")
        ax_bar.set_xticklabels([])
        target_line = ax_bar.axhline(rejection_point, color='black', linestyle=':', linewidth=2)
        
        # Add target rejection label with better positioning
        _prfx = r"Target: $\rho$"
        ax_bar.text(
            centers[-1] + group_width/2 - 0.2, 
            rejection_point + 0.05 * (ax_bar.get_ylim()[1] - ax_bar.get_ylim()[0]),
            f"{_prfx} = {rejection_point}",
            verticalalignment='bottom', 
            horizontalalignment='right', 
            fontsize=12,
            bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
        )

    # Set x-axis labels on appropriate subplot
    if sum([show_efficiency, show_relative_improvement, show_rejection]) == 0:
        ax_line.set_xticklabels([f"Month {m}" for m in range(start_month+1, end_month+1)], rotation=45)

    if show_efficiency:
        ax_eff.yaxis.set_major_formatter(PercentFormatter())
        ax_eff.set_title("F1 Improvement per Rejection", fontweight='bold', pad=8)
        ax_eff.set_ylabel("ΔF1/Rejection")
        ax_eff.axhline(0, color='black', linestyle='-', linewidth=0.5)
        if not show_relative_improvement:
            ax_eff.set_xlabel("Month")
            ax_eff.set_xticklabels([f"Month {m}" for m in range(start_month+1, end_month+1)], rotation=45)
        else:
            ax_eff.set_xticklabels([])
    
    if show_relative_improvement:
        ax_rel.yaxis.set_major_formatter(PercentFormatter())
        ax_rel.set_title("Absolute F1 Improvement", fontweight='bold', pad=8)
        ax_rel.set_ylabel("ΔF1")
        ax_rel.axhline(0, color='black', linestyle='-', linewidth=0.5)
        ax_rel.set_xlabel("Month")
        ax_rel.set_xticklabels([f"{m}" for m in range(start_month+1, end_month+1)], rotation=0)

    # Get unique legend entries for methods
    if ax_bar is not None:
        handles, labels = ax_bar.get_legend_handles_labels()
    else:
        handles, labels = ax_line.get_legend_handles_labels()

    # Create a dictionary of unique method labels and handles
    method_dict = {}
    for label, handle in zip(labels, handles):
        if label not in method_dict:
            method_dict[label] = handle
    
    # Create custom line objects for the sampling methods legend
    diff_handles = []
    diff_labels = []
    # Create custom Line2D objects for each unique differentiator
    for diff in unique_diffs:
        line_style = diff_linestyle_map.get(diff, '-')
        
        # Get display name from mapping or use the original
        display_name = None
        for key, name in SAMPLER_NAME_MPR.items():
            if key in diff:
                display_name = name
                break
        if display_name is None:
            display_name = diff
            
        # Create line object
        diff_handles.append(Line2D([0], [0], color='black', linestyle=line_style))
        diff_labels.append(display_name)
    
    # Set a tight layout first to optimize subplot spacing
    plt.tight_layout()
    
    # Calculate legend heights based on content
    methods_legend_height = 0.05 if len(method_dict) <= 4 else 0.08
    diffs_legend_height = 0.05
    
    # For single plot case (performance only), place legends side by side
    if n_plots == 1:
        # Define the desired order for classifier methods
        method_order = ['SVC', 'DeepDrebin', 'CADE', 'HCC']  # Adjust this to your preferred order
        
        # Create ordered pairs of (handle, name) for methods
        ordered_method_items = []
        
        # First go through our preferred order
        for preferred in method_order:
            for method_name, method_handle in method_dict.items():
                if preferred in method_name:
                    ordered_method_items.append((method_handle, method_name))
        
        # Then add any remaining methods not in our preferred list
        for method_name, method_handle in method_dict.items():
            if not any(preferred in method_name for preferred in method_order):
                ordered_method_items.append((method_handle, method_name))
        
        # Unzip the pairs
        ordered_method_handles, ordered_method_names = zip(*ordered_method_items) if ordered_method_items else ([], [])
        
        # First legend (Classifier Method) on the left with ordered items
        method_legend = fig.legend(
            ordered_method_handles, 
            ordered_method_names,
            loc='upper center', 
            bbox_to_anchor=(0.3, 0.02), 
            ncol=min(2, len(method_dict)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Classifier Method"
        )
        
        # Second legend (Sampling Method) on the right - unchanged
        diff_legend = fig.legend(
            diff_handles, 
            diff_labels,
            loc='upper center', 
            bbox_to_anchor=(0.7, 0.02), 
            ncol=min(2, len(diff_handles)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Sampling Method"
        )
        
        # For single plot, ensure there's more space between title and plot
        plt.subplots_adjust(
            top=0.85,      # More space between title and plot
            bottom=0.22,   # Space for legends side by side
            right=0.98,
            left=0.08
        )
    else:
        # For multiple plots, stack legends vertically
        total_legend_height = methods_legend_height + diffs_legend_height + 0.03  # Add spacing between legends
        
        # Define the desired order for classifier methods (same as above)
        method_order = ['Drebin', 'DeepDrebin', 'CADE', 'HCC']  # Adjust this to your preferred order
        
        # Create ordered pairs of (handle, name) for methods
        ordered_method_items = []
        
        # First go through our preferred order
        for preferred in method_order:
            for method_name, method_handle in method_dict.items():
                if preferred in method_name:
                    ordered_method_items.append((method_handle, method_name))
        
        # deduplicate ordered_method_items
        ordered_method_items = list(dict.fromkeys(ordered_method_items))

        # Then add any remaining methods not in our preferred list
        for method_name, method_handle in method_dict.items():
            if not any(preferred in method_name for preferred in method_order):
                ordered_method_items.append((method_handle, method_name))
        
        # Unzip the pairs
        ordered_method_handles, ordered_method_names = zip(*ordered_method_items) if ordered_method_items else ([], [])
        

        methods_legend_height, diffs_legend_height = 0.05, 0.05
        total_legend_height = methods_legend_height + diffs_legend_height + 0.03
        LEGEND_Y = 0.02  # ↑ higher = closer to the sub-plot  ## MOD



        # Create a method legend (first legend) with ordered items
        method_legend = fig.legend(
            ordered_method_handles, 
            ordered_method_names,
            loc='lower center', 
            bbox_to_anchor=(0.5, LEGEND_Y + diffs_legend_height),
            ncol=min(4, len(method_dict)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Classifier Method"
        )
        
        # Create a sampling methods legend (second legend) - unchanged
        diff_legend = fig.legend(
            diff_handles, 
            diff_labels,
            loc='lower center', 
            bbox_to_anchor=(0.5, LEGEND_Y),
            ncol=min(3, len(diff_handles)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Sampling Method"
        )
        
        # Adjust layout to make room for the dual legends
        plt.subplots_adjust(
            top=0.92,                  # Leave room for the title
            bottom=0.08 + total_legend_height,  # ↓ tighter  ## MOD
            hspace=0.35,                # Space between subplots
            right=0.98,
            left=0.08
        )
    
    if figname is not None:
        # Save as pdf
        fig.savefig(figname, format='pdf', bbox_inches='tight')
    return fig, tuple(axes)





def plot_rejection_metrics_by_month_v2(
    results_list,
    label_budget: int,
    rejection_point: int,
    title: str = "",
    show_rejection: bool = True,
    show_efficiency: bool = False,
    show_relative_improvement: bool = True,
    month_range: list = None,
    method_name_mpr=None,
    fig_width=None,
    fig_height_per_subplot: int = 3,
    publication_style: bool = True,
    figname=None,
):
    """Monthly rejection metrics plot — tighter legend spacing."""
    # ───────────────────────────── Style ─────────────────────────────
    if publication_style:
        plt.style.use('seaborn-v0_8-whitegrid')
        plt.rcParams.update(
            {
                "font.family": "sans-serif",
                "font.sans-serif": ["Arial", "DejaVu Sans", "Liberation Sans"],
                "font.size": 12,
                "axes.titlesize": 12,
                "axes.labelsize": 12,
                "xtick.labelsize": 12,
                "ytick.labelsize": 12,
                "legend.fontsize": 12,
                "figure.dpi": 300,
            }
        )

    # ----------------------------------------------------------------
    if method_name_mpr is None:
        method_name_mpr = lambda x: x

    if rejection_point == 0:
        show_efficiency = show_relative_improvement = show_rejection = False

    filtered1 = [e for e in results_list if label_budget in e[0]]
    if not filtered1:
        raise ValueError(f"Label Budget {label_budget} not in results_list")
    filtered2 = [e for e in filtered1 if rejection_point in e[1]]
    if not filtered2:
        raise ValueError(f"Rejection point {rejection_point} not present")

    methods, diffs, data_list = [], [], []
    for entry in filtered2:
        _, diff, mth = entry[0]
        methods.append(method_name_mpr(mth))
        diffs.append(diff)
        data_list.append(entry[1][rejection_point])

    n_months = len(data_list[0]["monthly_Rejections"])

    # ---------------- Month range ----------------
    sm = (month_range[0] - 1) if month_range and month_range[0] else 0
    em = month_range[1] if month_range and len(month_range) > 1 and month_range[1] else n_months
    start_month = max(0, min(sm, n_months - 1))
    end_month = max(start_month + 1, min(em, n_months))

    method_pairs = list(zip(methods, diffs))
    unique_pairs = sorted(set(method_pairs))
    n_pairs = len(unique_pairs)

    n_plots = 1 + sum([show_rejection, show_efficiency, show_relative_improvement])
    width = max(12, (end_month - start_month) * 0.8) if fig_width is None else fig_width
    fig_height = max(5, fig_height_per_subplot) if n_plots == 1 else fig_height_per_subplot * n_plots

    fig, axes = plt.subplots(n_plots, 1, figsize=(width, fig_height))
    if not isinstance(axes, np.ndarray):
        axes = [axes]

    ax_line = axes[0]
    ax_bar = axes[1] if show_rejection else None
    ax_eff = axes[2] if show_efficiency and show_relative_improvement else None
    ax_rel = axes[-1] if show_relative_improvement else None

    color_map = create_distinct_color_map(set(methods))
    method_color_map = {(m, d): color_map[m] for m, d in method_pairs}

    base_ls = {
        "full_first_year_subsample_months": "-",
        "weight_first_year_subsample_months": ":",
        "subsample_first_year_subsample_months": "--",
    }
    diff_ls = {d: next((s for k, s in base_ls.items() if k in d), "-") for d in set(diffs)}
    unique_diffs = sorted(set(diffs))
    hatch_map = {"full_first": "", "subsample_first": "//", "weight_first": "*"}

    group_width, month_gap = 1.0, 0.15
    total_group_space = group_width + month_gap
    bar_width = group_width / (n_pairs + 1)
    total_bars_width = n_pairs * bar_width
    centers = group_width / 2 + np.arange(end_month - start_month) * total_group_space

    # ---------------- Plotting ----------------
    for m in range(1, end_month - start_month):
        sep_x = (centers[m - 1] + centers[m]) / 2
        for ax in axes:
            ax.axvline(sep_x, color="black", linewidth=0.5, alpha=0.3)

    for method, diff in unique_pairs:
        idx = method_pairs.index((method, diff))
        data = data_list[idx]["monthly_F1"][start_month:end_month] * 100
        ax_line.plot(
            centers,
            data,
            color=method_color_map[(method, diff)],
            linestyle=diff_ls[diff],
            label=method,
            marker="o",
            markersize=4,
            alpha=0.8,
            linewidth=(2.0 if n_plots == 1 else 1.5),
        )
    ax_line.yaxis.set_major_formatter(PercentFormatter())

    for m in range(end_month - start_month):
        center = centers[m]
        curr = m + start_month
        start_x = center - total_bars_width / 2

        for i, (method, diff) in enumerate(unique_pairs):
            idx = method_pairs.index((method, diff))
            x = start_x + i * bar_width
            hatch = next((v for k, v in hatch_map.items() if k in diff), "")

            if show_rejection:
                y = data_list[idx]["monthly_Rejections"][curr]
                ax_bar.bar(x, y, width=bar_width, color=method_color_map[(method, diff)],
                           hatch=hatch, edgecolor="white", label=method)

            rejs = data_list[idx]["monthly_Rejections"][curr]
            f1b = data_list[idx]["monthly_F1_no_rejection"][curr]
            f1a = data_list[idx]["monthly_F1"][curr]

            if show_efficiency and rejs > 0:
                eff = (f1a - f1b) / rejs * 100
                ax_eff.bar(x, eff, width=bar_width, color=method_color_map[(method, diff)],
                           hatch=hatch, edgecolor="white")
            elif show_efficiency:
                ax_eff.plot(x + bar_width / 2, 0, "kx", markersize=8)

            if show_relative_improvement:
                imp = (f1a - f1b) * 100
                ax_rel.bar(x, imp, width=bar_width, color=method_color_map[(method, diff)],
                           hatch=hatch, edgecolor="white")

    for ax in axes:
        ax.set_xlim(centers[0] - group_width / 2, centers[-1] + group_width / 2)
        ax.set_xticks(centers)
        ax.grid(True, alpha=0.3)

    ax_line.set_title("Monthly F1 Scores", fontweight="bold", pad=14 if n_plots == 1 else 8)
    ax_line.set_ylabel("F1 Score")
    ax_line.set_xticklabels([])

    if ax_bar is not None:
        ax_bar.set_title("Monthly Rejections", fontweight="bold", pad=8)
        ax_bar.set_ylabel("Rejections")
        ax_bar.set_xticklabels([])
        # τ_rej → ρ
        _prfx = r"Target: $\rho$"
        ax_bar.axhline(rejection_point, color="black", linestyle=":", linewidth=2)
        ax_bar.text(
            centers[-1] + group_width / 2 - 0.2,
            rejection_point + 0.05 * (ax_bar.get_ylim()[1] - ax_bar.get_ylim()[0]),
            f"{_prfx} = {rejection_point}",
            va="bottom",
            ha="right",
            fontsize=12,
            bbox=dict(facecolor="white", alpha=0.7, edgecolor="none", pad=2),
        )

    if show_efficiency:
        ax_eff.yaxis.set_major_formatter(PercentFormatter())
        ax_eff.set_title("F1 Improvement / Rejection", fontweight="bold", pad=8)
        ax_eff.set_ylabel("ΔF1 / Rejection")
        ax_eff.axhline(0, color="black", linewidth=0.5)
        if not show_relative_improvement:
            ax_eff.set_xlabel("Month")
            ax_eff.set_xticklabels([f"Month {m}" for m in range(start_month + 1, end_month + 1)], rotation=45)
        else:
            ax_eff.set_xticklabels([])

    if show_relative_improvement:
        ax_rel.yaxis.set_major_formatter(PercentFormatter())
        ax_rel.set_title("Absolute F1 Improvement", fontweight="bold", pad=8)
        ax_rel.set_ylabel("ΔF1")
        ax_rel.axhline(0, color="black", linewidth=0.5)
        ax_rel.set_xlabel("Month")
        ax_rel.set_xticklabels([f"{m}" for m in range(start_month + 1, end_month + 1)], rotation=0)

    # ---------------- Legends (tighter) ----------------
    handles, labels = (ax_bar or ax_line).get_legend_handles_labels()
    method_dict = {lab: h for lab, h in zip(labels, handles) if lab not in locals().get("method_dict", {})}

    diff_handles = [
        Line2D([0], [0], color="black", linestyle=diff_ls[d]) for d in unique_diffs
    ]
    diff_labels = [
        next((name for k, name in SAMPLER_NAME_MPR.items() if k in d), d) for d in unique_diffs
    ]

    plt.tight_layout()
    fig.tight_layout(rect=(0, 0.03, 1, 0.97))

    method_order = ["SVC", "DeepDrebin", "CADE", "HCC"]
    ordered = [*(h for mo in method_order for n, h in method_dict.items() if mo in n),
               *(h for n, h in method_dict.items() if not any(mo in n for mo in method_order))]
    ordered_names = [n for n in method_dict if method_dict[n] in ordered]

    # ##############  SPACING CHANGES ##############
    LEGEND_Y = 0.04  # ↑ higher = closer to the sub-plot  ## MOD

    if n_plots == 1:
        fig.legend(
            ordered,
            ordered_names,
            loc="upper center",
            bbox_to_anchor=(0.3, LEGEND_Y),
            ncol=min(2, len(method_dict)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Classifier Method",
        )
        fig.legend(
            diff_handles,
            diff_labels,
            loc="upper center",
            bbox_to_anchor=(0.7, LEGEND_Y),
            ncol=min(2, len(diff_handles)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Sampling Method",
        )
        plt.subplots_adjust(top=0.85, bottom=0.14, right=0.98, left=0.08)  # ↓ tighter  ## MOD
    else:
        methods_legend_height, diffs_legend_height = 0.05, 0.05
        total_legend_height = methods_legend_height + diffs_legend_height + 0.03

        fig.legend(
            ordered,
            ordered_names,
            loc="lower center",
            bbox_to_anchor=(0.5, LEGEND_Y + diffs_legend_height),
            ncol=min(4, len(method_dict)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Classifier Method",
        )
        fig.legend(
            diff_handles,
            diff_labels,
            loc="lower center",
            bbox_to_anchor=(0.5, LEGEND_Y),
            ncol=min(3, len(diff_handles)),
            frameon=True,
            framealpha=0.9,
            fontsize=9,
            title="Sampling Method",
        )
        plt.subplots_adjust(
            top=0.92,
            bottom=0.08 + total_legend_height,  # ↓ tighter  ## MOD
            hspace=0.35,
            right=0.98,
            left=0.08,
        )

    if figname is not None:
        fig.savefig(figname, format="pdf", bbox_inches="tight")

    return fig, tuple(axes)


################################
### Plot Metrics vs. Budgets ###
################################


def plot_metric_by_budget(
        results_list, 
        metric='F1', 
        fig_width=None, 
        fig_height=4, 
        ylim=None, 
        use_markers=False,
        add_legend:bool = False,
        title:str = None,
        # symbols
        tau_rej=r'\rho', 
        tau=r"B_{M_i}", 

        method_name_mpr:lambda x: x = lambda x: x, 
        figname=None
):
    """
    Create publication-quality plot for a specific metric, with each subplot representing a different label budget.
    
    Parameters:
    -----------
    results_list : list
        List of result tuples [(key, data), ...] where key is (budget, differentiator, classifier)
    metric : str, default='F1'
        Metric to plot (F1, FNR, or FPR)
    fig_width : float, default=None
        Width of the figure in inches (if None, calculated based on number of subplots)
    fig_height : float, default=4
        Height of the figure in inches
    ylim : tuple, default=None
        Custom y-axis limits (min, max) as percentages
    use_markers : bool, default=True
        Whether to use markers for data points
    
    Returns:
    --------
    fig, axes : matplotlib figure and axes objects
    """
    # Apply publication-friendly style
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans'],
        'font.size': 11,
        'axes.titlesize': 12,
        'axes.labelsize': 11,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 12,
        'figure.dpi': 300,
    })
    
    # Group data by budget
    grouped_data = {}
    unique_classifiers = set()
    unique_differentiators = set()
    
    for result in results_list:
        try:
            budget = result[0][0]  # Get the budget (50, 100, 200)
            differentiator = result[0][1]  # Get the sampling method
            classifier_name = result[0][2]  # Get the classifier name
            if method_name_mpr is not None:
                classifier_name = method_name_mpr(classifier_name)
            
            unique_classifiers.add(classifier_name)
            unique_differentiators.add(differentiator)
            
            if budget not in grouped_data:
                grouped_data[budget] = []
                
            grouped_data[budget].append((differentiator, classifier_name, result[1]))
        except (IndexError, TypeError) as e:
            print(f"Warning: Error processing result: {e}")
            continue
    
    color_map = create_distinct_color_map(unique_classifiers)
    
    # Define line styles for different differentiators

    # Ensure all differentiators have a line style
    diff_linestyle_map = {}
    line_styles = ['-', '--', ':', '-.']
    for i, diff in enumerate(sorted(unique_differentiators)):
        if diff in SAMPLER_LINE_MPR:
            diff_linestyle_map[diff] = SAMPLER_LINE_MPR[diff]
        else:
            diff_linestyle_map[diff] = line_styles[i % len(line_styles)]
    
    # Markers for different classifiers
    markers = ['o', 's', 'D', '^', 'v', '<', '>', 'p']
    marker_map = dict(zip(sorted(unique_classifiers), markers[:len(unique_classifiers)]))
    
    # Sort budgets for consistent presentation
    budgets = sorted(grouped_data.keys())
    
    # Calculate subplot layout
    num_plots = len(budgets)
    if fig_width is None:
        fig_width = min(12, 3 * num_plots)  # 3 inches per subplot, max 12 inches
    
    # Create figure and subplots
    fig, axes = plt.subplots(1, num_plots, figsize=(fig_width, fig_height))
    
    # Make axes iterable for single subplot case
    if num_plots == 1:
        axes = [axes]
    
    # Collect all y values to determine global y limits
    all_y_values = []
    
    # First pass to collect y values for consistent scaling

    for budget in budgets:
        for diff, classifier_name, data in grouped_data[budget]:
            try:
                x_values = sorted([int(k) for k in data.keys()])
                y_values = [data[x].get(metric, 0)*100 for x in x_values]
                all_y_values.extend([y for y in y_values if not np.isnan(y)])
            except (KeyError, TypeError, AttributeError) as e:
                print(f"Warning: Error processing {budget}, {diff}, {classifier_name}: {e}")
    
    # Calculate global y limits if not specified
    if ylim is None and all_y_values:
        y_min = max(0, np.floor(min(all_y_values) - 5))
        y_max = min(100, np.ceil(max(all_y_values) + 5))
        global_ylim = (y_min, y_max)
    else:
        global_ylim = ylim
    
    # Track all lines for the legend
    classifier_lines = []
    classifier_names = []
    diff_lines = []
    diff_names = []
    

    # Plot data for each budget
    for idx, budget in enumerate(budgets):
        ax = axes[idx]
        
        # Plot lines for each classifier and differentiator
        for diff, classifier_name, data in grouped_data[budget]:
            # Get display name from mapping or use the original
            display_name = None
            for key, name in SAMPLER_NAME_MPR.items():
                if key in diff:
                    display_name = name
                    break
            if display_name is None:
                display_name = diff
            try:
                # Extract x and y values
                x_values = sorted([int(k) for k in data.keys()])
                y_values = [data[x].get(metric, 0)*100 for x in x_values]
                
                # Skip if no valid data
                if not x_values or not y_values:
                    continue
                
                # Get line style and color
                linestyle = diff_linestyle_map.get(diff, '-')
                color = color_map.get(classifier_name, 'black')
                
                # Plot with optional markers
                plot_kwargs = {
                    'color': color,
                    'linestyle': linestyle,
                    'linewidth': 1.5
                }
                
                if use_markers:
                    plot_kwargs['marker'] = marker_map.get(classifier_name, 'o')
                    plot_kwargs['markersize'] = 5
                
                line = ax.plot(x_values, y_values, **plot_kwargs)
                
                # Store line info for legend (only once)
                if idx == 0:
                    if classifier_name not in classifier_names:
                        classifier_lines.append(line[0])
                        classifier_names.append(classifier_name)
                        
                    if display_name not in diff_names:
                        # Create custom line for differentiator legend
                        diff_line = Line2D([0], [0], color='black', linestyle=linestyle)
                        diff_lines.append(diff_line)
                        diff_names.append(display_name)
            except (KeyError, TypeError, AttributeError) as e:
                print(f"Warning: Error plotting {budget}, {diff}, {classifier_name}: {e}")
        
        # Set axis labels and title
        ax.set_title(f'${tau}={budget}$', fontweight='bold')
        ax.set_xlabel(f'${tau_rej}$')
        
        # Only add y-axis label to the first subplot
        if idx == 0:
            ax.set_ylabel(f'{metric} Score')
        
        # Format y-axis as percentage
        ax.yaxis.set_major_formatter(PercentFormatter(decimals=0 if metric != "FPR" else 2))
        
        # Set consistent y limits
        if global_ylim:
            ax.set_ylim(global_ylim)
        
        # Add grid
        ax.grid(True, linestyle='--', alpha=0.3)
        
        # Set aspect ratio
        ax.set_box_aspect(0.8)


    if add_legend:
        # Grouping
        first_column_names = ['DeepDrebin', 'Drebin']
        ordered_classifier_items = []
        used_indices = set()

        # 1. Add DeepDrebin and Drebin first (each once)
        for preferred in first_column_names:
            for i, name in enumerate(classifier_names):
                if i not in used_indices and preferred in name:
                    ordered_classifier_items.append((classifier_lines[i], name))
                    used_indices.add(i)
                    break

        # 2. Add remaining classifiers
        for i, name in enumerate(classifier_names):
            if i not in used_indices:
                ordered_classifier_items.append((classifier_lines[i], name))
                used_indices.add(i)

        # Unpack
        lines_all, names_all = zip(*ordered_classifier_items) if ordered_classifier_items else ([], [])

        # Layout calculation
        drebin_cols = 1
        rest_items = len(names_all) - drebin_cols * 2
        rest_cols = math.ceil(rest_items / 2)
        total_cols = drebin_cols + rest_cols

        # Final unified legend
        classifier_legend = fig.legend(
            lines_all, names_all,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.05),
            ncol=total_cols,
            columnspacing=1.5,
            handletextpad=0.8,
            #labelspacing=1.2,  # Important to stack vertically
            frameon=True,
            framealpha=0.9,
            title="Classifier"
        )

        # Sampling method legend
        diff_legend = fig.legend(
            diff_lines, diff_names,
            loc='lower center',
            bbox_to_anchor=(0.5, -0.22),
            ncol=len(diff_names),
            frameon=True,
            framealpha=0.9,
            title="Sampling Method"
        )

        fig.add_artist(classifier_legend)



    
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')
    # Adjust layout for the legends
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.35)  # Make room for legends
    
    if figname:
        # Save as pdf
        fig.savefig(figname, format='pdf', bbox_inches='tight', dpi=300)
    return fig, axes


############################
### Risk-Coverage Curves ###
############################


# Reuse our optimized vectorized quantile function
def values_to_quantiles_vectorized(arr):
    """Convert array of values to their quantile distribution using vectorization."""
    arr = np.asarray(arr)
    # Use argsort twice to efficiently compute ranks
    ranks = np.argsort(np.argsort(arr)).astype(float) / (len(arr) - 1)
    
    # Handle case where all values are the same
    if np.all(arr == arr[0]):
        return np.ones_like(arr, dtype=float)
    
    return ranks

def select_first_from_multiseed(items):
    """Selects the first trial from a multiseeded experiment"""
    return [x[0] for x in items]

def compute_risk_coverage_curve(labels, predictions, uncertainties, is_multiseed, force_start_zero=True):
    """
    Efficiently compute a risk-coverage curve for a single method.
    
    Returns:
    - coverage: array of coverage values
    - risk: array of risk values
    """
    # If multiple seeds, select the first one
    if is_multiseed:
        labels = select_first_from_multiseed(labels)
        predictions = select_first_from_multiseed(predictions)
        uncertainties = select_first_from_multiseed(uncertainties)
    
    # Concatenate all arrays
    all_labels = np.concatenate(labels)
    all_predictions = np.concatenate(predictions)
    all_uncertainties = np.concatenate(uncertainties)
    
    # Determine correct/incorrect predictions
    correct_mask = all_labels == all_predictions
    incorrect_mask = ~correct_mask
    
    # Convert uncertainties to quantiles using our optimized function
    all_uncertainties = values_to_quantiles_vectorized(all_uncertainties)
    
    # Compute optimal number of bins
    n_total = len(all_uncertainties)
    n_bins = min(100, max(20, int(np.sqrt(n_total))))
    bins = np.linspace(0, 1, n_bins + 1)
    
    # Compute histograms in one pass
    hist_correct, _ = np.histogram(all_uncertainties[correct_mask], bins=bins)
    hist_incorrect, _ = np.histogram(all_uncertainties[incorrect_mask], bins=bins)
    
    # Vectorized computation of cumulative statistics
    cum_correct = np.cumsum(hist_correct)
    cum_incorrect = np.cumsum(hist_incorrect)
    cum_total = cum_correct + cum_incorrect
    
    # Vectorized computation of risk and coverage
    with np.errstate(divide='ignore', invalid='ignore'):
        risk = np.divide(cum_incorrect, cum_total)
    risk = np.nan_to_num(risk)  # Replace NaNs with 0s
    coverage = cum_total / n_total
    
    # Ensure all curves start at (0,0) if requested
    if force_start_zero and coverage[0] > 0.01:
        coverage = np.concatenate([[0], coverage])
        risk = np.concatenate([[0], risk])
    
    return coverage, risk


def plot_risk_coverage_curves_optimized_v3(
        dataset_dict, 
        label_budget=None,
        method_name_mpr=None, 
        figsize=None, 
        force_start_zero=True, 
        title=None, 
        legend_max_width=0.9,
        n_jobs=4,
        subplot_width=5,
        subplot_height=4,
        ylim=None,
        legend_fontsize=14,
        legend_spacing=0.8,  # Increased spacing between legends
        show_legend=True, 
        figname=None, 
):
    """
    Plot risk-coverage curves with one subplot per dataset and shared legends.
    Legend optimizations:
    - Limited to 3 methods per row
    - Method names shown in full without abbreviation
    - Improved spacing between plots and legends
    
    Parameters:
    - dataset_dict: dict
        Dictionary where keys are dataset names and values are DataFrames.
        Each DataFrame is expected to have a MultiIndex with (label_budget, sampler_method, method_name).
    - label_budget: int or None, default=None
        If provided, only plot curves for this specific label budget.
        If None, plot curves for all label budgets.
    - method_name_mpr: function, optional
        A function to postprocess method names for display.
    - figsize: tuple, optional
        Overall figure size. If None, will be calculated based on number of datasets.
    - force_start_zero: bool, default=True
        Force all curves to start at (0,0) point.
    - title: str, optional
        Optional title for the entire figure.
    - legend_max_width: float, default=0.9
        Maximum width of the legend as a fraction of figure width.
    - n_jobs: int, default=4
        Number of parallel workers for curve computation.
    - subplot_width: float, default=5
        Width of each subplot in inches.
    - subplot_height: float, default=4
        Height of each subplot in inches.
    - ylim: tuple, optional
        Y-axis limits for all subplots (min, max). If None, auto-determined.
    - legend_fontsize: int, default=12
        Font size for legend text.
    - legend_spacing: float, default=0.8
        Spacing between the two legends.
        
    Returns:
    - fig: matplotlib Figure containing multiple subplots and shared legends.
    """
    # Apply publication-friendly styling
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['Arial', 'DejaVu Sans', 'Liberation Sans'],
        'font.size': 11,
        'axes.titlesize': 12,
        'axes.labelsize': 11,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': legend_fontsize,
        'figure.dpi': 300
    })
    
    # Filter datasets by label budget if specified
    filtered_dataset_dict = {}
    
    for dataset_name, df in dataset_dict.items():
        if df.empty:
            continue
            
        # Check if this dataset has a MultiIndex
        if isinstance(df.index, pd.MultiIndex):
            # For MultiIndex, filter by label_budget if specified
            if label_budget is not None:
                # Handle different index structures
                if df.index.nlevels >= 3:
                    # Assuming structure: (label_budget, sampler_method, method_name)
                    filtered_df = df.xs(label_budget, level=0, drop_level=False)
                elif df.index.nlevels == 2:
                    # Check if first level might be label_budget
                    if label_budget in df.index.get_level_values(0):
                        filtered_df = df.xs(label_budget, level=0, drop_level=False)
                    else:
                        # Can't filter by label_budget with this structure
                        filtered_df = df
                else:
                    # Single level index, can't filter by label_budget
                    filtered_df = df
            else:
                # Use all data
                filtered_df = df
        else:
            # Regular index, no filtering by label_budget
            filtered_df = df
            
        # Only include non-empty dataframes
        if not filtered_df.empty:
            filtered_dataset_dict[dataset_name] = filtered_df
    
    # Exit if no data after filtering
    if not filtered_dataset_dict:
        print("No data to plot after filtering by label budget")
        return plt.figure()
    
    # Calculate overall figure size if not provided
    n_datasets = len(filtered_dataset_dict)
    if figsize is None:
        total_width = subplot_width * n_datasets
        # Allocate more space for legends with full method names
        figsize = (total_width, subplot_height + 3)  # Increased height for legends
    
    # Create figure with gridspec for better control over layout
    fig = plt.figure(figsize=figsize, constrained_layout=False)
    
    # Set up gridspec: main area for plots and bottom area for legends
    # Increased space for legends with height_ratios[1] value
    gs = gridspec.GridSpec(2, 1, height_ratios=[3, 2], figure=fig, hspace=0.4)  # More space for legends
    
    # Create subfigure for the main plots area
    subfig_plots = fig.add_subfigure(gs[0, 0])
    
    # Create a grid of subplots within the subfigure
    axs = subfig_plots.subplots(1, n_datasets, squeeze=False)[0]
    
    # Pre-process to find all unique methods and samplers across all datasets
    unique_methods = set()
    unique_samplers = set()
    
    for dataset_name, dataset_slice in filtered_dataset_dict.items():
        for index, row in dataset_slice.iterrows():
            # Extract sampler_method and method_name based on index structure
            if isinstance(index, tuple) and len(index) >= 2:
                # Handle different MultiIndex levels
                if len(index) >= 3:
                    # (label_budget, sampler_method, method_name)
                    sampler_method, method_name = index[1], index[2]
                else:
                    # (sampler_method, method_name) or (label_budget, method_name)
                    sampler_method, method_name = index[0], index[1]
            else:
                # Single index
                sampler_method = "default"
                method_name = index
                
            if method_name_mpr is not None:
                method_name = method_name_mpr(method_name)
            
            unique_methods.add(method_name)
            unique_samplers.add(sampler_method)
    
    # Create color mapping for methods
    #method_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_methods)))
    #method_color_map = dict(zip(sorted(unique_methods), method_colors))

    method_color_map = create_distinct_color_map(unique_methods)

    
    # Ensure all differentiators have a line style
    diff_linestyle_map = {}
    line_styles = ['-', '--', ':', '-.']
    for i, diff in enumerate(sorted(unique_samplers)):
        if diff in SAMPLER_LINE_MPR:
            diff_linestyle_map[diff] = SAMPLER_LINE_MPR[diff]
        else:
            diff_linestyle_map[diff] = line_styles[i % len(line_styles)]
    
    # Prepare for storing legend entries
    method_lines = []
    method_names = []
    sampler_lines = []
    sampler_names = []
    
    # Process each dataset in parallel
    dataset_tasks = []
    
    for dataset_idx, (dataset_name, dataset_slice) in enumerate(filtered_dataset_dict.items()):
        dataset_tasks.append((dataset_idx, dataset_name, dataset_slice))
    
    # Process datasets one by one (for better organization of code)
    for dataset_idx, dataset_name, dataset_slice in dataset_tasks:
        ax = axs[dataset_idx]
        
        # Set subplot title to dataset name
        ax.set_title(dataset_name + " Dataset", fontsize=12, fontweight='bold')

        ax.set_ylim(0, 0.1)
        
        # Collect curve computation tasks for this dataset
        tasks = []
        task_info = []
        
        for index, row in dataset_slice.iterrows():
            # Extract sampler_method and method_name based on index structure
            if isinstance(index, tuple) and len(index) >= 2:
                if len(index) >= 3:
                    sampler_method, method_name = index[1], index[2]
                else:
                    sampler_method, method_name = index[0], index[1]
            else:
                sampler_method = "default"
                method_name = index
                
            if method_name_mpr is not None:
                method_name = method_name_mpr(method_name)
            
            # Add task to list if row has the required columns
            if all(col in row.index for col in ["Labels", "Predictions", "Uncertainties (Month Ahead)", "is_multiseed"]):
                tasks.append((
                    row["Labels"],
                    row["Predictions"],
                    row["Uncertainties (Month Ahead)"],
                    row["is_multiseed"],
                    force_start_zero
                ))
                task_info.append({
                    'dataset_name': dataset_name,
                    'sampler_method': sampler_method,
                    'method_name': method_name,
                    'linestyle': diff_linestyle_map.get(sampler_method, '-'),
                    'color': method_color_map.get(method_name),
                })
        
        # Process tasks in parallel
        curve_results = []
        
        with ThreadPoolExecutor(max_workers=n_jobs) as executor:
            futures = [executor.submit(compute_risk_coverage_curve, *task) for task in tasks]
            
            for i, future in enumerate(futures):
                try:
                    coverage, risk = future.result()
                    curve_results.append((coverage, risk))
                except Exception as e:
                    print(f"Error computing curve for {task_info[i]['dataset_name']} - {task_info[i]['method_name']}: {e}")
                    curve_results.append((np.array([0, 1]), np.array([0, 0])))
        
        # Plot curves for this dataset
        for i, (coverage, risk) in enumerate(curve_results):
            info = task_info[i]
            
            # Create label based on method and sampler
            if "full_first_year" in info['sampler_method']:
                # For default sampler, just show method name
                label = info['method_name']
            else:
                # For other samplers, show method and sampler
                sampler_display = ""
                for key, display in SAMPLER_NAME_MPR.items():
                    if key in info['sampler_method']:
                        sampler_display = f" ({display})"
                        break
                label = f"{info['method_name']}{sampler_display}"
            
            # Plot the curve
            line = ax.plot(
                coverage, risk,
                linewidth=2,
                label=label,
                color=info['color'],
                linestyle=info['linestyle'],
                alpha=0.8
            )
            
            # Store unique methods and samplers for legend
            if info['method_name'] not in method_names:
                method_names.append(info['method_name'])
                method_lines.append(Line2D([0], [0], color=info['color'], lw=2))
            
            # Store unique samplers for legend
            sampler_key = next((key for key in SAMPLER_NAME_MPR if key in info['sampler_method']), info['sampler_method'])
            sampler_display = SAMPLER_NAME_MPR.get(sampler_key, sampler_key)
            if sampler_key not in sampler_names and sampler_key in SAMPLER_NAME_MPR:
                sampler_names.append(sampler_key)
                sampler_lines.append(Line2D([0], [0], color='black', 
                                            linestyle=diff_linestyle_map.get(sampler_key, '-')))
        
        # Format this subplot
        ax.set_xlim(0, 1)
        if ylim:
            ax.set_ylim(ylim)
        else:
            ax.set_ylim(bottom=0)
        
        # Format axes as percentages
        ax.xaxis.set_major_formatter(PercentFormatter(1.0))
        ax.yaxis.set_major_formatter(PercentFormatter(1.0, decimals=1))
        
        # Labels for only the leftmost plot to avoid redundancy
        if dataset_idx == 0:
            ax.set_ylabel("Risk (Error Rate)", fontsize=11)
        
        # X-axis label for all plots
        ax.set_xlabel("Coverage", fontsize=11)
        
        # Add grid
        ax.grid(True, linestyle='--', alpha=0.3)
    
    # Add a shared title if specified
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')
    elif label_budget is not None:
        denot = r"$B_{M_i}$"
        fig.suptitle(rf"Risk–Coverage Curves ({denot}$={label_budget}$)", 
                     fontsize=14, fontweight='bold')
    
    # Adjust spacing between subplots
    # Increase bottom margin to leave more space for the legends
    plt.tight_layout(rect=[0, 0.35, 1, 0.95])  # Increased bottom margin for legends
    
    # Create subfigure for legends
    subfig_legends = fig.add_subfigure(gs[1, 0])
    
    # Split the legend area into two parts with more spacing between them
    legend_gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], figure=subfig_legends, hspace=legend_spacing)
    

    # Create axes for legends
    if method_lines and show_legend:
        ax_method_legend = subfig_legends.add_subplot(legend_gs[0, 0])
        ax_method_legend.axis('off')  # Hide the axes
        
        # Define the desired order
        method_order = ['Drebin', 'DeepDrebin', 'CADE', 'HCC']
        
        # Create ordered pairs of (line, name)
        ordered_items = []
        
        # First go through our preferred order
        for preferred in method_order:
            for i, method_name in enumerate(method_names):
                if preferred in method_name:
                    ordered_items.append((method_lines[i], method_name))

        # deduplicate method names if they appear multiple times
        seen_names = set()
        ordered_items = [(line, name) for line, name in ordered_items if name not in seen_names and not seen_names.add(name)]
        
        # Then add any remaining methods not in our preferred list
        for i, method_name in enumerate(method_names):
            if not any(preferred in method_name for preferred in method_order):
                ordered_items.append((method_lines[i], method_name))
        
        # Unzip the pairs
        ordered_lines, display_method_names = zip(*ordered_items) if ordered_items else ([], [])
        
        # Method legend - position at the top of the dedicated area
        # Limit to 3 methods per row
        method_ncol = min(4, len(method_names))
        method_legend = ax_method_legend.legend(
            ordered_lines, display_method_names,
            loc='center',
            ncol=method_ncol,
            frameon=True,
            framealpha=0.9,
            title="Method",
            fontsize=legend_fontsize,
            bbox_to_anchor=(0.5, 4.)  # Adjusted position
        )
        
        # Allow the legend to have multiple rows by adjusting the legend handler map
        method_legend._ncol = method_ncol
    
    if sampler_lines and show_legend:
        ax_sampler_legend = subfig_legends.add_subplot(legend_gs[1, 0])
        ax_sampler_legend.axis('off')  # Hide the axes
        
        # Sampler legend - position at the bottom of the dedicated area
        display_sampler_names = [SAMPLER_NAME_MPR.get(name, name) for name in sampler_names]
        sampler_ncol = min(6, len(sampler_names))  # Keep the original column count for sampler legend
        sampler_legend = ax_sampler_legend.legend(
            sampler_lines, display_sampler_names,
            loc='center',
            ncol=sampler_ncol,
            frameon=True,
            framealpha=0.9,
            title=r"Sampling of $\mathcal{D}_0$",
            fontsize=legend_fontsize,
            bbox_to_anchor=(0.5, 4.3)  # Adjusted position
        )
    if figname is not None: 
        # Save as PDF
        fig.savefig(figname, format='pdf', bbox_inches='tight', dpi=300)
    return fig