from multiprocessing import Pool, cpu_count, current_process
import itertools
from experiments import DsSampler, LoadHCCDatasets, Trainer, Rejector, TransformFixMatchCSR_ApplyBatch, UDATrainer
from tqdm.notebook import tqdm
import psutil
import os
import math
from time import time 
import pickle 
import numpy as np 
from sklearn.metrics import f1_score
from typing import Union, Tuple
import torch
from sklearn.model_selection import KFold
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch import nn    
from kneed import KneeLocator
from tools import RobustPickleIO



def log_live_result(
            muncs, 
            y_preds, 
            y_test, 
            rejc_pts = [250, 500, 750, 1000], 
            prefix="Vanilla CE", 
            results: list = None,
            i=0, 
    ):
        curr_res = {}
        f1_full = f1_score(y_test, y_preds, average="binary")
        for rej_pt in rejc_pts:
            topn_mask = np.argsort(muncs)[::-1][:rej_pt]
            reject_mask = np.ones(len(muncs), dtype=bool)
            reject_mask[topn_mask] = False

            # Measure the F1-Score
            curr_y_test = y_test[reject_mask]
            curr_mpreds = y_preds[reject_mask]

            curr_f1 = f1_score(curr_y_test, curr_mpreds, average="binary")
            curr_res["F1 (Rej.) [N=" + str(rej_pt) + "]"] = curr_f1
        
        # Combine all the results to a single string showing the score at each rejection point
        res_str = " | ".join([f"{k}: {v*100:.1f}" for k, v in curr_res.items()])
        print(f"{prefix} | F1 (Full): {100*f1_full:.1f} | {res_str}")

        if results is not None:
            results.append({
                "month": i, 
                "prefix": prefix,
                "F1 (Full)": f1_full,
                **curr_res
            })

        return results







#############################
#### Experimental Setups ####
#############################

import multiprocessing as mp
from multiprocessing import cpu_count, Pool, Manager
import pickle
import os
from functools import partial
import time

class ParallelProcessorMixin():
    def run(self):
        # Calculate number of processes and CPUs per process
        total_cpus = cpu_count()
        num_processes = min(self.NUM_PROCESSES, total_cpus)
        cpus_per_process = total_cpus // num_processes
        
        print(f"Running {len(self.combinations)} experiments across {num_processes} processes")
        print(f"Each process will use {cpus_per_process} CPU cores")
        
        try:
            all_results = []
            if num_processes > 1: 
                with Pool(processes=num_processes) as pool:
                    # Process combinations
                    I = 0
                    for local_results in pool.imap_unordered(self.run_experiment, self.combinations):
                        all_results.extend(local_results)
                        if self.exp_name is not None:
                            with open(self.exp_name, "wb") as f:
                                pickle.dump(all_results, f)
                        I +=1
                        print(f"Completed {I}/{len(self.combinations)} experiments")
            else:
                I = 0
                for I, args in enumerate(self.combinations):
                    all_results.extend(self.run_experiment(args))
                    I +=1
                    print(f"Completed {I}/{len(self.combinations)} experiments")
                    if self.exp_name is not None:
                        with open(self.exp_name, "wb") as f:
                            pickle.dump(all_results, f)
        except (KeyboardInterrupt, Exception) as e:
            print("Interrupted!")
            raise e
        finally:
            # Delete director samplers if it exists
            if os.path.exists("samplers"):
                import shutil
                shutil.rmtree("samplers")

        return all_results


