import numpy as np
import os 
from collections import defaultdict
from typing import Dict, Optional
from typing import Union, List, Tuple, Dict, Any
from functools import partial
from sklearn.model_selection import KFold, StratifiedKFold
from tqdm.notebook import tqdm
import pandas as pd
import pickle
import datetime as dt
from dateutil.relativedelta import relativedelta
import json
import torch 
from copy import deepcopy
from tqdm.notebook import tqdm
from sklearn.svm import LinearSVC as _LinearSCV
import numpy as np
from kneed import KneeLocator

from torch.utils.data import DataLoader, TensorDataset
from torch import nn
import torch.nn.functional as F
import os
from scipy.sparse import csr_matrix

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from uda import UDATrainer
from augmentations import TransformFixMatchCSR, TransformFixMatchCSR_ApplyBatch, TransformTensorDataset


##################################
### Additional Baseline Models ###
##################################


class LinearSVC(_LinearSCV):
    def get_ncms(self, X):
        """
        Compute the SVM-NCM as in Transcendent [1]:
        
        The decision function returns the distance to the SVM hyperplane:
        * If value < 0, the sample is classified as class 0
        * If value > 0, the sample is classified as class 1

        The NCM expresses the following:
        - The closer a point is to the hyperplane, the more uncertain the model is about the prediction.
        - The further a point is from the hyperplane, the more certain the model is about the prediction.

        To achieve this for both "sides", the following transformation is conducted:

        (1) Compute the absolute value for the decision function
        (2) Negate the scores, s.t. higher values become smaller negative values (further away from the hyperplane)
        """
        decision_function_out = self.decision_function(X).ravel()
        # Negative values indicate the negative class, positive values the positive class

        ood_scores_raw = np.copy(decision_function_out)
        # Negate: Smaller values are further away from the hyperplane (higher certainty)
        single_score_uncertainty = -1. * np.abs(ood_scores_raw)

        # Predictions: If value is negative 0, if value is positive 1
        predictions = np.where(decision_function_out < 0, 0, 1)

        return predictions, single_score_uncertainty
    
class DrebinMLP(nn.Module):
    """
    Network architecture used by Grosse et al. in the paper
    'Adversarial Examples for Malware Detection'

    Modifications (don't change the architecture, only the definition): 
    * Splitting of the layers into backbone and classifier
    * Variable output size; target to project to a large-dim output space and train.
    """
    def __init__(
            self, 
            input_size,
            output_size=2
    ):
        super(DrebinMLP, self).__init__()
        self.backbone = nn.Sequential(
            nn.Linear(input_size, 200),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(200, 200),
            nn.ReLU(),
            nn.Dropout(0.5),
        )
        self.classifier = nn.Linear(200, output_size)

    def forward(self, x):
        x = self.backbone(x)
        return self.classifier(x)
    
    def embed(self, x):
        return self.backbone(x)
    

def train_dd(
        curr_X_train, 
        curr_y_train, 
        curr_weights=None,
        num_epochs=30, 
        augment_apply_batch: lambda x: x = None
):
    # Produce DataLoaders from the data

    if curr_weights is None:
        dataset = TensorDataset(
            torch.tensor(curr_X_train, dtype=torch.float32),
            torch.tensor(curr_y_train, dtype=torch.long),
        )
    else:
        dataset = TensorDataset(
            torch.tensor(curr_X_train, dtype=torch.float32),
            torch.tensor(curr_y_train, dtype=torch.long),
            torch.tensor(curr_weights, dtype=torch.float32),
        )

    train_loader = DataLoader(
        dataset, 
        batch_size=512,
        shuffle=True,
    )

    # Training Utils
    model = DrebinMLP(input_size=curr_X_train.shape[1], output_size=2)
    criterion = nn.CrossEntropyLoss(reduction="mean")
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

    for _ in range(num_epochs):
        model.train()
        for package in train_loader:
            curr_weights = None 
            if len(package) == 2:
                curr_X_train, curr_y_train = package
            elif len(package) == 3:
                curr_X_train, curr_y_train, curr_weights = package
                # Do not employ normalisation for the weights as we did earlier
                #curr_weights = check_and_normalise_weights(curr_weights, "cpu")

            if augment_apply_batch is not None:
                curr_X_train = augment_apply_batch(curr_X_train)
            optimiser.zero_grad()
            output = model(curr_X_train)
            loss = criterion(output, curr_y_train)
            if curr_weights is not None:
                #loss = (loss * curr_weights).mean()
                # Switch to weighted loss
                loss = (loss * curr_weights).sum() / curr_weights.sum()

            loss.backward()
            optimiser.step()
    return model


##########################
### Additional Modules ###
##########################


def select_most_uncertain_points(X_test, y_test, model=None, uncertainty = None, N=100, return_uncertainty=False):
    # Produce the Softmax uncertainties
    if model is not None:
        assert uncertainty is None, "If model is provided, uncertainty must be None"
        if isinstance(X_test, np.ndarray):
            X_test = torch.tensor(X_test, dtype=torch.float32)
        if isinstance(y_test, np.ndarray):
            y_test = torch.tensor(y_test, dtype=torch.long)
        logits = model(X_test)
        class_probs = F.softmax(logits, dim=1)    
        # Compute uncertainty from max softmax probability
        max_probs, _ = torch.max(class_probs, dim=1)
        uncertainty = 1 - ((max_probs - 0.5).abs() / 0.5)
        uncertainty = uncertainty.detach().cpu().numpy()
    else: 
        assert uncertainty is not None, "If model is not provided, uncertainty must be provided"

    # Select the top-N most uncertain points
    idx_uncertainties = np.argsort(uncertainty)[::-1][:N]
    curr_mask = np.zeros(len(X_test), dtype=bool)
    curr_mask[idx_uncertainties] = True
    assert np.sum(curr_mask) == N, f"Expected {N} samples, got {np.sum(curr_mask)}"
    if not return_uncertainty: 
        return X_test[curr_mask], y_test[curr_mask]
    else:
        return X_test[curr_mask], y_test[curr_mask], uncertainty
    

