import logging
import os
import datetime
import torchvision.models as models
import torchvision.transforms.functional as TF
import math
import torch
import yaml
from easydict import EasyDict
import shutil
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from salient_imagenet_utils import *

from torchvision import transforms as transforms
from torch.utils.data import Dataset, DataLoader

class CustomDataSet(Dataset):
    def __init__(self, imagenet_path, cams_path, resize_size=224, mask_type='core', split='train', augment_scale=0):
        assert (mask_type == 'core' or mask_type == 'spurious')
        
        self.augment_scale = augment_scale
        self.split = split

        self.base_transform = transforms.Compose([
            transforms.Resize((resize_size, resize_size)),
            transforms.ToTensor()])

        self.resize = transforms.Resize((resize_size, resize_size))

        self.resize_size = resize_size
        self.imagenet_path = imagenet_path
        # self.cams_path = cams_path
        self.cams_path = os.path.join(cams_path, split)#'train')
        
        metadata_file = os.path.join(self.cams_path, 'metadata.csv')
        metadata_df = pd.read_csv(metadata_file)
                
        self.image_indices = np.array(metadata_df['image_index'])
        self.image_paths = np.array(metadata_df['image_path'])
        self.class_indices = np.array(metadata_df['class_index'])
        
        self.metadata_df = metadata_df
        
        mturk_metadata_path = os.path.join(cams_path, 'whole_imagenet_results/approved_results_new.csv')
        mturk_results_discover = MTurk_Results(mturk_metadata_path)

        self.core_spurious_dict = mturk_results_discover.core_spurious_dict
        
        self.mask_type = mask_type
                
    def __len__(self):
        return len(self.image_paths)

    def transform(self, img, mask):
        ''' accepts list of corresponding images (i.e. image and core mask).
            Converts imgs to tensors, and applies augmentation w/ same params to each image
            All images receive same random resized crop + random horizontal flip '''
        transformed_imgs = []
        imgs = [self.resize(img), mask]
        # i, j, h, w = transforms.RandomResizedCrop.get_params(imgs[0], scale=(0.8,1.0),ratio=(0.75,1.25))
        i, j, h, w = transforms.RandomResizedCrop.get_params(imgs[0], scale=(self.augment_scale,1.0),ratio=(0.75,1.25))
        coin_flip = (random.random() < 0.5)
        for ind, img in enumerate(imgs):
            if self.split == 'train':
                img = TF.crop(img, i, j, h, w)

                if coin_flip:
                    img = TF.hflip(img)
            
            img = self.base_transform(img)
            if img.shape[0] == 1:
                img = torch.cat([img, img, img], axis=0)
            
            transformed_imgs.append(img)

        return transformed_imgs

    def __getitem__(self, index):
        image_index = self.image_indices[index]
        if self.split == 'train':
            image_path = os.path.join(self.imagenet_path, self.image_paths[index])
        else:
            image_path = os.path.join(self.cams_path, self.image_paths[index])

        class_index = self.class_indices[index]
        class_path = os.path.join(self.cams_path, 'class_' + str(class_index))

        image = Image.open(image_path).convert("RGB")
        
        row = self.metadata_df.iloc[index]
        
        all_mask = np.zeros((224, 224))
        for j in range(5):
            feature_index, image_rank = row['feature_index_' + str(j)], row['image_rank_' + str(j)]
            
            if feature_index == -1:
                continue

            main_key = str(class_index) + '_' + str(feature_index)
            if self.core_spurious_dict[main_key] == self.mask_type:
                feature_path = os.path.join(class_path, 'feature_' + str(feature_index))
                cams_path = os.path.join(feature_path, 'cams')

                cam_path = os.path.join(cams_path, str(image_index) + '.jpg')
                mask = cv2.imread(cam_path, cv2.IMREAD_GRAYSCALE)
                mask = 1. - (mask/255.)

                all_mask = np.maximum(all_mask, mask)
        
        
        all_mask = np.uint8(all_mask * 255)
        all_mask = Image.fromarray(all_mask)
        if self.augment_scale > 0:
            image_tensor, mask_tensor = self.transform(image, all_mask)
        else:
            image_tensor, mask_tensor = [self.base_transform(x) for x in [image, all_mask]]

        # img, mask_tensor = self.transform([image_tensor, all_mask])
        # image_tensor = self.transform(image)
        # mask_tensor = self.transform(all_mask)
        return image_tensor, mask_tensor, class_index
    
# train_dataset = CustomDataSet('/fs/cml-datasets/ImageNet/ILSVRC2012', 
#                               '/REDACTED/salient_imagenet_dataset/', 
#                               resize_size=224)

def dilate_erode_fast(masks, dilate=True, iterations=15, kernel=5):
    assert kernel % 2 == 1
    half_k = kernel // 2
    batch_size, _, side_len, _ = masks.shape

    out = masks[:,0,:,:].clone()
    padded = torch.zeros(batch_size, side_len+2*half_k, side_len+2*half_k, device=masks.device)
    if not dilate:
        padded = 1 + padded
    for itr in range(iterations):
        all_padded = []
        centered = padded.clone()
        centered[:, half_k:half_k+side_len, half_k:half_k+side_len]; all_padded.append(centered)
        for j in range(1, half_k+1):
            left, right, up, down = [padded.clone() for _ in range(4)]
            left[:, half_k-j:half_k-j+side_len, half_k:half_k+side_len] = out; all_padded.append(left)
            right[:, half_k+j:half_k+j+side_len, half_k:half_k+side_len] = out; all_padded.append(right)
            up[:, half_k:half_k+side_len, half_k+j:half_k+j+side_len] = out; all_padded.append(up)
            down[:, half_k:half_k+side_len, half_k-j:half_k-j+side_len] = out; all_padded.append(down)
        all_padded = torch.stack(all_padded)
        out = torch.max(all_padded, dim=0)[0] if dilate else torch.min(all_padded, dim=0)[0]
        out = out[:, half_k:half_k+side_len, half_k:half_k+side_len]
    
    out = torch.stack([out, out, out], dim=1)
    out = out / torch.max(out)
    return out