from abc import ABC

import pytorch_lightning as pl

from src.utils.misc import *
from src.learners.single_task import SingleTaskLearner
from src.data.datasets.dataset_utils import get_dataset_split

TRAINING_KEY = "training/"
VALIDATION_KEY = "validation/"
TEST_KEY = "testing/"


@quick_register
class SingleMixtureTaskLearner(SingleTaskLearner, ABC):
    def __init__(self, model, train_loader, valid_loader, test_loader, optimizer, lr_scheduler, loss_fn,
                 logging_functions, dirs_dict, regularization_fn=None, use_lookahead=False, **kwargs):
        super().__init__(model=model, train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader,
                         optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fn=loss_fn,
                         logging_functions=logging_functions, dirs_dict=dirs_dict,
                         regularization_fn=regularization_fn, use_lookahead=use_lookahead, **kwargs)

        # Get the split ids (i.e. related to the corruption levels of images etc. ) used in each dataset.
        self.train_splits, _ = get_dataset_split(dataset=self.train_loader.dataset)

        valid_datasets = [loader.dataset for loader in self.valid_loader]
        test_datasets = [loader.dataset for loader in self.test_loader]
        self.valid_splits, self.val_dataloader_idx_2_split = get_dataset_split(dataset=valid_datasets)
        self.test_splits, self.test_dataloader_idx_2_split = get_dataset_split(dataset=test_datasets)

        # Run some sanity checks on the datasets and dataloaders.
        assert self.valid_splits == self.test_splits, f"Validation and testing splits must be the same. "

        is_single_split_per_dataloader = lambda d: all([len(list(v)) == 1 for v in d.values()])
        assert is_single_split_per_dataloader(self.val_dataloader_idx_2_split), \
            f"Assume each val or test loader corresponds to single split. "
        assert is_single_split_per_dataloader(self.test_dataloader_idx_2_split), \
            f"Assume each val or test loader corresponds to single split. "

        # Determine which splits are in-distribution, and which are OOD.
        self.in_distribution_splits = self.train_splits
        self.ood_splits = self.valid_splits - self.train_splits

    # ____ Validation. ____
    def validation_step(self, data_batch, batch_nb, dataloader_idx, **kwargs):
        # Update the prepend key to reflect the dataloader that's being used.
        prepend_key = VALIDATION_KEY + f"dataloader_idx_{dataloader_idx}/"

        # Run forward pass and log.
        _, metric_logs = self.common_step(data_batch, batch_nb, optimizer_idx=0, prepend_key=prepend_key,
                                          dataloader_idx=dataloader_idx)

        return metric_logs

    def validation_epoch_end(self, outputs):
        averaged_metrics = average_evaluation_results(outputs)

        # Compute the average in-distribution validation error.
        avg_in_dist_val_error, avg_ood_val_error = self._get_average_in_and_ood_validation_errors(averaged_metrics)
        averaged_metrics[f"{VALIDATION_KEY}average_in_distribution_val_error"] = avg_in_dist_val_error
        averaged_metrics[f"{VALIDATION_KEY}average_ood_val_error"] = avg_ood_val_error

        # Log the averaged metrics.
        self.logger.log_metrics(averaged_metrics, step=self.global_step)

        # Empty cache for memory-intensive use cases.
        # torch.cuda.empty_cache()

        return averaged_metrics

    def _get_average_in_and_ood_validation_errors(self, metrics):
        val_error_keys = [k for k in metrics.keys() if "val" in k and "error" in k]
        in_distribution_val_dataloader_idx = [k for k, v in self.val_dataloader_idx_2_split.items() if len(self.in_distribution_splits.intersection(v)) > 0]
        ood_val_dataloader_idx = [k for k, v in self.val_dataloader_idx_2_split.items() if len(self.ood_splits.intersection(v)) > 0]
        is_in_distribution_key = lambda k: any([str(s) in k for s in in_distribution_val_dataloader_idx])
        is_ood_key = lambda k: any([str(s) in k for s in ood_val_dataloader_idx])

        in_dist_val_error_keys = [k for k in val_error_keys if is_in_distribution_key(k)]
        ood_val_error_keys = [k for k in val_error_keys if is_ood_key(k)]

        avg_in_dist_val_error = float(np.mean([v for k, v in metrics.items() if k in in_dist_val_error_keys]))
        avg_ood_dist_val_error = float(np.mean([v for k, v in metrics.items() if k in ood_val_error_keys]))

        return avg_in_dist_val_error, avg_ood_dist_val_error