def weights_from_training_buffer(
        X_train_buffer: Union[np.ndarray, torch.Tensor], 
        y_train_buffer: Union[np.ndarray, torch.Tensor],
        num_epochs: int = 30, 
        n_folds: int = 6,
        mode = "kfold", #strat_kfold otherwise
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    
    # Perform k-fold split on X_train
    # First Year has 12 months - withh 6 splits we select 2*MONTHLY_REJECTION_BUDGET per sampling round

    if mode == "kfold":
        kf = KFold(n_splits=n_folds, shuffle=True)
        splitter = kf.split(X_train_buffer)
    elif mode == "strat_kfold":
        kf = StratifiedKFold(n_splits=n_folds, shuffle=True)
        splitter = kf.split(X_train_buffer, y_train_buffer)

    uncertainties, indices = [], []
    for train_index, test_index in tqdm(
        splitter, 
        desc="KFold Split and Uncertainty Sampling over Training-Set", 
        total=n_folds
    ):
        X_train, X_val = X_train_buffer[train_index], X_train_buffer[test_index]
        y_train, y_val = y_train_buffer[train_index], y_train_buffer[test_index]
        DD = train_dd(X_train, y_train, num_epochs=num_epochs)
        _, _, uncertainty = select_most_uncertain_points(
            X_test=X_val, 
            y_test=y_val, 
            N=2, 
            model=DD, 
            return_uncertainty=True,
        )
        uncertainties.append(uncertainty)
        indices.append(test_index)

    #concatenate all uncertainties
    uncertainties = np.concatenate(uncertainties)
    indices = np.concatenate(indices)
    uncertainties = uncertainties[indices]

    return uncertainties


def find_elbow_point_int(weights):
    """
    Find the elbow point in the cumulative sum curve of weights.

    Parameters:
    weights (np.array): The uncertainty weights
    
    Returns:
    int: The index of the elbow point
    """
    # Argsort the weights in descending order
    sorted_idcs = np.argsort(weights)[::-1]
    
    # Create ordered weights
    ordered_weights = weights[sorted_idcs]
    
    # Compute the cumulative sum
    cumsum = ordered_weights.cumsum()
    
    # X values (sample indices)
    x = np.arange(len(cumsum))
    
    # Normalize the data to 0-1 range for better elbow detection
    x_norm = x / np.max(x)
    y_norm = cumsum / np.max(cumsum)
    
    # Use KneeLocator to find the elbow point
    # S parameter controls how aggressive the elbow detection is
    # S=1.0 is moderate, higher values make it more aggressive
    kneedle = KneeLocator(
        x_norm, y_norm, 
        curve="concave", direction="increasing", 
        S=1.0
    )
    # Get the elbow point index in the original scale
    elbow_index = int(kneedle.elbow * len(x))
    return elbow_index



###########################
### Loading the Dataset ###
###########################

DATASET_NUM_MONTHS_MPR = {
    'androzoo': 24, 
    'apigraph': 72, 
    'transcendent_vmonth': 48,
    'transcendent': 48, 
}

DATA_PATH = os.path.dirname(__file__) + "/data/"
def load_range_dataset_w_benign(data_name, start_month, end_month, folder=DATA_PATH):
    if start_month != end_month:
        dataset_name = f'{start_month}to{end_month}'
    else:
        dataset_name = f'{start_month}'
    saved_data_file = os.path.join(folder, data_name, f'{dataset_name}_selected.npz')
    data = np.load(saved_data_file, allow_pickle=True)
    X_train, y_train = data['X_train'], data['y_train']
    y_mal_family = data['y_mal_family']
    return X_train, y_train, y_mal_family

class LoadHCCDatasets():
    def __init__(self, which:str = "apigraph", test_month_granularity: Optional[str] = None):
        self.which = which

        if test_month_granularity is not None:
            pass 

        if which == "androzoo":
            """
            * Training set is 2019 data
            * Validation set is 2020-01 - 2020-06
            * Test set is 2020-07 to 2021-12
            TRAIN_START=2019-01
            TRAIN_END=2019-12
            TEST_START=2020-01
            TEST_END=2021-12
            """
            self.data_name = 'gen_androzoo_drebin'
            self.train_start = "2019-01"
            self.train_end = "2019-12"
            self.test_start = "2020-01"
            self.test_end = "2021-12"

        elif which == "apigraph":
            """
            DATA=gen_apigraph_drebin
            TRAIN_START=2012-01
            TRAIN_END=2012-12
            TEST_START=2013-01
            TEST_END=2018-12
            """
            self.data_name = 'gen_apigraph_drebin'
            self.train_start = "2012-01"
            self.train_end = "2012-12"
            self.test_start = "2013-01"
            self.test_end = "2018-12"

        elif which == "transcendent": 
            self.data_name = "gen_transcendent_drebin"
            self.train_start = "2014-01"
            self.train_end = "2014-12"
            self.test_start = "2015-01"
            self.test_end = "2018-12"


    def _load(self):
        # Load the Training set
        if self.which in ["androzoo", "apigraph", "transcendent"]:
            X_train, y_train, _ = load_range_dataset_w_benign(
                self.data_name, 
                self.train_start, 
                self.train_end,
            )

            # Load the Test-Set
            start = dt.datetime.strptime(self.test_start, '%Y-%m')
            end = dt.datetime.strptime(self.test_end, '%Y-%m')

            cur_month = start

            X_tests, y_tests = [], []
            while cur_month <= end:
                """
                Step (6): Load test data.
                """
                cur_month_str = cur_month.strftime('%Y-%m')
                X_test, y_test, _ = load_range_dataset_w_benign(
                    self.data_name, 
                    cur_month_str, 
                    cur_month_str,
                )
                cur_month += relativedelta(months=1)
                X_tests.append(X_test)
                y_tests.append(y_test)

            return X_train, X_tests, y_train, y_tests

    def load_family_labels(self):
        X_train, X_tests, y_train, y_tests = self._load()
        return X_train, X_tests, y_train, y_tests

    def load_binary_labels(self):
        X_train, X_tests, y_train, y_tests = self._load()
        y_train = np.array([1 if item != 0 else 0 for item in y_train])
        y_tests = [np.array([1 if item != 0 else 0 for item in y_test]) for y_test in y_tests]
        return X_train, X_tests, y_train, y_tests


class LoadDataset(LoadHCCDatasets):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)



