from PIL import Image
import os
import torch
import cv2

from multiprocessing import Pool
from itertools import cycle
from PIL import ImageFilter
from torchvision.transforms import transforms

import numpy as np
import torch.nn as nn

def apply_mask_and_save_individual_image_in_png(image_mask, masked_data_save_dir, target, image_path):
    target_dir = os.path.join(masked_data_save_dir, str(target.item()))
    os.makedirs(target_dir, exist_ok=True)
    original_image = Image.open(image_path).convert('RGB')
    image_mask = np.expand_dims(cv2.resize(image_mask, dsize=original_image.size, interpolation=cv2.INTER_NEAREST), axis=-1)
    original_image = np.array(original_image) * image_mask
    im = Image.fromarray(original_image.astype(np.uint8))
    im.save(os.path.join(target_dir, image_path.split("/")[-1]))



def apply_mask_and_save_images(image_masks, masked_data_save_dir, images_pathes, targets):
    pool = Pool()
    pool.starmap(apply_mask_and_save_individual_image_in_png, zip(
        image_masks, cycle([masked_data_save_dir]), targets, images_pathes))
    pool.close()

def blur_masked_region_and_save_individual_image_in_png(image, image_mask, path, target, id, std, mean):
    image = (image * std) + mean
    image = (255.0 * image).astype(np.uint8)
    image = Image.fromarray(np.rollaxis(image, 0, 3))
    blurred = image.filter(ImageFilter.GaussianBlur(5))
    image_mask = np.squeeze(((1-image_mask)*255.0).astype(np.uint8))
    image.paste(blurred, mask=Image.fromarray(image_mask))
    image.save(os.path.join(path, f"{id}_{target}.png"))

def blur_masked_region_and_save_images(images, image_masks, path, ids, targets, std, mean):
    pool = Pool()
    pool.starmap(blur_masked_region_and_save_individual_image_in_png, zip(
        images, image_masks, cycle([path]), targets, ids, cycle([std]), cycle([mean])))
    pool.close()

def change_masked_region_color_and_save_individual_image_in_png(image, image_mask, path, target, id, std, mean):
    color_jitter = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
    image = (image * std) + mean
    image = (255.0 * image).astype(np.uint8)
    image = Image.fromarray(np.rollaxis(image, 0, 3))
    color_jittered = color_jitter(image)
    image_mask = np.rollaxis(image_mask, 0, 3)
    image = Image.fromarray(((image * image_mask)+(color_jittered*(1-image_mask))).astype(np.uint8))
    image.save(os.path.join(path, f"{id}_{target}.png"))

def change_masked_region_color_and_save_images(images, image_masks, path, ids, targets, std, mean):
    pool = Pool()
    pool.starmap(change_masked_region_color_and_save_individual_image_in_png, zip(
        images, image_masks, cycle([path]), targets, ids, cycle([std]), cycle([mean])))
    pool.close()


def save_numpy(data, file_name, path):
    np.save(os.path.join(path, file_name), data)


def save_checkpoint(
    model, optimizer, lr_scheduler, checkpoint_path: str, current_epoch
):
    state = {
        "optimizer": optimizer.state_dict(),
        "scheduler": lr_scheduler.state_dict(),
        "epoch": current_epoch,
    }
    if isinstance(model, nn.DataParallel):
        state["model"] = model.module.state_dict()
    else:
        state["model"] = model.state_dict()
    torch.save(state, checkpoint_path)
    del state
    torch.cuda.empty_cache()


def load_checkpoint(model, optimizer, lr_scheduler, checkpoint_path: str):
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(checkpoint_path)
    else:
        state = torch.load(checkpoint_path)
        i = 0
        if isinstance(model, nn.DataParallel):
            model_dict = model.module.state_dict()
        else:
            model_dict = model.state_dict()
        model_keys = list(model_dict.keys())
        for key in list(state['model'].keys()):
            if i < len(model_keys) and model_keys[i] in key:
                model_dict[model_keys[i]] = state['model'][key]
                i += 1
        if isinstance(model, nn.DataParallel):
            model.module.load_state_dict(model_dict)
        else:
            model.load_state_dict(model_dict)
        if optimizer is not None:
            optimizer.load_state_dict(state["optimizer"])
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(state["scheduler"])
        current_epoch = state["epoch"] + 1
        del state
        torch.cuda.empty_cache()
        return model, optimizer, lr_scheduler, current_epoch


def load_data_on_ram(image_data_file):
    return Image.open(image_data_file)