import torch
import copy
import math

import numpy as np

from torchvision import transforms
from multiprocessing import Pool

def train_val_split(dataset, val_ratio):
    indexes = torch.randperm(len(dataset.data))
    val_size = int(len(dataset.data) * val_ratio)
    train_indexes = indexes[val_size:]
    val_indexes = indexes[:val_size]
    return list(map(dataset.data.__getitem__, train_indexes)), list(
        map(dataset.data.__getitem__, val_indexes)
    )


class InputNormalize(torch.nn.Module):
    """
    A module (custom layer) for normalizing the input to have a fixed
    mean and standard deviation (user-specified).
    """

    def __init__(self, new_mean, new_std):
        super(InputNormalize, self).__init__()
        new_std = new_std[..., None, None]
        new_mean = new_mean[..., None, None]

        self.register_buffer("new_mean", new_mean)
        self.register_buffer("new_std", new_std)

    def forward(self, x):
        x = torch.clamp(x, 0, 1)
        x_normalized = (x - self.new_mean) / self.new_std
        return x_normalized


class NormalizedModel(torch.nn.Module):
    """ """

    def __init__(self, model, dataset):
        super(NormalizedModel, self).__init__()
        self.normalizer = InputNormalize(dataset.mean, dataset.std)
        self.model = model

    def forward(self, inp):
        """ """
        normalized_inp = self.normalizer(inp)
        output = self.model(normalized_inp)
        return output


def extract_number_info(file_path):
    name = file_path.split("/")[-1]
    name_without_prefix = name.split(".")[0]
    name_components = name_without_prefix.split("_")
    if name_components[-3].isnumeric():
        return (
            int(name_components[-3]),
            int(name_components[-2]),
            int(name_components[-1]),
        )
    else:
        return (int(name_components[-2]), int(name_components[-1]))


def filter_data_by_label(data, targets, class_labels_to_filter):
    """
    extract indices of data that have labels that exist in the desired_class_labels
    """
    filtered_target_idx = torch.cat(
        [torch.where(targets == label)[0] for label in class_labels_to_filter]
    )
    return data[filtered_target_idx], targets[filtered_target_idx]


def group_labels(targets, old_to_new_label_mapping):
    """
    assign new labels to data based on the label_grouping
    """
    new_labels = list(old_to_new_label_mapping.keys())
    old_label_groupings = list(old_to_new_label_mapping.values())

    for i, target in enumerate(targets):
        for idx, old_label_grouping in enumerate(old_label_groupings):
            if target in old_label_grouping:
                target = new_labels[idx]

        targets[i] = torch.tensor(int(target))
    return targets


def make_images_rgb(data):
    rgb_images = []
    for image in data:
        image = transforms.Lambda(lambda x: x.repeat(3, 1, 1))(image)
        rgb_images.append(image)
    return torch.stack(rgb_images)


def add_squares_to_images(
    data,
    targets,
    data_percentage_to_add_square,
    square_number,
    reverse=False,
    square_size=6,
):

    red_square = torch.zeros((3, square_size, square_size), dtype=torch.uint8)
    red_square[0, :, :] = 255
    red_square[1, :, :] = 0
    red_square[2, :, :] = 0
    blue_square = torch.zeros((3, square_size, square_size), dtype=torch.uint8)
    blue_square[0, :, :] = 0
    blue_square[1, :, :] = 0
    blue_square[2, :, :] = 255

    if reverse:
        tmp_square = copy.deepcopy(red_square)
        red_square = blue_square
        blue_square = tmp_square
    target_one_data_count = torch.sum(targets)
    target_zero_data_count = len(data) - target_one_data_count
    target_one_data_count_to_add_square = (
        target_one_data_count * data_percentage_to_add_square
    )
    target_zero_data_count_to_add_square = (
        target_zero_data_count * data_percentage_to_add_square
    )
    modified_data = []
    for image, target in zip(data, targets):

        if target == 1:
            if target_one_data_count_to_add_square > 0:
                target_one_data_count_to_add_square -= 1
                square = red_square
            else:
                square = blue_square
        else:
            if target_zero_data_count_to_add_square > 0:
                target_zero_data_count_to_add_square -= 1
                square = blue_square
            else:
                square = red_square

        if square_number == 4:
            image[:, 0:square_size, 0:square_size] = square
            image[:, -square_size:, -square_size:] = square
            image[:, -square_size:, 0:square_size] = square
            image[:, 0:square_size, -square_size:] = square
        elif square_number == 1:
            image[:, 0:square_size, 0:square_size] = square

        modified_data.append(image)

    return torch.stack(modified_data)