###########################
### Sampler Code and Co ###
###########################


def resample_from_training_buffer(
        X_train_buffer: Union[np.ndarray, torch.Tensor], 
        y_train_buffer: Union[np.ndarray, torch.Tensor],
        num_samples_total: int,
        num_epochs: int = 30, 
        n_folds: int = 6,
        show_progress: bool = True 
) -> Tuple[np.ndarray, np.ndarray]:
    
    # Perform k-fold split on X_train
    # First Year has 12 months - withh 6 splits we select 2*MONTHLY_REJECTION_BUDGET per sampling round
    X_new_buffer, y_new_buffer = [], []

    samples_per_slice = num_samples_total // n_folds

    kf = KFold(n_splits=n_folds, shuffle=True)

    def _dummy_tqdm(iterable, desc, total):
        return iterable
    progressbar = tqdm if show_progress else _dummy_tqdm
    for train_index, test_index in progressbar(
        kf.split(X_train_buffer), 
        desc="KFold Split and Uncertainty Sampling over Training-Set", 
        total=n_folds
    ):
        X_train, X_val = X_train_buffer[train_index], X_train_buffer[test_index]
        y_train, y_val = y_train_buffer[train_index], y_train_buffer[test_index]
        
        trainer = UDATrainer(
            loss_function="CE", 
            eval_every=150,
            num_epochs = num_epochs,
        )
        trainer = trainer.fit(
            X_train,
            y_train,
        )
        method_tuples, _ = trainer.predict_and_uncertainty(X_val, y_val)
        _, _, muncs = method_tuples[0]

        # Select the 2*MONTHLY_REJECTION_BUDGET most uncertain samples
        muncs_argsort = np.argsort(muncs)[::-1]
        curr_mask = muncs_argsort[:samples_per_slice]
        X_new_buffer.append(X_val[curr_mask])
        y_new_buffer.append(y_val[curr_mask]) 
    
    return np.vstack(X_new_buffer), np.concatenate(y_new_buffer)