class DANNMainExperimentsNoAug(ParallelProcessorMixin):
    def __init__(self, NUM_PROCESSES=8):
        self.NUM_PROCESSES = NUM_PROCESSES
        MONTHLY_LABEL_BUDGETS = [50, 100, 200, 400]

        DATASETS = [
            "androzoo", 
            "apigraph", 
            "transcendent_vmonth",
        ]

        SAMPLER_MODES = [
            "full_first_year_subsample_months", 
            "subsample_first_year_subsample_months",
        ]

        REJECTOR_MODE = ["reject_topn_overall",]

        TRAINER_MODES = [
            #"CE", # Not run yet!
            "CE+DANN-Sampler",
            "DANN+CE-Sampler",
            "DANN+DANN-Sampler",
        ]
        DANN_UNCERTAINTY_METHODS = [
            "DANN - Combined (average)",
            "DANN - Classifier"
        ]
        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            SAMPLER_MODES,
            TRAINER_MODES,
            DANN_UNCERTAINTY_METHODS,
            MONTHLY_LABEL_BUDGETS,
            REJECTOR_MODE,
        ))


    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """
        # Set CPU affinity for this process
        DATASET, SAMPLER_MODE, TRAINER_MODE, DANN_UNCERTAINTY_METHOD, MONTHLY_LABEL_BUDGET, REJECTOR_MODE = args
        
        local_results = []
        # Load dataset for this experiment
        X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()
        
        # Initialize stateful objects
        sampler = DsSampler(
            X_train=X_train, 
            y_train=y_train, 
            X_tests=X_tests, 
            y_tests=y_tests,
            SAMPLER_MODE=SAMPLER_MODE, 
            MONTHLY_BUDGET_FIRST_YEAR=400, 
            NUM_MONTHS_DANN_BACKLOG=1,
            EVAL_EVERY=150,
            NUM_EPOCHS=100,
        )

        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        # Create description for this experiment's progress bar
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{TRAINER_MODE}-{DANN_UNCERTAINTY_METHOD}-Budget{MONTHLY_LABEL_BUDGET}"
        
        # Run the complete inner loop with its own progress bar
        for i in range(len(X_tests)):
            time_start = time()
            curr_res = {}
            curr_res["Dataset"] = DATASET
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Trainer-Mode"] = TRAINER_MODE + "(Upd. Sampler)"
            curr_res["DANN-Uncertainty-Method"] = DANN_UNCERTAINTY_METHOD
            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["Test-Month"] = i

            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )
            mname, mpreds, muncs_first, muncs, _ = Trainer(
                # Data
                curr_X_source=curr_X_source,
                curr_X_target=curr_X_target,
                curr_X_test=curr_X_test,
                curr_y_source=curr_y_source,
                curr_y_target=curr_y_target,
                curr_y_test=curr_y_test,
                # Hyperparameters
                i=i,
                TRAINER_MODE=TRAINER_MODE,
                CE_TRAIN_SCALE=0.,
                DANN_TRAIN_SCALE=0.,
                DANN_UNCERTAINTY_METHOD=DANN_UNCERTAINTY_METHOD,
                # General Hyperparameters
                EVAL_EVERY=150,
                NUM_EPOCHS=100,
            )

            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)
            time_end = time()
            print(f"{exp_desc} | Month {i} in {time_end - time_start}s")
        
        return local_results




class DANN_FE_Experiments(DANNMainExperimentsNoAug):
    def __init__(self, NUM_PROCESSES=8):
        self.NUM_PROCESSES = NUM_PROCESSES
        MONTHLY_LABEL_BUDGETS = [50, 100, 200, 400]

        DATASETS = [
            "androzoo", 
            "apigraph", 
            "transcendent_vmonth",
        ]

        SAMPLER_MODES = [
            "full_first_year_subsample_months", 
            #"subsample_first_year_subsample_months",
        ]

        TRAINER_MODES = [
            "DANN-FE"
        ]
 
        REJECTOR_MODE = ["reject_topn_overall",]

        DANN_TRAINING_EPOCHS = [30, 50]
        CE_TRAINING_EPOCHS = [5, 10, 20]
        CE_TRAIN_AUGS = [0, 0.05]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            SAMPLER_MODES,
            TRAINER_MODES,
            MONTHLY_LABEL_BUDGETS,
            REJECTOR_MODE,
            DANN_TRAINING_EPOCHS,
            CE_TRAINING_EPOCHS,
            CE_TRAIN_AUGS, 
        ))


    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """
        # Set CPU affinity for this process
        DATASET, SAMPLER_MODE, TRAINER_MODE, MONTHLY_LABEL_BUDGET, REJECTOR_MODE, DANN_TRAINING_EPOCHS, CE_TRAINING_EPOCHS, CE_TRAIN_AUG= args
        
        local_results = []
        # Load dataset for this experiment
        X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()
        
        # Initialize stateful objects
        sampler = DsSampler(
            X_train=X_train, 
            y_train=y_train, 
            X_tests=X_tests, 
            y_tests=y_tests,
            SAMPLER_MODE=SAMPLER_MODE, 
            MONTHLY_BUDGET_FIRST_YEAR=400, 
            NUM_MONTHS_DANN_BACKLOG=1,
            EVAL_EVERY=150,
            NUM_EPOCHS=100,
        )

        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        # Create description for this experiment's progress bar
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{TRAINER_MODE}-Budget{MONTHLY_LABEL_BUDGET}-DANN-Epochs={DANN_TRAINING_EPOCHS}-CE-Epochs={CE_TRAINING_EPOCHS}-CE-Aug={CE_TRAIN_AUG}"
        
        # Run the complete inner loop with its own progress bar
        for i in range(len(X_tests)):
            time_start = time()

            curr_res = {}
            curr_res["Test-Month"] = i
            curr_res["Dataset"] = DATASET
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Trainer-Mode"] = TRAINER_MODE + "(Upd. Sampler)"
            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["Rejector-Mode"] = REJECTOR_MODE
            curr_res["DANN-Training-Epochs"] = DANN_TRAINING_EPOCHS
            curr_res["CE-Training-Epochs"] = CE_TRAINING_EPOCHS
            curr_res["CE-Train-Aug-Scale"] = CE_TRAIN_AUG



            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )


            if i > 0:
                source_aug = None
                if CE_TRAIN_AUG > 0:
                    source_aug = TransformFixMatchCSR_ApplyBatch(curr_X_source, scale=0.05)
                
                trainer = UDATrainer(
                    loss_function="DANN", 
                    eval_every=150,
                    num_epochs = DANN_TRAINING_EPOCHS, #10 works much better than 100!
                )
                trainer = trainer.fit(
                    X_train, #X_source_without_M0,
                    y_train, #y_source_without_M0,
                    np.vstack(X_tests[:i]), #X_target_M0,
                    np.concatenate(y_tests[:i]), #y_target_M0,
                    apply_source_aug=source_aug, # Source Augmentation seems to hurt the process!
                    #apply_target_aug=target_aug,
                )
                # (2) Freeze the backbone and fine-tune the classifier on the training-set with labels up to current month
                trainer.set_loss_function("CE")
                trainer.num_epochs = CE_TRAINING_EPOCHS
                
                trainer = trainer.fit(
                    curr_X_source,
                    curr_y_source,
                    apply_source_aug=source_aug, 
                    warm_start_model=True, 
                    set_feature_extractor_trainable=True, 
                )
                #print(f"Comparison F1 score: {comp_f1s[i]}")
                method_tuples, _ = trainer.predict_and_uncertainty(curr_X_test, curr_y_test)
                mname, mpreds, muncs = method_tuples[0]
                muncs_first = muncs
            else:
                mname, mpreds, muncs_first, muncs, _ = Trainer(
                    # Data
                    curr_X_source=curr_X_source,
                    curr_X_target=curr_X_target,
                    curr_X_test=curr_X_test,
                    curr_y_source=curr_y_source,
                    curr_y_target=curr_y_target,
                    curr_y_test=curr_y_test,
                    # Hyperparameters
                    TRAINER_MODE="CE",
                    CE_TRAIN_SCALE=CE_TRAIN_AUG,
                    DANN_TRAIN_SCALE=0.,
                    #DANN_UNCERTAINTY_METHOD=DANN_UNCERTAINTY_METHOD,
                    # General Hyperparameters
                    EVAL_EVERY=150,
                    NUM_EPOCHS=30,
                    i=i,
                )


            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)
            time_end = time()

            curr_desc = f"{exp_desc}\t\t | Month {i} in {(time_end - time_start):.0f}s |  "
            log_live_result(
                muncs=muncs, 
                y_preds=mpreds, 
                y_test=curr_y_test, 
                prefix=curr_desc, 
                results=None,
                i=i, 
            )

        
        return local_results



