import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import torch.utils.data as data
from PIL import Image
import json
from torch.utils.data import Dataset


class CelebaDataset(Dataset):
    """Custom Dataset for loading CelebA face images"""

    def __init__(self, csv_path, img_dir, transform=None):
        df = pd.read_csv(csv_path, index_col=0)
        self.img_dir = img_dir
        self.csv_path = csv_path
        self.img_names = df.index.values
        self.y = df.values
        self.transform = transform

        ### Compute imratio for each label
        imratio_list = []
        for index in range(self.y.shape[1]):
            row_ind = self.y[index]
            unique, counts = np.unique(row_ind, return_counts=True)
            count_dict = dict(zip(unique, counts))
            try:
                one_count = count_dict[1]
            except:
                one_count = 0
            zero_count = count_dict[0]
            imratio = one_count / (one_count + zero_count)
            imratio_list.append(imratio)
            # print("Index = ", index, ", imratio = ", imratio)
        self.imratio_list = imratio_list
        # print(imratio_list)

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))

        if self.transform is not None:
            img = self.transform(img)

        label = self.y[index]
        # return index, self.img_names[index], img, label
        return img, label

    def __len__(self):
        return self.y.shape[0]