class DsSampler():
    def __init__(
        self, 
        # Data
        X_train, 
        y_train, 
        X_tests, 
        y_tests,

        # Hyperparameters
        SAMPLER_MODE: str = "rolling_window", # Choice of ["rolling_window", "full_first_year_subsample_months", "subsample_first_year_subsample_months"]
        MONTHLY_BUDGET_FIRST_YEAR: int = 400, # How many samples per month over the first year are chosen
        MONTHLY_BUDGET_THEREAFTER: int = 200, 
        NUM_MONTHS_DANN_BACKLOG: int = 1,
        NUM_EPOCHS: int = 100,
    
        show_sampler_progress: bool = True, 

    ):
        self.X_train_init = X_train
        self.y_train_init = y_train
        self.X_tests = X_tests
        self.y_tests = y_tests

        # Hyperparameters
        self.SAMPLER_MODE = SAMPLER_MODE
        self.MONTHLY_BUDGET_FIRST_YEAR = MONTHLY_BUDGET_FIRST_YEAR
        self.MONTHLY_BUDGET_THEREAFTER = MONTHLY_BUDGET_THEREAFTER
        self.NUM_MONTHS_DANN_BACKLOG = NUM_MONTHS_DANN_BACKLOG
        self.NUM_EPOCHS = NUM_EPOCHS

        self.show_sampler_progress = show_sampler_progress

        # Internals
        self._initial_sampling_result = {}
        self._is_first_year_sampled = False  # Flag to track if first year sampling has been done

    def sample(
            self, 
            monthly_test_set_selector: Dict[int, np.ndarray] = None, # Mapping of monthly index to boolean masks
            i: int = 0, # Test-Month Counter
            return_weights:bool = False, 
    ):
        if i > 0:
            if "subsample_months" in self.SAMPLER_MODE:
                assert monthly_test_set_selector is not None, "monthly_test_set_selector must be provided for i > 0"

        # (1) Select the Train- and Test-Sets
        if self.SAMPLER_MODE in [
            "zero_first_year_subsample_months", #No subsampling of the first year, return empty!
            "full_first_year_subsample_months", # No subsampling, just use the full first year

            "subsample_first_year_subsample_months", # Default kfold mode for selection + fixed sampling rate
            "stratk_subsample_first_year_subsample_months", # Same as Default but with Stratified KFold

            "random_subsample_first_year_subsample_months", # Random subsampling
            "stratk_random_subsample_first_year_subsample_months", # Random subsampling with stratified kfold

            "elbow_subsample_first_year_subsample_months", # Subsample first year with elbow point method
            "elbow_stratk_subsample_first_year_subsample_months" # Subsample first year with elbow point method and stratified kfold
        ]:
            if i == 0:  # First Iteration of the Sampler
                # Check if we've already sampled the first year
                if self._is_first_year_sampled:
                    cache_key = (self.SAMPLER_MODE, self.MONTHLY_BUDGET_FIRST_YEAR)
                    if cache_key in self._initial_sampling_result:
                        print(f"Using cached first year sampling for {self.SAMPLER_MODE}")
                        return self._initial_sampling_result[cache_key]

                self.X_train, self.y_train = deepcopy(self.X_train_init), deepcopy(self.y_train_init)
                curr_X_target, curr_y_target = [], []

                # Handle all sampling modes
                if "subsample_first_year" in self.SAMPLER_MODE:
                    if "random" in self.SAMPLER_MODE:
                        sample_count = 12 * self.MONTHLY_BUDGET_FIRST_YEAR
                        if "stratk" in self.SAMPLER_MODE:
                            # Stratified random subsampling
                            # Compute the class ratios
                            class_ratios = np.bincount(self.y_train) / len(self.y_train)
                            # compute the equivalent samples per class
                            class_samples = (class_ratios * sample_count).astype(int)
                            # Select the samples
                            all_selected_indices = []
                            for class_label in np.unique(self.y_train):
                                class_indices = np.where(self.y_train == class_label)[0]
                                selected_indices = np.random.choice(class_indices, class_samples[class_label], replace=False)
                                all_selected_indices += list(selected_indices)
                            # Update the training data
                            self.X_train = self.X_train[all_selected_indices]
                            self.y_train = self.y_train[all_selected_indices]
                        else:
                            # Pure random subsampling
                            selected_indices = np.random.choice(len(self.y_train), sample_count, replace=False)
                            self.X_train = self.X_train[selected_indices]
                            self.y_train = self.y_train[selected_indices]
                    else:
                        mode = "strat_kfold" if "stratk" in self.SAMPLER_MODE else "kfold"
                        weights = weights_from_training_buffer(
                            self.X_train, 
                            self.y_train,
                            num_epochs=self.NUM_EPOCHS, 
                            n_folds=6,
                            mode=mode
                        )
                        # Get weights using appropriate method
                        if "elbow" in self.SAMPLER_MODE:
                            # Use elbow point to determine sample count
                            sample_count = find_elbow_point_int(weights)
                            print(f"Selecting {sample_count} top uncertain samples")
                        else:
                            # For regular subsample, use predefined budget
                            sample_count = 12 * self.MONTHLY_BUDGET_FIRST_YEAR
        
                        # Apply selection using the same function for all modes
                        self.X_train, self.y_train = select_most_uncertain_points(
                            self.X_train, 
                            self.y_train, 
                            uncertainty=weights, 
                            N=sample_count, 
                            model=None,
                            return_uncertainty=False
                        )
                
                if self.SAMPLER_MODE == "zero_first_year_subsample_months":
                    self.X_train, self.y_train = [], []

                # Prepare return values
                curr_X_source, curr_X_test = self.X_train, self.X_tests[i]
                curr_y_source, curr_y_test = self.y_train, self.y_tests[i]
                package = (curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test)
                
                # Store the result with proper key
                cache_key = (self.SAMPLER_MODE, self.MONTHLY_BUDGET_FIRST_YEAR)
                self._initial_sampling_result[cache_key] = package
                self._is_first_year_sampled = True  # Mark that we've done the sampling
                return package
            
            else:
                # (2) Create the Source-Data
                X_source_select, y_source_select = [], []
                for j in range(i):
                    curr_test_set_selector = monthly_test_set_selector[j]
                    X_source_select.append(self.X_tests[j][curr_test_set_selector])
                    y_source_select.append(self.y_tests[j][curr_test_set_selector])

                # (3) Compile the Target Data for DANN
                X_target_select, y_target_select = [], []
                for j in range(max(0, i - self.NUM_MONTHS_DANN_BACKLOG), i):
                    # Use all points from the test-set as target data
                    curr_test_set_selector = np.ones_like(monthly_test_set_selector[j], dtype=bool)
                    X_target_select.append(self.X_tests[j][curr_test_set_selector])
                    y_target_select.append(self.y_tests[j][curr_test_set_selector])
                    
                # (4) Compile the Data
                if len(self.X_train) > 0:
                    curr_X_source = np.vstack([self.X_train] + X_source_select)
                else:
                    curr_X_source = np.vstack(X_source_select)
                curr_X_target = np.vstack(X_target_select)
                curr_X_test = self.X_tests[i]

                curr_y_source = np.concatenate([self.y_train] + y_source_select)
                curr_y_target = np.concatenate(y_target_select)  # Just for evaluation
                curr_y_test = self.y_tests[i]

                return curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test

        elif self.SAMPLER_MODE == "rolling_window":
            if i == 0:
                cache_key = self.SAMPLER_MODE
                if self._is_first_year_sampled and cache_key in self._initial_sampling_result:
                    return self._initial_sampling_result[cache_key]
            
                self.X_train, self.y_train = deepcopy(self.X_train_init), deepcopy(self.y_train_init)
                curr_X_target, curr_y_target = [], []
                curr_X_source, curr_X_test = self.X_train, self.X_tests[i]
                curr_y_source, curr_y_test = self.y_train, self.y_tests[i]
                package = curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test
                
                # Store with simple key
                self._initial_sampling_result[cache_key] = package
                self._is_first_year_sampled = True
                return package 
            else:
                curr_X_target, curr_y_target = []
                curr_X_train, curr_X_test = np.vstack([self.X_train] + self.X_tests[:i]), self.X_tests[i]
                curr_y_train, curr_y_test = np.concatenate([self.y_train]+ self.y_tests[:i]), self.y_tests[i]

                return curr_X_train, curr_X_target, curr_X_test, curr_y_train, curr_y_target, curr_y_test
        
        ############################
        #### More Esoteric Shit #### 
        ############################

        # Experimental
        if "experimental" in self.SAMPLER_MODE:
            if i == 0: # First Iteration of the Sampler 
                self.X_train, self.y_train = deepcopy(self.X_train_init), deepcopy(self.y_train_init)
                if "subsample_first_year" in self.SAMPLER_MODE:
                    # Perform k-fold split on X_train
                    # First Year has 12 months - withh 6 splits we select 2*MONTHLY_REJECTION_BUDGET per sampling round
                    self.X_train, self.y_train = resample_from_training_buffer(
                        self.X_train, 
                        self.y_train,
                        num_samples_total=12*self.MONTHLY_BUDGET_FIRST_YEAR,
                        num_epochs=self.NUM_EPOCHS, 
                        show_progress=self.show_sampler_progress
                    )
                return self.X_train, self.y_train, self.X_tests[i], self.y_tests[i]
            else:
                # X_source_without_M0, y_source_without_MO, X_target_M1, y_target_M1,\
                #  X_train_with_M0, y_train_with_M0, X_test_M1, y_test_M1
                X_source_without_M0, y_source_without_M0 = [self.X_train], [self.y_train]
                for j in range(0, i-1):
                    curr_test_set_selector = monthly_test_set_selector[j]
                    X_source_without_M0.append(self.X_tests[j][curr_test_set_selector])
                    y_source_without_M0.append(self.y_tests[j][curr_test_set_selector])

                if len(X_source_without_M0) > 1:
                    X_source_without_M0, y_source_without_M0 = np.vstack(X_source_without_M0), np.concatenate(y_source_without_M0)
                else:
                    X_source_without_M0, y_source_without_M0 = X_source_without_M0[0], y_source_without_M0[0]

                X_target_M0 = self.X_tests[i-1][monthly_test_set_selector[i-1]]
                y_target_M0 = self.y_tests[i-1][monthly_test_set_selector[i-1]]

                X_train_with_M0, y_train_with_M0 = [self.X_train], [self.y_train]
                for j in range(0, i):
                    curr_test_set_selector = monthly_test_set_selector[j]
                    X_train_with_M0.append(self.X_tests[j][curr_test_set_selector])
                    y_train_with_M0.append(self.y_tests[j][curr_test_set_selector])

                X_train_with_M0, y_train_with_M0 = np.vstack(X_train_with_M0), np.concatenate(y_train_with_M0)

                X_test_M1, y_test_M1 = self.X_tests[i], self.y_tests[i]

                return X_source_without_M0, y_source_without_M0, X_target_M0, y_target_M0,\
                    X_train_with_M0, y_train_with_M0, X_test_M1, y_test_M1

        elif self.SAMPLER_MODE in ["resample_monthly_fixed", "resample_monthly_grow"]:
            X_train_buffer, y_train_buffer = [deepcopy(self.X_train_init)], [deepcopy(self.y_train_init)]
            assert not i in monthly_test_set_selector, "monthly_test_set_selector must not contain the current month"
            if i > 0:
                assert i - 1 in monthly_test_set_selector, "monthly_test_set_selector must contain the previous month"
            for j in range(i):
                curr_test_set_selector = monthly_test_set_selector[j]
                X_train_buffer.append(self.X_tests[j][curr_test_set_selector])
                y_train_buffer.append(self.y_tests[j][curr_test_set_selector])
            if len(X_train_buffer) > 1:
                X_train_buffer, y_train_buffer = np.vstack(X_train_buffer), np.concatenate(y_train_buffer)
            else:
                X_train_buffer, y_train_buffer = X_train_buffer[0], y_train_buffer[0]
            # (A) Either keep the buffer-size fixed
            # (B) Add the capacity of one month's budget for every iteration
            NUM_SAMPLES_TOTAL = \
                12*self.MONTHLY_BUDGET_FIRST_YEAR if not "grow" in self.SAMPLER_MODE \
                    else (12 * self.MONTHLY_BUDGET_FIRST_YEAR + i * self.MONTHLY_BUDGET_THEREAFTER)

            self.X_train, self.y_train = resample_from_training_buffer(
                X_train_buffer, 
                y_train_buffer,
                num_samples_total=NUM_SAMPLES_TOTAL,
                num_epochs=self.NUM_EPOCHS, 
                show_progress=self.show_sampler_progress

            )
            curr_X_source, curr_X_test = X_train_buffer, self.X_tests[i]
            curr_y_source, curr_y_test = y_train_buffer, self.y_tests[i]

            # For backward compatibility:
            curr_X_target, curr_y_target = [], []
            return curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test