'''
class ParallelProcessorMixin:
    def __init__(self, NUM_PROCESSES=8, exp_name=None):
        self.NUM_PROCESSES = NUM_PROCESSES
        self.exp_name = exp_name
        self.combinations = []
        
    def run(self):
        # Calculate number of processes and CPUs per process
        total_cpus = cpu_count()
        num_processes = min(self.NUM_PROCESSES, total_cpus)
        cpus_per_process = total_cpus // num_processes
        
        print(f"Running {len(self.combinations)} experiments across {num_processes} processes")
        print(f"Each process will use {cpus_per_process} CPU cores")
        
        # Create a manager for shared objects
        manager = Manager()
        shared_samplers = manager.dict()
        
        all_results = []
        if num_processes > 1: 
            with Pool(processes=num_processes) as pool:
                # Create partial function with the shared dictionary
                from functools import partial
                run_with_store = partial(self.run_experiment_wrapper, shared_samplers=shared_samplers)
                
                # Process combinations
                I = 0
                for local_results in pool.imap_unordered(run_with_store, self.combinations):
                    all_results.extend(local_results)
                    if self.exp_name is not None:
                        with open(self.exp_name, "wb") as f:
                            pickle.dump(all_results, f)
                    I += 1
                    print(f"Completed {I}/{len(self.combinations)} experiments")
        else:
            I = 0
            for args in self.combinations:
                all_results.extend(self.run_experiment_wrapper(args, shared_samplers))
                I += 1
                print(f"Completed {I}/{len(self.combinations)} experiments")
                if self.exp_name is not None:
                    with open(self.exp_name, "wb") as f:
                        pickle.dump(all_results, f)
        return all_results
        
    def run_experiment_wrapper(self, args, shared_samplers):
        """Wrapper that passes the shared samplers to run_experiment"""
        return self.run_experiment(args, shared_samplers)
    
    def run_experiment(self, args, shared_samplers=None):
        """Must be implemented by subclasses"""
        raise NotImplementedError("Subclasses must implement run_experiment")
'''

