import ipdb
import torch
import torchvision
from .dataset_tinyimagenet import load_train_dataset, load_val_dataset, obtain_aug
import torch.utils.data as data
import numpy as np
from PIL import Image
import os
from robustness.tools.breeds_helpers import setup_breeds
from robustness.tools.breeds_helpers import make_living17, make_entity30
from robustness.tools.breeds_helpers import print_dataset_info
from robustness.tools.breeds_helpers import ClassHierarchy
from robustness import datasets
from robustness.tools import folder
from .stl_cifar_style import get_dataset_stl_pretrain, get_dataset_stl_train
from .cifar_stl_style import get_dataset_cifar_stl_style_pretrain
from .domainnet import get_dataset_domainet_pretrain, get_dataset_domainnet_onedomain
from .domainnet import DomainNet, DomainNetPair


def get_dataset(dataset, data_dir, transform, train=True, download=True):
    if dataset == 'cifar10':
        dataset = torchvision.datasets.CIFAR10('PATH_TO_DATASET', train=train, transform=transform, download=download)
    elif dataset == 'cifar100':
        dataset = torchvision.datasets.CIFAR100('PATH_TO_DATASET', train=train, transform=transform, download=download)
    elif dataset == 'imagenet':
        dataset = load_train_dataset(dataset, transform) if train==True else load_val_dataset(dataset, transform)
    elif dataset == 'tiny-imagenet':
        dataset = load_train_dataset(dataset, transform) if train==True else load_val_dataset(dataset, transform)
    else:
        raise NotImplementedError
    return dataset


def get_dataset_breeds_pretrain(dataset, data_dir, transform, train=True, download=True):
    data_dir = '/u/scr/nlp/imagenet'
    info_dir = '/tiger/u/jhaochen/BREEDS-Benchmarks/imagenet_class_hierarchy/modified'
    if not (os.path.exists(info_dir) and len(os.listdir(info_dir))):
        print("Downloading class hierarchy information into `info_dir`")
        setup_breeds(info_dir)
    if dataset == 'living17':
        ret = make_living17(info_dir, split=None)
        superclasses, subclass_split, label_map = ret
        all_subclasses = subclass_split[0]
        # ipdb.set_trace()
        dataset_tmp = datasets.CustomImageNet(data_dir, all_subclasses)
        # dataset = datasets.CustomImageNet(data_dir, all_subclasses, transform_train=transform)
        traindir = os.path.join(data_dir, 'train')
        dataset = folder.ImageFolder(traindir, transform, label_mapping=dataset_tmp.label_mapping)
        return dataset
    if dataset == 'entity30':
        ret = make_entity30(info_dir, split=None)
        superclasses, subclass_split, label_map = ret
        all_subclasses = subclass_split[0]
        # ipdb.set_trace()
        dataset_tmp = datasets.CustomImageNet(data_dir, all_subclasses)
        # dataset = datasets.CustomImageNet(data_dir, all_subclasses, transform_train=transform)
        traindir = os.path.join(data_dir, 'train')
        dataset = folder.ImageFolder(traindir, transform, label_mapping=dataset_tmp.label_mapping)
        return dataset