def add_background_to_images(
    data,
    targets,
    data_percentage_to_add_square,
    reverse=False,
):
    blue_channel, red_channel = 2, 0
    if reverse:
        blue_channel, red_channel = red_channel, blue_channel
    target_one_data_count = torch.sum(targets)
    target_zero_data_count = len(data) - target_one_data_count
    target_one_data_count_to_add_square = (
        target_one_data_count * data_percentage_to_add_square
    )
    target_zero_data_count_to_add_square = (
        target_zero_data_count * data_percentage_to_add_square
    )
    modified_data = []
    for image, target in zip(data, targets):

        if target == 1:
            if target_one_data_count_to_add_square > 0:
                target_one_data_count_to_add_square -= 1
                channel = red_channel
            else:
                channel = blue_channel
        else:
            if target_zero_data_count_to_add_square > 0:
                target_zero_data_count_to_add_square -= 1
                channel = blue_channel
            else:
                channel = red_channel

        image[channel, :, :][image[channel, :, :] == 0] = 255

        modified_data.append(image)

    return torch.stack(modified_data)

def get_heat_map(images, heat_map_generator):
    return heat_map_generator(images)

def calculate_data_heat_map_mean(data_loader, gpu_ids, heat_map_generators):
    batch_means = []
    pool = Pool()
    for images, _, _ in data_loader:
        splitted_images = list(torch.split(images, math.ceil(len(images)/len(gpu_ids))))
        for i in range(len(splitted_images)):
            splitted_images[i].to(f"cuda:{gpu_ids[i]}")
        heat_maps = pool.starmap(get_heat_map, zip(
            splitted_images, heat_map_generators))
        heat_maps = heat_maps[~np.isnan(heat_maps)]
        if len(heat_maps) == 0:
            continue
        batch_means.append(heat_maps.mean())
    heat_maps_mean = np.array(batch_means).mean()
    pool.close()
    print("Threshold is set to {}".format(heat_maps_mean))
    return heat_maps_mean


def add_one_square_to_images(data, targets, data_percentage_to_add_square, reverse, square_size):
    red_square = torch.zeros((3, square_size, square_size), dtype=torch.uint8)
    red_square[0, :, :] = 255
    red_square[1, :, :] = 0
    red_square[2, :, :] = 0
    if reverse:
        data_percentage_to_add_square = 1 - data_percentage_to_add_square
    target_one_data_count = torch.sum(targets)
    target_zero_data_count = len(data) - target_one_data_count
    target_one_data_count_to_add_square = (
        target_one_data_count * data_percentage_to_add_square
    )
    target_zero_data_count_to_add_square = (
        target_zero_data_count * data_percentage_to_add_square
    )
    modified_data = []
    for image, target in zip(data, targets):

        if target == 1:
            if target_one_data_count_to_add_square > 0:
                target_one_data_count_to_add_square -= 1
                image[:, 0:square_size, 0:square_size] = red_square
        else:
            if target_zero_data_count_to_add_square > 0:
                target_zero_data_count_to_add_square -= 1
            else:
                image[:, 0:square_size, 0:square_size] = red_square
        modified_data.append(image)

    return torch.stack(modified_data)