class CEMainExperiments(ParallelProcessorMixin):
    def __init__(self, NUM_PROCESSES=8, exp_name = None):

        self.NUM_PROCESSES = NUM_PROCESSES
        self.exp_name = exp_name

        DATASETS = [
            "androzoo", 
            "apigraph", 
            "transcendent", 
        ]

        MONTHLY_LABEL_BUDGETS = [
            50, 
            100, 
            200, 
            400
        ]
    
        NUM_EPOCHS = [
            #10, 
            30, 
            50,
        ]

        NUM_SAMPLES_PER_MONTH_FIRST_YEAR = [
            0, 
            12//12, 
            100//12,
            250//12, 
            500//12,  
            1000//12, 
            2500//12, 
            4800//12, 
            5000//12,
            10000//12,
            20000//12,
        ]

        RANDOM_SEEDS = [
            0, 1, 2, 3, 4
        ]

        SAMPLER_MODES = [
            "stratk_random_subsample_first_year_subsample_months", 
            #"subsample_first_year_subsample_months",

            #"stratk_subsample_first_year_subsample_months",
            #"elbow_subsample_first_year_subsample_months",
            #"full_first_year_subsample_months",
            #"weight_first_year_subsample_months",
        ]

        TRAIN_AUG_SCALES = [
            0, 
            #0.05, 
            #0.1, 
            #0.2
        ]

        TRAINER_MODES = [
            "CE", 
            #"HCCMLP",
        ]

        REJECTOR_MODE = [
            #"reject_topn_pos", 
            #"reject_topn_neg",
            #"reject_topn_overall",
            "reject_topn_random", 
        ]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            MONTHLY_LABEL_BUDGETS,
            NUM_EPOCHS, 
            SAMPLER_MODES,
            TRAIN_AUG_SCALES, 
            TRAINER_MODES,
            REJECTOR_MODE,
            NUM_SAMPLES_PER_MONTH_FIRST_YEAR,
            RANDOM_SEEDS
        ))

    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """

        # Set CPU affinity for this process
        DATASET, MONTHLY_LABEL_BUDGET, NUM_EPOCHS, SAMPLER_MODE, TRAIN_AUG_SCALE, \
            TRAINER_MODE, REJECTOR_MODE, NUM_SAMPLES_PER_MONTH_FIRST_YEAR, RANDOM_SEED = args
        
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{REJECTOR_MODE}-{TRAINER_MODE}-Scale={TRAIN_AUG_SCALE}-Budget={MONTHLY_LABEL_BUDGET}-Epochs={NUM_EPOCHS}-#Samples={NUM_SAMPLES_PER_MONTH_FIRST_YEAR}-Seed={RANDOM_SEED}"

        # Exclude some experiments:
        '''
        if SAMPLER_MODE in ["full_first_year_subsample_months", "subsample_first_year_subsample_months"]:
            if REJECTOR_MODE == "reject_topn_overall":
                return []
        '''

        '''
        if SAMPLER_MODE in ["full_first_year_subsample_months"]:
            if NUM_SAMPLES_PER_MONTH_FIRST_YEAR != 400:
                return []
        '''

        # Create rejector for this experiment
        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        ##################################################################################################
        #Check if pickle file exists
        pickle_io = RobustPickleIO(max_retries=30, retry_delay=1)
        # Create directory "samplers" if not exists
        if not os.path.exists("samplers"):
            os.makedirs("samplers")
        sampler_key = f"samplers/{DATASET}_{SAMPLER_MODE}_{NUM_SAMPLES_PER_MONTH_FIRST_YEAR}.pkl"
        if os.path.exists(sampler_key):
            sampler = pickle_io.read_pickle(sampler_key)
            # Update parameters if needed
            sampler.MONTHLY_BUDGET_FIRST_YEAR = NUM_SAMPLES_PER_MONTH_FIRST_YEAR
            sampler.MONTHLY_BUDGET_THEREAFTER = MONTHLY_LABEL_BUDGET
            print(f"Reusing sampler for {sampler_key}")
        else:
            # Load dataset and create new sampler
            X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()

            # Initialize stateful objects
            SAMPLER_MODE_ = SAMPLER_MODE
            if SAMPLER_MODE == "weight_first_year_subsample_months":
                SAMPLER_MODE_ = "full_first_year_subsample_months"
            if NUM_SAMPLES_PER_MONTH_FIRST_YEAR == 0:
                SAMPLER_MODE_ = "zero_first_year_subsample_months"

            sampler = DsSampler(
                X_train=X_train, 
                y_train=y_train, 
                X_tests=X_tests, 
                y_tests=y_tests,
                SAMPLER_MODE=SAMPLER_MODE_, 
                MONTHLY_BUDGET_FIRST_YEAR=NUM_SAMPLES_PER_MONTH_FIRST_YEAR, 
                MONTHLY_BUDGET_THEREAFTER=MONTHLY_LABEL_BUDGET,
                NUM_MONTHS_DANN_BACKLOG=1,
                NUM_EPOCHS=30,  # Keep that fixed for now
            )
            # Force rerun of the sampler
            '''
            # Invoke Sampler to Store Objects!
            _ = sampler.sample(
                i=0,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )
            # Store for future use if shared store exists
            pickle_io.write_pickle(sampler, sampler_key)
            '''

        ##################################################################################################

        # Run the complete inner loop with its own progress bar
        curr_weights = None 
        local_results = []
        for i in range(len(sampler.X_tests)):
            time_start = time.time()
            curr_res = {}
            curr_res["Dataset"] = DATASET
            curr_res["Rejector-Mode"] = REJECTOR_MODE
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Num-Epochs"] = NUM_EPOCHS
            curr_res["Trainer-Mode"] = TRAINER_MODE
            curr_res["CE-Train-Aug-Scale"] = TRAIN_AUG_SCALE
            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["#-Monthly-Samples-First-Year"] = NUM_SAMPLES_PER_MONTH_FIRST_YEAR
            curr_res["Random-Seed"] = RANDOM_SEED
            curr_res["Test-Month"] = i

            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )

            if SAMPLER_MODE == "weight_first_year_subsample_months":
                from experiments import weights_from_training_buffer
                if i == 0:
                    curr_weights = weights_from_training_buffer(
                        X_train_buffer=curr_X_source,
                        y_train_buffer=curr_y_source,
                        num_epochs=30,
                        n_folds=6,
                    )
                else: 
                    num_additional_items_X_source = len(curr_X_source) - len(curr_weights)
                    # Add 1.0 to the weights of the new items
                    curr_weights = np.concatenate([curr_weights, np.ones(num_additional_items_X_source).astype(np.float32)])


            if i == 0 and NUM_SAMPLES_PER_MONTH_FIRST_YEAR == 0:
                mname = 'CE - Classifier'
                # Create a bernoulli random mask for the first month
                mpreds = np.random.rand(len(curr_y_test)) < 0.5
                # Create a random vector of values between 0 and 1
                muncs_first = muncs = np.random.rand(len(curr_y_test))
            else:
                mname, mpreds, muncs_first, muncs, _ = Trainer(
                    # Data
                    curr_X_source=curr_X_source,
                    curr_X_target=curr_X_target,
                    curr_X_test=curr_X_test,
                    curr_y_source=curr_y_source,
                    curr_y_target=curr_y_target,
                    curr_y_test=curr_y_test,
                    curr_weights_source=curr_weights,
                    # Hyperparameters
                    i=i,
                    TRAINER_MODE=TRAINER_MODE,
                    CE_TRAIN_SCALE=TRAIN_AUG_SCALE,
                    DANN_TRAIN_SCALE=0.,
                    #DANN_UNCERTAINTY_METHOD=DANN_UNCERTAINTY_METHOD,
                    # General Hyperparameters
                    EVAL_EVERY=150,
                    NUM_EPOCHS=NUM_EPOCHS,
                )

            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                preds_past_month=mpreds, 
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)
            time_end = time.time()
            print(f"{exp_desc} | Month {i} in {time_end - time_start}s")
        
        return local_results
    


########################### Deep Drebin for various time-windows with re-iteration ###########################




class CEMainExperiments(ParallelProcessorMixin):
    def __init__(self, NUM_PROCESSES=8, exp_name = None):

        self.NUM_PROCESSES = NUM_PROCESSES
        self.exp_name = exp_name

        DATASETS = [
            "androzoo", 
            "apigraph", 
            "transcendent", 
        ]



        MONTHLY_LABEL_BUDGETS = [
            50, 
            100, 
            200, 
            400
        ]
    
        NUM_EPOCHS = [
            #10, 
            30, 
            50,
        ]

        NUM_SAMPLES_PER_MONTH_FIRST_YEAR = [
            #0, 
            #12//12, 
            #100//12,
            #250//12, 
            #500//12,  
            #1000//12, 
            #2500//12, 
            4800//12, 
            #5000//12,
            #10000//12,
            #20000//12,
        ]

        RANDOM_SEEDS = [
            0, 1, 2, 3, 4
        ]

        SAMPLER_MODES = [
            #"stratk_random_subsample_first_year_subsample_months", 
            "subsample_first_year_subsample_months",
            "full_first_year_subsample_months",


            #"stratk_subsample_first_year_subsample_months",
            #"elbow_subsample_first_year_subsample_months",
            #"full_first_year_subsample_months",
            #"weight_first_year_subsample_months",
        ]

        TRAIN_AUG_SCALES = [
            0, 
            #0.05, 
            #0.1, 
            #0.2
        ]

        TRAINER_MODES = [
            "CE", 
            #"HCCMLP",
        ]

        REJECTOR_MODE = [
            #"reject_topn_pos", 
            #"reject_topn_neg",
            #"reject_topn_overall",
            "reject_topn_random", 
        ]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            MONTHLY_LABEL_BUDGETS,
            NUM_EPOCHS, 
            SAMPLER_MODES,
            TRAIN_AUG_SCALES, 
            TRAINER_MODES,
            REJECTOR_MODE,
            NUM_SAMPLES_PER_MONTH_FIRST_YEAR,
            RANDOM_SEEDS
        ))

    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """

        # Set CPU affinity for this process
        DATASET, MONTHLY_LABEL_BUDGET, NUM_EPOCHS, SAMPLER_MODE, TRAIN_AUG_SCALE, \
            TRAINER_MODE, REJECTOR_MODE, NUM_SAMPLES_PER_MONTH_FIRST_YEAR, RANDOM_SEED = args
        
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{REJECTOR_MODE}-{TRAINER_MODE}-Scale={TRAIN_AUG_SCALE}-Budget={MONTHLY_LABEL_BUDGET}-Epochs={NUM_EPOCHS}-#Samples={NUM_SAMPLES_PER_MONTH_FIRST_YEAR}-Seed={RANDOM_SEED}"

        # Create rejector for this experiment
        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        ##################################################################################################

        # Load dataset and create new sampler
        X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()

        # Initialize stateful objects
        SAMPLER_MODE_ = SAMPLER_MODE
        if SAMPLER_MODE == "weight_first_year_subsample_months":
            SAMPLER_MODE_ = "full_first_year_subsample_months"
        if NUM_SAMPLES_PER_MONTH_FIRST_YEAR == 0:
            SAMPLER_MODE_ = "zero_first_year_subsample_months"

        sampler = DsSampler(
            X_train=X_train, 
            y_train=y_train, 
            X_tests=X_tests, 
            y_tests=y_tests,
            SAMPLER_MODE=SAMPLER_MODE_, 
            MONTHLY_BUDGET_FIRST_YEAR=NUM_SAMPLES_PER_MONTH_FIRST_YEAR, 
            MONTHLY_BUDGET_THEREAFTER=MONTHLY_LABEL_BUDGET,
            NUM_MONTHS_DANN_BACKLOG=1,
            NUM_EPOCHS=30,  # Keep that fixed for now
        )


        ##################################################################################################

        # Run the complete inner loop with its own progress bar
        curr_weights = None 
        local_results = []
        for i in range(len(sampler.X_tests)):
            time_start = time.time()
            curr_res = {}
            curr_res["Dataset"] = DATASET
            curr_res["Rejector-Mode"] = REJECTOR_MODE
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Num-Epochs"] = NUM_EPOCHS
            curr_res["Trainer-Mode"] = TRAINER_MODE
            curr_res["CE-Train-Aug-Scale"] = TRAIN_AUG_SCALE
            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["#-Monthly-Samples-First-Year"] = NUM_SAMPLES_PER_MONTH_FIRST_YEAR
            curr_res["Random-Seed"] = RANDOM_SEED
            curr_res["Test-Month"] = i

            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )

            if SAMPLER_MODE == "weight_first_year_subsample_months":
                from experiments import weights_from_training_buffer
                if i == 0:
                    curr_weights = weights_from_training_buffer(
                        X_train_buffer=curr_X_source,
                        y_train_buffer=curr_y_source,
                        num_epochs=30,
                        n_folds=6,
                    )
                else: 
                    num_additional_items_X_source = len(curr_X_source) - len(curr_weights)
                    # Add 1.0 to the weights of the new items
                    curr_weights = np.concatenate([curr_weights, np.ones(num_additional_items_X_source).astype(np.float32)])


            if i == 0 and NUM_SAMPLES_PER_MONTH_FIRST_YEAR == 0:
                mname = 'CE - Classifier'
                # Create a bernoulli random mask for the first month
                mpreds = np.random.rand(len(curr_y_test)) < 0.5
                # Create a random vector of values between 0 and 1
                muncs_first = muncs = np.random.rand(len(curr_y_test))
            else:
                mname, mpreds, muncs_first, muncs, _ = Trainer(
                    # Data
                    curr_X_source=curr_X_source,
                    curr_X_target=curr_X_target,
                    curr_X_test=curr_X_test,
                    curr_y_source=curr_y_source,
                    curr_y_target=curr_y_target,
                    curr_y_test=curr_y_test,
                    curr_weights_source=curr_weights,
                    # Hyperparameters
                    i=i,
                    TRAINER_MODE=TRAINER_MODE,
                    CE_TRAIN_SCALE=TRAIN_AUG_SCALE,
                    DANN_TRAIN_SCALE=0.,
                    #DANN_UNCERTAINTY_METHOD=DANN_UNCERTAINTY_METHOD,
                    # General Hyperparameters
                    EVAL_EVERY=150,
                    NUM_EPOCHS=NUM_EPOCHS,
                )

            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                preds_past_month=mpreds, 
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)
            time_end = time.time()
            print(f"{exp_desc} | Month {i} in {time_end - time_start}s")
        
        return local_results