class Rejector():
    def __init__(
        self,
        REJECTOR_MODE: str = "reject_topn_overall", #reject_topn_neg, reject_topn_pos
        MONTHLY_LABEL_BUDGET: int = 50, 
                 
    ):
        self.monthly_test_set_selector = {}
        self.monthly_hyperparameters = defaultdict(dict)

        self.REJECTOR_MODE = REJECTOR_MODE
        self.MONTHLY_LABEL_BUDGET = MONTHLY_LABEL_BUDGET

    def update(
        self, 
        muncs_month_ahead, # Uncertainties from the beginning of the month
        muncs_past_month, # Uncertainties from the end of the month for the past month
        i: int, # Test-Month Counter
        preds_past_month: np.ndarray = None, # Predictions from the end of the month for the past month
    ):
        
        if self.REJECTOR_MODE == "rolling_forward":
            reject_mask, select_mask = self._reject_none(muncs_month_ahead) 
        
        elif self.REJECTOR_MODE == "reject_topn_overall": # HCC default behaviour
            reject_mask, select_mask = self._reject_topn_overall(muncs_past_month, self.MONTHLY_LABEL_BUDGET)
        
        elif self.REJECTOR_MODE == "reject_topn_random":
            reject_mask, select_mask = self._reject_topn_random(muncs_past_month, self.MONTHLY_LABEL_BUDGET)
    
        elif self.REJECTOR_MODE == "reject_topn_neg":
            assert preds_past_month is not None, "Predictions for the past month must be provided for this mode"
            reject_mask, select_mask = self._reject_topn_neg_sampling(
                muncs_past_month=muncs_past_month, 
                predictions_past_month=preds_past_month, 
                neg_label=1, 
                MONTHLY_LABEL_BUDGET=self.MONTHLY_LABEL_BUDGET,
            )

        elif self.REJECTOR_MODE == "reject_topn_pos":
            assert preds_past_month is not None, "Predictions for the past month must be provided for this mode"
            reject_mask, select_mask = self._reject_topn_neg_sampling(
                muncs_past_month=muncs_past_month, 
                predictions_past_month=preds_past_month, 
                neg_label=0, 
                MONTHLY_LABEL_BUDGET=self.MONTHLY_LABEL_BUDGET,
            )

        self.monthly_test_set_selector[i] = select_mask
        return reject_mask, select_mask


    def _reject_none(self, muncs_month_ahead):
        """Base-Case (Rolling-Forward Evaluation) - Accepting all points"""
        reject_mask = select_mask = np.ones(len(muncs_month_ahead), dtype=bool)
        return reject_mask, select_mask
    
    def _reject_topn_neg_sampling(
        self, 
        muncs_past_month: np.ndarray,
        predictions_past_month: np.ndarray,
        neg_label: int = 1,
        MONTHLY_LABEL_BUDGET: int = 50,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Rejects N points according to uncertainty, prioritizing negative samples first.
        If there aren't enough negative samples to meet the budget, fills the remaining
        slots with the most uncertain positive samples.
        
        Args:
            muncs_past_month: The 1D array of uncertainties over the last month
            predictions_past_month: The 1D array of predictions over the last month
            neg_label: The label of the negative samples (e.g. 1)
            MONTHLY_LABEL_BUDGET: The target number of points to reject
            
        Returns:
            Tuple[np.ndarray, np.ndarray]: (reject_mask, select_mask) boolean masks
                indicating which samples to reject/select
        """
        # Ensure we don't try to select more points than available
        total_points = len(muncs_past_month)
        actual_budget = min(MONTHLY_LABEL_BUDGET, total_points)
        
        if actual_budget == 0:
            return np.zeros(total_points, dtype=bool), np.zeros(total_points, dtype=bool)

        # Create class masks
        neg_mask = (predictions_past_month == neg_label)
        pos_mask = ~neg_mask
        
        # Get indices and uncertainties for both classes
        neg_indices = np.where(neg_mask)[0]
        pos_indices = np.where(pos_mask)[0]
        
        neg_uncertainties = muncs_past_month[neg_mask]
        pos_uncertainties = muncs_past_month[pos_mask]
        
        # Sort indices by uncertainty for each class
        if len(neg_uncertainties) > 0:
            sorted_neg_indices = neg_indices[np.argsort(neg_uncertainties)[::-1]]
        else:
            sorted_neg_indices = np.array([], dtype=int)
            
        if len(pos_uncertainties) > 0:
            sorted_pos_indices = pos_indices[np.argsort(pos_uncertainties)[::-1]]
        else:
            sorted_pos_indices = np.array([], dtype=int)
        
        # Take as many negative samples as we can up to the budget
        num_neg_samples = min(len(sorted_neg_indices), actual_budget)
        selected_indices = sorted_neg_indices[:num_neg_samples]
        
        # If we haven't met the budget, fill with positive samples
        remaining_budget = actual_budget - num_neg_samples
        if remaining_budget > 0:
            num_pos_samples = min(len(sorted_pos_indices), remaining_budget)
            selected_indices = np.concatenate([
                selected_indices,
                sorted_pos_indices[:num_pos_samples]
            ])
        
        # Create final masks
        select_mask = np.zeros(total_points, dtype=bool)
        select_mask[selected_indices] = True
        reject_mask = select_mask.copy()
        
        return reject_mask, select_mask
    
    def _reject_topn_overall(
            self, 
            muncs_past_month: np.ndarray, 
            MONTHLY_LABEL_BUDGET: int
    ):
        """
        Rejects all points over the last month according to a predefined monthly budget.

        Args:
            - muncs_past_month: The 1D array of uncertainties over the last month (at the end of the month)
            - MONTHLY_LABEL_BUDGET: The budget for the monthly rejections (e.g. 50)
        
        Returns:
            - select_mask: The boolean mask of the selected points for rejection.
        """
        topn_mask = np.argsort(muncs_past_month)[::-1][:MONTHLY_LABEL_BUDGET]
        select_mask = np.zeros(len(muncs_past_month), dtype=bool)
        select_mask[topn_mask] = True
        reject_mask = select_mask.copy()
        return reject_mask, select_mask
    

    def _reject_topn_random(
            self, 
            muncs_past_month: np.ndarray, 
            MONTHLY_LABEL_BUDGET: int
    ):
        """
        Rejects random points over the last month according to a predefined monthly budget.

        Args:
            - muncs_past_month: The 1D array of uncertainties over the last month (at the end of the month)
            - MONTHLY_LABEL_BUDGET: The budget for the monthly rejections (e.g. 50)
        
        Returns:
            - reject_mask: The boolean mask of the selected points for rejection.
            - select_mask: Same as reject_mask in this implementation.
        """
        # Determine how many samples to select (minimum of budget or available samples)
        num = min(MONTHLY_LABEL_BUDGET, len(muncs_past_month))
        
        # Randomly select indices without replacement
        random_indices = np.random.choice(len(muncs_past_month), num, replace=False)
        
        # Create a boolean mask of the same length as muncs_past_month
        select_mask = np.zeros(len(muncs_past_month), dtype=bool)
        
        # Set the randomly selected indices to True
        select_mask[random_indices] = True
        
        # Create a copy for the reject mask
        reject_mask = select_mask.copy()
        
        return reject_mask, select_mask



def Trainer(
    # Data 
    curr_X_source,
    curr_X_target,
    curr_X_test,
    curr_y_source,
    curr_y_target,
    curr_y_test,
    curr_weights_source=None,

    # Hyperparameters
    i: int = 0, # Iteration-Counter
    TRAINER_MODE: str = "CE", # Choice of ["CE", "CE+DANN-Sampler", "DANN+DANN-Sampler"]
    CE_TRAIN_SCALE: float = 0.,
    DANN_TRAIN_SCALE: float = 0.,
    DANN_UNCERTAINTY_METHOD: str = "DANN - Combined (average)", # Something like: "DANN - Combined (average)" or  "DANN - Classifier"

    # General Hyperparameters
    EVAL_EVERY: int = 150,
    NUM_EPOCHS: int = 100,
    SVM_C: float = 1.0,

    device:str = "mps",
):
    

    if TRAINER_MODE in ["CE", "HCCMLP"]: # DeepDrebin
        apply_source_aug = TransformFixMatchCSR_ApplyBatch(curr_X_source, scale=CE_TRAIN_SCALE)
        additional_kwargs = {}
        if CE_TRAIN_SCALE > 0:
            additional_kwargs["apply_source_aug"] = apply_source_aug
        
        if curr_weights_source is not None:
            pass 

        trainer = UDATrainer(
            loss_function=TRAINER_MODE, 
            eval_every=EVAL_EVERY,
            num_epochs = NUM_EPOCHS,
            device=device, 
        )

        trainer = trainer.fit(
            curr_X_source,
            curr_y_source,
            w_train=curr_weights_source,
            **additional_kwargs
        )

        method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
        mname, mpreds, muncs = method_tuples[0] 
        muncs_first = muncs

        return mname, mpreds, muncs_first, muncs, curr_y_test
    
    elif TRAINER_MODE == "SVC":
        if not isinstance(curr_X_source, np.ndarray):
            curr_X_source = curr_X_source.toarray()
            curr_X_target = curr_X_target.toarray()
            curr_X_test = curr_X_test.toarray()
            curr_y_test = curr_y_test.toarray()
        

        curr_X_source = csr_matrix(curr_X_source)
        curr_X_target = csr_matrix(curr_X_target)
        curr_X_test = csr_matrix(curr_X_test)
        curr_y_test = csr_matrix(curr_y_test)

        svm = LinearSVC(dual=True, C=SVM_C) #max_iter=100
        svm.fit(curr_X_source, curr_y_source) 
        mpreds, muncs = svm.get_ncms(curr_X_test)
        return "SVC", mpreds, muncs, muncs, curr_y_test

    

    elif TRAINER_MODE == "DANN+CE-Sampler": 
        additional_kwargs = {}
        if DANN_TRAIN_SCALE > 0:
            train_transformer = TransformFixMatchCSR(curr_X_source, scale=DANN_TRAIN_SCALE)
            additional_kwargs["train_dataset_cls"] = partial(TransformTensorDataset, transform=train_transformer)


        if i == 0:
            # (1) First round via normal CE training 
            trainer = UDATrainer(
                loss_function="CE", 
                eval_every=EVAL_EVERY,
                num_epochs = NUM_EPOCHS,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                **additional_kwargs
            )
            method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)

            # Unpack method tupe - for CE the list of tuples has only length one  [("CE - Classifier", val_predictions, val_uncertainties)]
            mname, mpreds, muncs = method_tuples[0]
            muncs_first = muncs
        else:
            trainer = UDATrainer(
                loss_function="DANN", 
                eval_every=EVAL_EVERY,
                num_epochs = 100,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                curr_X_target,
                curr_y_target,
                **additional_kwargs
            )

            # (3) Collect the uncertainties for ranking for the month for hand-labelling!
            method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
            mname, mpreds, muncs_first = [(mname, mpred, muncs) for mname, mpred, muncs in method_tuples if mname == DANN_UNCERTAINTY_METHOD][0]


            # (1) First round via normal CE training - Assumption: We are at the beginning of the month and only have previous labelled and unlabelled data
            additional_kwargs = {}
            if CE_TRAIN_SCALE > 0:
                train_transformer = TransformFixMatchCSR(curr_X_source, scale=CE_TRAIN_SCALE)
                additional_kwargs["train_dataset_cls"] = partial(TransformTensorDataset, transform=train_transformer)

            trainer = UDATrainer(
                loss_function="CE", 
                eval_every=EVAL_EVERY,
                num_epochs = NUM_EPOCHS,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                curr_X_test,
                curr_y_test,
            )

            method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
            _, _, muncs = method_tuples[0]

        return mname, mpreds, muncs_first, muncs, curr_y_test



    elif TRAINER_MODE == "CE+DANN-Sampler": #DeepDrebin + UDA Sampler

        # (1) First round via normal CE training - Assumption: We are at the beginning of the month and only have previous labelled and unlabelled data
        additional_kwargs = {}
        if CE_TRAIN_SCALE > 0:
            train_transformer = TransformFixMatchCSR(curr_X_source, scale=CE_TRAIN_SCALE)
            additional_kwargs["train_dataset_cls"] = partial(TransformTensorDataset, transform=train_transformer)

        trainer = UDATrainer(
            loss_function="CE", 
            eval_every=EVAL_EVERY,
            num_epochs = NUM_EPOCHS,
        )

        trainer = trainer.fit(
            curr_X_source,
            curr_y_source,
            curr_X_test,
            curr_y_test,
        )

        method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
        mname, mpreds, muncs_first = method_tuples[0]

        # (2) Second Round - Assumption: We are at the end of the month and observed the full data of the month - Train DANN as a selector!
        additional_kwargs = {}
        if DANN_TRAIN_SCALE > 0:
            train_transformer = TransformFixMatchCSR(curr_X_source, scale=DANN_TRAIN_SCALE)
            additional_kwargs["train_dataset_cls"] = partial(TransformTensorDataset, transform=train_transformer)

        trainer = UDATrainer(
            loss_function="DANN", 
            eval_every=EVAL_EVERY,
            num_epochs = 100,
        )

        trainer = trainer.fit(
            curr_X_source,
            curr_y_source,
            curr_X_test,
            curr_y_test,
            **additional_kwargs
        )

        # (3) Collect the uncertainties for ranking for the month for hand-labelling!
        method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
        _, _, muncs = [(mname, mpred, muncs) for mname, mpred, muncs in method_tuples if mname == DANN_UNCERTAINTY_METHOD][0]
        return mname, mpreds, muncs_first, muncs, curr_y_test


    elif TRAINER_MODE == "DANN+DANN-Sampler":
        train_transformer = TransformFixMatchCSR(curr_X_source, scale=DANN_TRAIN_SCALE)
        additional_kwargs = {}
        if DANN_TRAIN_SCALE > 0:
            additional_kwargs["train_dataset_cls"] = partial(TransformTensorDataset, transform=train_transformer)

        if i == 0:
            # (1) First round via normal CE training 
            trainer = UDATrainer(
                loss_function="CE", 
                eval_every=EVAL_EVERY,
                num_epochs = NUM_EPOCHS,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                curr_X_test,
                curr_y_test,
                **additional_kwargs
            )
            method_tuples, _ = trainer.predict_and_uncertainty()

            # Unpack method tupe - for CE the list of tuples has only length one  [("CE - Classifier", val_predictions, val_uncertainties)]
            mname, mpreds, muncs = method_tuples[0]
            muncs_first = muncs
        else:
            # (2) Second Round - Assumption: We are at the beginning of the month and only have previous labelled and unlabelled data
            trainer = UDATrainer(
                loss_function="DANN", 
                eval_every=EVAL_EVERY,
                num_epochs = NUM_EPOCHS,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                curr_X_target,
                curr_y_target,
                **additional_kwargs
            )

            # (3) These are the predictions for the next month!
            method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)

            # (4) Evaluate the model
            # muncs here is only used in some versions where we assume that we work with the previous uncertainty
            # - because now we retrain DANN with X_test to allow for a more targeted ranking!
            mname, mpreds, muncs_first = [(mname, mpred, muncs) for mname, mpred, muncs in method_tuples if mname == DANN_UNCERTAINTY_METHOD][0]

            # (5) Second Round - Assumption: We are at the end of the month and observed the full data of the month
            trainer = UDATrainer(
                loss_function="DANN", 
                eval_every=EVAL_EVERY,
                num_epochs = NUM_EPOCHS,
            )

            trainer = trainer.fit(
                curr_X_source,
                curr_y_source,
                curr_X_test,
                curr_y_test,
                **additional_kwargs
            )

            # (7) Collect the uncertainties for ranking for the month for hand-labelling!
            method_tuples, labels = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
            _, _, muncs = [(mname, mpred, muncs) for mname, mpred, muncs in method_tuples if mname == DANN_UNCERTAINTY_METHOD][0]

        return mname, mpreds, muncs_first, muncs, curr_y_test
    


if __name__ == "__main__":
    def test_rejection_improved():
        reject_topn_neg_sampling = Rejector(
            REJECTOR_MODE="reject_topn_neg",
        )._reject_topn_neg_sampling
    
        # Test case 1: More negative samples than budget
        uncertainties = np.array([0.1, 0.8, 0.3, 0.9, 0.2, 0.7])
        predictions = np.array([0, 1, 0, 1, 1, 0])  # 3 negative samples
        
        reject_mask, select_mask = reject_topn_neg_sampling(
            uncertainties,
            predictions,
            neg_label=1,
            MONTHLY_LABEL_BUDGET=2
        )
        
        expected_mask1 = np.array([False, True, False, True, False, False])
        np.testing.assert_array_equal(select_mask, expected_mask1)
        print("Test 1 passed: Correctly selected top 2 negative samples")

        # Test case 2: Fewer negative samples than budget
        uncertainties = np.array([0.1, 0.8, 0.3, 0.9, 0.2, 0.7])
        predictions = np.array([0, 1, 0, 0, 1, 0])  # 2 negative samples
        
        reject_mask, select_mask = reject_topn_neg_sampling(
            uncertainties,
            predictions,
            neg_label=1,
            MONTHLY_LABEL_BUDGET=4
        )
        
        expected_mask2 = np.array([False, True, False, True, True, True])
        np.testing.assert_array_equal(select_mask, expected_mask2)
        print("Test 2 passed: Correctly selected 2 negatives + 2 most uncertain positives")

        # Test case 3: No negative samples
        uncertainties = np.array([0.1, 0.8, 0.3, 0.9, 0.2, 0.7])
        predictions = np.array([0, 0, 0, 0, 0, 0])  # no negative samples
        
        reject_mask, select_mask = reject_topn_neg_sampling(
            uncertainties,
            predictions,
            neg_label=1,
            MONTHLY_LABEL_BUDGET=3
        )
        
        expected_mask3 = np.array([False, True, False, True, False, True])
        np.testing.assert_array_equal(select_mask, expected_mask3)
        print("Test 3 passed: Correctly selected 3 most uncertain samples when no negatives")

    test_rejection_improved()