##############################################################################################################

from misc.transcendent_ice import LinearSVC as T_LinearSVC
from misc.transcendent_ice import DrebinMLP as T_DrebinMLP
from misc.transcendent_ice import EnsembleTranscendSelector
from functools import partial

class TranscendentExperiments(DANNMainExperimentsNoAug):
    def __init__(self, NUM_PROCESSES=8, exp_name = None):

        self.NUM_PROCESSES = NUM_PROCESSES
        self.exp_name = exp_name

        DATASETS = [
            "androzoo", 
            "apigraph", 
            #"transcendent_vmonth",
            "transcendent", 
        ]
        MONTHLY_LABEL_BUDGETS = [50, 100, 200, 400]

        NUM_EPOCHS = [
            #10, 
            30, 
            #50,
        ]

        NUM_SAMPLES_PER_MONTH_FIRST_YEAR = [
            400, 
            #600,
        ]
        RANDOM_SEEDS = [
            0, 
            #1, 
            ##2, 
            #3, 
            #4
        ]

        SAMPLER_MODES = [
            "full_first_year_subsample_months", 
            "subsample_first_year_subsample_months",
            #"weight_first_year_subsample_months",
        ]

        CRIT_MODE = [
            "cred", 
            "cred+conf",
        ]

        ENSEMBLE_MODE = [
            "ensemble", 
            #"ensemble_ncm_weighted", 
            "singleton",
        ]

        MODEL_TYPE = [
            #"DeepDrebin",
            "SVM", 
        ]

        REJECTOR_MODE = ["reject_topn_overall",]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            MONTHLY_LABEL_BUDGETS,
            NUM_EPOCHS, 
            SAMPLER_MODES,
            CRIT_MODE, 
            ENSEMBLE_MODE,
            MODEL_TYPE, 
            REJECTOR_MODE,
            NUM_SAMPLES_PER_MONTH_FIRST_YEAR,
            RANDOM_SEEDS
        ))

    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """

        # Set CPU affinity for this process
        DATASET, MONTHLY_LABEL_BUDGET, NUM_EPOCHS, SAMPLER_MODE, CRIT_MODE, ENSEMBLE_MODE, MODEL_TYPE, REJECTOR_MODE, NUM_SAMPLES_PER_MONTH_FIRST_YEAR, RANDOM_SEED = args        
        local_results = []
        # Load dataset for this experiment
        X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()
        
        # Initialize stateful objects
        SAMPLER_MODE_ = SAMPLER_MODE
        if SAMPLER_MODE == "weight_first_year_subsample_months":
            SAMPLER_MODE_ = "full_first_year_subsample_months"

        sampler = DsSampler(
            X_train=X_train, 
            y_train=y_train, 
            X_tests=X_tests, 
            y_tests=y_tests,
            SAMPLER_MODE=SAMPLER_MODE_, 
            MONTHLY_BUDGET_FIRST_YEAR=NUM_SAMPLES_PER_MONTH_FIRST_YEAR, 
            MONTHLY_BUDGET_THEREAFTER=MONTHLY_LABEL_BUDGET,
            NUM_MONTHS_DANN_BACKLOG=1,
            EVAL_EVERY=150,
            NUM_EPOCHS=NUM_EPOCHS,
        )

        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        # Create description for this experiment's progress bar
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{CRIT_MODE}-{ENSEMBLE_MODE}-{MODEL_TYPE}-Budget={MONTHLY_LABEL_BUDGET}-Epochs={NUM_EPOCHS}-#Samples={NUM_SAMPLES_PER_MONTH_FIRST_YEAR}-Seed={RANDOM_SEED}"    

        for i in range(len(X_tests)):
            time_start = time()
            curr_res = {}
            curr_res["Test-Month"] = i

            curr_res["Trainer-Mode"] = "Transcendent (ICE)"

            curr_res["Dataset"] = DATASET
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Num-Epochs"] = NUM_EPOCHS

            curr_res["transcendent-crit-Mode"] = CRIT_MODE
            curr_res["transcendent-ensemble"] = ENSEMBLE_MODE
            curr_res["transcendent-model-type"] = MODEL_TYPE

            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["#-Monthly-Samples-First-Year"] = NUM_SAMPLES_PER_MONTH_FIRST_YEAR
            curr_res["Random-Seed"] = RANDOM_SEED

            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )

            if MODEL_TYPE == "SVM":
                MODEL_CLS = partial(T_LinearSVC, dual=True, C=0.01, max_iter=100)
                mname = "Dist. to Hyperplane"
            elif MODEL_TYPE == "DeepDrebin":
                MODEL_CLS = partial(T_DrebinMLP, input_size=curr_X_source.shape[1])
                mname = "Softmax Uncert."

            ets = EnsembleTranscendSelector(
                model_cls=MODEL_CLS, #LinearSVC,
                crit=CRIT_MODE,
                mode=ENSEMBLE_MODE
            )
            ets.fit(curr_X_source, curr_y_source, n_splits=10)

            median_ncms, preds = ets.predict(X_test=curr_X_test)

            #mname, mpreds, muncs_first, muncs
            muncs = muncs_first = median_ncms
            mpreds = preds

            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)

            time_end = time()
            curr_desc = f"{exp_desc}\t\t | Month {i} in {(time_end - time_start):.0f}s |  "

            log_live_result(
                muncs=muncs, 
                y_preds=mpreds, 
                y_test=curr_y_test, 
                prefix=curr_desc, 
                results=None,
                i=i, 
            )
        return local_results




class CENewSamplers(DANNMainExperimentsNoAug):
    def __init__(self, NUM_PROCESSES=8):
        self.NUM_PROCESSES = NUM_PROCESSES
        DATASETS = [
            #"androzoo", 
            #"apigraph", 
            "transcendent_vmonth",
        ]

        MONTHLY_LABEL_BUDGETS = [50, 100, 200, 400]

        SAMPLER_MODES = [
            "full_first_year_subsample_months",
            "subsample_first_year_subsample_months",
            "resample_monthly_fixed", 
            "resample_monthly_grow", 
        ]
    
        TRAIN_AUG_SCALES = [0.05]

        TRAINER_MODES = ["CE",]

        REJECTOR_MODE = ["reject_topn_overall",]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(
            itertools.product(
                DATASETS,
                SAMPLER_MODES,
                MONTHLY_LABEL_BUDGETS,
                TRAIN_AUG_SCALES, 
                TRAINER_MODES,
                REJECTOR_MODE,
            )
        )

    def run_experiment(self, args):
            """
            Run a complete experiment for one combination of parameters
            """
            # Set CPU affinity for this process
            DATASET, SAMPLER_MODE, MONTHLY_LABEL_BUDGET, TRAIN_AUG_SCALE, TRAINER_MODE, REJECTOR_MODE = args
            
            local_results = []
            # Load dataset for this experiment
            X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()
            
            # Initialize stateful objects
            sampler = DsSampler(
                X_train=X_train, 
                y_train=y_train, 
                X_tests=X_tests, 
                y_tests=y_tests,
                SAMPLER_MODE=SAMPLER_MODE, 
                MONTHLY_BUDGET_FIRST_YEAR=600, 
                MONTHLY_BUDGET_THEREAFTER=MONTHLY_LABEL_BUDGET,
                NUM_MONTHS_DANN_BACKLOG=1,
                EVAL_EVERY=150,
                NUM_EPOCHS=30,
                show_sampler_progress=False,
            )

            rejector = Rejector(
                REJECTOR_MODE=REJECTOR_MODE,
                MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
            )

            # Create description for this experiment's progress bar
            exp_desc = f"{DATASET}-{SAMPLER_MODE}-{TRAINER_MODE}-Scale={TRAIN_AUG_SCALE}-Budget={MONTHLY_LABEL_BUDGET}"
            
            # Run the complete inner loop with its own progress bar
            for i in range(len(X_tests)):
                time_start = time()
                curr_res = {}
                curr_res["Dataset"] = DATASET
                curr_res["Sampler-Mode"] = SAMPLER_MODE
                curr_res["Trainer-Mode"] = TRAINER_MODE
                curr_res["CE-Train-Aug-Scale"] = TRAIN_AUG_SCALE
                curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
                curr_res["Test-Month"] = i

                curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                    i=i,
                    monthly_test_set_selector=rejector.monthly_test_set_selector,
                )

                mname, mpreds, muncs_first, muncs, _ = Trainer(
                    # Data
                    curr_X_source=curr_X_source,
                    curr_X_target=curr_X_target,
                    curr_X_test=curr_X_test,
                    curr_y_source=curr_y_source,
                    curr_y_target=curr_y_target,
                    curr_y_test=curr_y_test,
                    # Hyperparameters
                    i=i,
                    TRAINER_MODE=TRAINER_MODE,
                    CE_TRAIN_SCALE=TRAIN_AUG_SCALE,
                    DANN_TRAIN_SCALE=0.,
                    #DANN_UNCERTAINTY_METHOD=DANN_UNCERTAINTY_METHOD,
                    # General Hyperparameters
                    EVAL_EVERY=150,
                    NUM_EPOCHS=30,

                    device = "cpu",
                )

                curr_res["Method-Name"] = mname
                curr_res["Predictions"] = mpreds
                curr_res["Uncertainties (Month Ahead)"] = muncs_first
                curr_res["Uncertainties (Past Month)"] = muncs
                curr_res["Labels"] = curr_y_test

                reject_mask, select_mask = rejector.update(
                    muncs_month_ahead=muncs_first,
                    muncs_past_month=muncs,
                    i=i,
                )
                curr_res["Rejection-Mask"] = reject_mask
                curr_res["Selection-Mask"] = select_mask
                local_results.append(curr_res)
                time_end = time()

                curr_desc = f"{exp_desc}\t\t | Month {i} in {(time_end - time_start):.0f}s |  "
                log_live_result(
                    muncs=muncs, 
                    y_preds=mpreds, 
                    y_test=curr_y_test, 
                    prefix=curr_desc, 
                    results=None,
                    i=i, 
                )
            
            return local_results




class Baselines(DANNMainExperimentsNoAug):
    def __init__(self, NUM_PROCESSES=8, exp_name = None):
        self.NUM_PROCESSES = NUM_PROCESSES

        self.exp_name = exp_name

        DATASETS = [
            "androzoo",
            "apigraph", 
            #"transcendent_vmonth",
            "transcendent", 
        ]
        MONTHLY_LABEL_BUDGETS = [50, 100, 200, 400]

        NUM_SAMPLES_PER_MONTH_FIRST_YEAR = [
            400, 
        ]
        RANDOM_SEEDS = [
            0, 1, 2, 3, 4
        ]

        SAMPLER_MODES = [
            "full_first_year_subsample_months", 
            "subsample_first_year_subsample_months",
        ]

        TRAINER_MODES = [
            "SVC",
        ]
        SVMCs = [1.0, 0.1, 0.01]

        REJECTOR_MODE = ["reject_topn_overall",]

        # Generate all parameter combinations (excluding the month iteration)        
        self.combinations = list(itertools.product(
            DATASETS,
            MONTHLY_LABEL_BUDGETS,
            SAMPLER_MODES,
            TRAINER_MODES,
            REJECTOR_MODE,
            NUM_SAMPLES_PER_MONTH_FIRST_YEAR,
            SVMCs, 
            RANDOM_SEEDS
        ))
    
    def run_experiment(self, args):
        """
        Run a complete experiment for one combination of parameters
        """
        # Set CPU affinity for this process
        DATASET, MONTHLY_LABEL_BUDGET, SAMPLER_MODE, TRAINER_MODE, REJECTOR_MODE, NUM_SAMPLES_PER_MONTH_FIRST_YEAR, SVM_C, RANDOM_SEED = args
        
        local_results = []
        # Load dataset for this experiment
        X_train, X_tests, y_train, y_tests = LoadHCCDatasets(DATASET).load_binary_labels()
        
        # Initialize stateful objects
        sampler = DsSampler(
            X_train=X_train, 
            y_train=y_train, 
            X_tests=X_tests, 
            y_tests=y_tests,
            SAMPLER_MODE=SAMPLER_MODE, 
            MONTHLY_BUDGET_FIRST_YEAR=NUM_SAMPLES_PER_MONTH_FIRST_YEAR, 
            MONTHLY_BUDGET_THEREAFTER=MONTHLY_LABEL_BUDGET,
            NUM_MONTHS_DANN_BACKLOG=1,
            EVAL_EVERY=150,
            NUM_EPOCHS=30,
            show_sampler_progress=False,
        )

        rejector = Rejector(
            REJECTOR_MODE=REJECTOR_MODE,
            MONTHLY_LABEL_BUDGET=MONTHLY_LABEL_BUDGET,
        )

        # Create description for this experiment's progress bar
        exp_desc = f"{DATASET}-{SAMPLER_MODE}-{TRAINER_MODE}-Budget={MONTHLY_LABEL_BUDGET}-#Samples={NUM_SAMPLES_PER_MONTH_FIRST_YEAR}-SVM_C={SVM_C}-Seed={RANDOM_SEED}"
        
        # Run the complete inner loop with its own progress bar
        for i in range(len(X_tests)):
            time_start = time()
            curr_res = {}
            curr_res["Dataset"] = DATASET
            curr_res["Sampler-Mode"] = SAMPLER_MODE
            curr_res["Trainer-Mode"] = TRAINER_MODE
            curr_res["Monthly-Label-Budget"] = MONTHLY_LABEL_BUDGET
            curr_res["#-Monthly-Samples-First-Year"] = NUM_SAMPLES_PER_MONTH_FIRST_YEAR
            curr_res["svm_c"] = SVM_C
            curr_res["Random-Seed"] = RANDOM_SEED
            curr_res["Test-Month"] = i

            curr_X_source, curr_X_target, curr_X_test, curr_y_source, curr_y_target, curr_y_test = sampler.sample(
                i=i,
                monthly_test_set_selector=rejector.monthly_test_set_selector,
            )

            mname, mpreds, muncs_first, muncs, _ = Trainer(
                # Data
                curr_X_source=curr_X_source,
                curr_X_target=curr_X_target,
                curr_X_test=curr_X_test,
                curr_y_source=curr_y_source,
                curr_y_target=curr_y_target,
                curr_y_test=curr_y_test,
                # Hyperparameters
                TRAINER_MODE=TRAINER_MODE,
                SVM_C=SVM_C,
                i=i,
            )

            curr_res["Method-Name"] = mname
            curr_res["Predictions"] = mpreds
            curr_res["Uncertainties (Month Ahead)"] = muncs_first
            curr_res["Uncertainties (Past Month)"] = muncs
            curr_res["Labels"] = curr_y_test

            reject_mask, select_mask = rejector.update(
                muncs_month_ahead=muncs_first,
                muncs_past_month=muncs,
                i=i,
            )
            curr_res["Rejection-Mask"] = reject_mask
            curr_res["Selection-Mask"] = select_mask
            local_results.append(curr_res)
            time_end = time()

            curr_desc = f"{exp_desc}\t\t | Month {i} in {(time_end - time_start):.0f}s |  "
            log_live_result(
                muncs=muncs, 
                y_preds=mpreds, 
                y_test=curr_y_test, 
                prefix=curr_desc, 
                results=None,
                i=i, 
            )
        
        return local_results




if __name__ == "__main__":
    '''
    results = DANNMainExperimentsNoAug(NUM_PROCESSES=8).run()
    with open("parallel_dann_no_aug.pkl", "wb") as f:
        pickle.dump(results, f)
    '''
    '''
    results = CENewSamplers(NUM_PROCESSES=8).run()
    with open("parallel_ce_new_samplers.pkl", "wb") as f:
        pickle.dump(results, f) 
    '''

    '''
    results = DANN_FE_Experiments(NUM_PROCESSES=8).run()
    with open("parallel_dann_fe.pkl", "wb") as f:
        pickle.dump(results, f)
    print("Done!")
    '''
    ####################################################
    #'''
    exp_name = "parallel_ce_random_monthly_subsampling.pkl"
    results = CEMainExperiments(NUM_PROCESSES=4, exp_name=exp_name).run()
    with open(exp_name, "wb") as f:
        pickle.dump(results, f)
    #'''

    '''
    exp_name = "parallel_svc_v2.pkl"
    results = Baselines(NUM_PROCESSES=8, exp_name=exp_name).run()
    with open(exp_name, "wb") as f:
        pickle.dump(results, f)
    '''

    '''
    #TranscendentExperiments
    exp_name = "transcendent_svm_dd_v2.pkl"
    results = TranscendentExperiments(NUM_PROCESSES=4, exp_name=exp_name).run()
    with open(exp_name, "wb") as f:
        pickle.dump(results, f)
    '''



