import os
import logging
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random
from einops import rearrange, repeat, reduce

import os
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler

import pandas as pd

def _get_index_train_test_path(data_directory_path, split_num, train=True):
    """
       Method to generate the path containing the training/test split for the given
       split number (generally from 1 to 20).
       @param split_num      Split number for which the data has to be generated
       @param train          Is true if the data is training data. Else false.
       @return path          Path of the file containing the requried data
    """
    if train:
        return os.path.join(data_directory_path, "index_train_" + str(split_num) + ".txt")
    else:
        return os.path.join(data_directory_path, "index_test_" + str(split_num) + ".txt")

def onehot_encode_cat_feature(X, cat_var_idx_list):
    """
    Apply one-hot encoding to the categorical variable(s) in the feature set,
        specified by the index list.
    """
    # select numerical features
    X_num = np.delete(arr=X, obj=cat_var_idx_list, axis=1)
    # select categorical features
    X_cat = X[:, cat_var_idx_list]
    X_onehot_cat = []
    for col in range(X_cat.shape[1]):
        X_onehot_cat.append(pd.get_dummies(X_cat[:, col], drop_first=True))
    X_onehot_cat = np.concatenate(X_onehot_cat, axis=1).astype(np.float32)
    dim_cat = X_onehot_cat.shape[1]  # number of categorical feature(s)
    X = np.concatenate([X_num, X_onehot_cat], axis=1)
    return X, dim_cat


def preprocess_uci_feature_set(X, data_path):
    """
    Obtain preprocessed UCI feature set X (one-hot encoding applied for categorical variable)
        and dimension of one-hot encoded categorical variables.
    """
    dim_cat = 0
    task_name = data_path.split('/')[-1]
    if task_name == 'bostonHousing':
        X, dim_cat = onehot_encode_cat_feature(X, [3])
    elif task_name == 'energy':
        X, dim_cat = onehot_encode_cat_feature(X, [4, 6, 7])
    elif task_name == 'naval-propulsion-plant':
        X, dim_cat = onehot_encode_cat_feature(X, [0, 1, 8, 11])
    return X, dim_cat

class UCI(Dataset):
    def __init__(self, data_path, task, split=0, train_split='train', normalize_x=True, normalize_y=True, train_ratio=0.6, device='cuda'):
        data_dir = os.path.join(data_path, task, 'data')
        data_file = os.path.join(data_dir, 'data.txt')
        index_feature_file = os.path.join(data_dir, 'index_features.txt')
        index_target_file = os.path.join(data_dir, 'index_target.txt')
        n_splits_file = os.path.join(data_dir, 'n_splits.txt')

        data = np.loadtxt(data_file)
        index_features = np.loadtxt(index_feature_file)
        index_target = np.loadtxt(index_target_file)

        X = data[:, [int(i) for i in index_features.tolist()]].astype(np.float32)
        y = data[:, int(index_target.tolist())].astype(np.float32)

        X, dim_cat = preprocess_uci_feature_set(X=X, data_path=data_path)
        self.dim_cat = dim_cat

        # load the indices of the train and test sets
        index_train = np.loadtxt(_get_index_train_test_path(data_dir, split, train=True))
        index_test = np.loadtxt(_get_index_train_test_path(data_dir, split, train=False))

        # read in data files with indices
        x_train = X[[int(i) for i in index_train.tolist()]]
        y_train = y[[int(i) for i in index_train.tolist()]].reshape(-1, 1)
        x_test = X[[int(i) for i in index_test.tolist()]]
        y_test = y[[int(i) for i in index_test.tolist()]].reshape(-1, 1)

        # split train set further into train and validation set for hyperparameter tuning
        num_training_examples = int(train_ratio * x_train.shape[0])
        x_val = x_train[num_training_examples:, :]
        y_val = y_train[num_training_examples:]
        x_train = x_train[0:num_training_examples, :]
        y_train = y_train[0:num_training_examples]

        self.x_train = x_train if type(x_train) is torch.Tensor else torch.from_numpy(x_train)
        self.y_train = y_train if type(y_train) is torch.Tensor else torch.from_numpy(y_train)
        self.x_test = x_test if type(x_test) is torch.Tensor else torch.from_numpy(x_test)
        self.y_test = y_test if type(y_test) is torch.Tensor else torch.from_numpy(y_test)
        self.x_val = x_val if type(x_val) is torch.Tensor else torch.from_numpy(x_val)
        self.y_val = y_val if type(y_val) is torch.Tensor else torch.from_numpy(y_val)

        self.train_n_samples = x_train.shape[0]
        self.train_dim_x = self.x_train.shape[1]  # dimension of training data input
        self.train_dim_y = self.y_train.shape[1]  # dimension of training regression output

        self.test_n_samples = x_test.shape[0]
        self.test_dim_x = self.x_test.shape[1]  # dimension of testing data input
        self.test_dim_y = self.y_test.shape[1]  # dimension of testing regression output

        self.normalize_x = normalize_x
        self.normalize_y = normalize_y
        self.scaler_x, self.scaler_y = None, None

        if self.normalize_x:
            self.normalize_train_test_x()
        if self.normalize_y:
            self.normalize_train_test_y()

        self.return_dataset(train_split, device=device)

    def normalize_train_test_x(self):
        """
        When self.dim_cat > 0, we have one-hot encoded number of categorical variables,
            on which we don't conduct standardization. They are arranged as the last
            columns of the feature set.
        """
        self.scaler_x = StandardScaler(with_mean=True, with_std=True)
        if self.dim_cat == 0:
            self.x_train = torch.from_numpy(
                self.scaler_x.fit_transform(self.x_train).astype(np.float32))
            self.x_test = torch.from_numpy(
                self.scaler_x.transform(self.x_test).astype(np.float32))
            self.x_val = torch.from_numpy(
                self.scaler_x.transform(self.x_val).astype(np.float32))
        else:  # self.dim_cat > 0
            x_train_num, x_train_cat = self.x_train[:, :-self.dim_cat], self.x_train[:, -self.dim_cat:]
            x_test_num, x_test_cat = self.x_test[:, :-self.dim_cat], self.x_test[:, -self.dim_cat:]
            x_val_num, x_val_cat = self.x_val[:, :-self.dim_cat], self.x_val[:, -self.dim_cat:]
            x_train_num = torch.from_numpy(
                self.scaler_x.fit_transform(x_train_num).astype(np.float32))
            x_test_num = torch.from_numpy(
                self.scaler_x.transform(x_test_num).astype(np.float32))
            x_val_num = torch.from_numpy(
                self.scaler_x.transform(x_val_num).astype(np.float32))
            self.x_train = torch.from_numpy(np.concatenate([x_train_num, x_train_cat], axis=1))
            self.x_test = torch.from_numpy(np.concatenate([x_test_num, x_test_cat], axis=1))
            self.x_val = torch.from_numpy(np.concatenate([x_val_num, x_val_cat], axis=1))
        
    def normalize_train_test_y(self):
        self.scaler_y = StandardScaler(with_mean=True, with_std=True)
        self.y_train = torch.from_numpy(
            self.scaler_y.fit_transform(self.y_train).astype(np.float32)
            )
        self.y_test = torch.from_numpy(
            self.scaler_y.transform(self.y_test).astype(np.float32)
            )
        self.y_val = torch.from_numpy(
            self.scaler_y.transform(self.y_val).astype(np.float32)
            )

    def return_dataset(self, split="train", device='cuda'):
        if split == "train":
            self.data = self.x_train.to(device)
            self.target = self.y_train.to(device)
        elif split == "val":
            self.data = self.x_val.to(device)
            self.target = self.y_val.to(device)
        else:
            self.data = self.x_test.to(device)
            self.target = self.y_test.to(device)
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]


def fix_random_seeds(seed=31, strict=False):
    """
    Fix random seeds.
    """
    if seed > 2**32 - 1:
        seed = seed // 2**32
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    if strict:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_mnist_loaders(data_aug=False, batch_size=128, test_batch_size=1000, download=False, perc=1.0, nw=4, onehot=False):
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(28, padding=4),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    target_transform = None
    if onehot:
        def target_transform(x): return torch.nn.functional.one_hot(torch.tensor(x), 10).float()

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=True, download=download, transform=transform_train,
                       target_transform=target_transform), batch_size=batch_size,
        shuffle=True, num_workers=nw, drop_last=True, pin_memory=True, persistent_workers=True, prefetch_factor=16
    )

    # get only 10% set for evaluation (one every 10 samples)
    eval_dataset = datasets.MNIST(root='.data/mnist', train=True, download=download, transform=transform_test, target_transform=target_transform)
    eval_dataset = torch.utils.data.Subset(eval_dataset, list(range(0, len(eval_dataset), 10)))
    train_eval_loader = DataLoader(eval_dataset, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False)
    train_eval_loader.dataset.classes = eval_dataset.dataset.classes

    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist', train=False, download=download, transform=transform_test,
                       target_transform=target_transform),
        batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False
    )

    return train_loader, test_loader, train_eval_loader


def get_cifar_loaders(data_aug=False, batch_size=128, test_batch_size=1000, download=False, onehot=False, nw=4, 
                      debug=False):

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = transform_test

    target_transform = None
    if onehot:
        def target_transform(x): return torch.nn.functional.one_hot(torch.tensor(x), 10).float()

    dset_train = datasets.CIFAR10(root='.data/cifar', train=True, download=download,
                                  transform=transform_train, target_transform=target_transform)
    dset_test = datasets.CIFAR10(root='.data/cifar', train=False, download=download,
                                 transform=transform_test, target_transform=target_transform)
    eval_dataset = datasets.CIFAR10(root='.data/cifar', train=True, download=download,
                                     transform=transform_test, target_transform=target_transform)
    dset_trainval = torch.utils.data.Subset(eval_dataset, list(range(0, len(eval_dataset), 10)))
    

    train_loader = DataLoader(dset_train, batch_size=batch_size, shuffle=True, num_workers=nw, drop_last=True, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(dset_test, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False)
    train_eval_loader = DataLoader(dset_trainval, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=True)
    train_eval_loader.dataset.classes = dset_trainval.dataset.classes

    if debug:
        # make subset
        debug_instance_per_class = 10
        target_indices = []
        target_indices_counter = [0] * 10
        for i, (_, target) in enumerate(dset_trainval):
            if onehot:
                target = target.argmax().item()
            if target_indices_counter[target] < debug_instance_per_class:
                target_indices.append(i)
                target_indices_counter[target] += 1
        dset_trainval = torch.utils.data.Subset(dset_trainval, target_indices)
        dset_train = dset_trainval
        dset_test = dset_trainval
        train_loader, test_loader, train_eval_loader = [
            DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=nw, drop_last=False) for dset in [dset_train, dset_test, dset_trainval]
        ]

    return train_loader, test_loader, train_eval_loader 

def get_svhn_loaders(data_aug=False, batch_size=128, test_batch_size=1000, download=False, onehot=False, nw=4):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
    ])
    if data_aug:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)),
        ])
    else:
        transform_train = transform_test

    target_transform = None
    if onehot:
        def target_transform(x): return torch.nn.functional.one_hot(torch.tensor(x), 10).float()

    dset_train = datasets.SVHN(root='.data/svhn', split='train', download=download,
                               transform=transform_train, target_transform=target_transform)
    dset_test = datasets.SVHN(root='.data/svhn', split='test', download=download,
                              transform=transform_test, target_transform=target_transform)
    eval_dataset = datasets.SVHN(root='.data/svhn', split='train', download=download,
                                 transform=transform_test, target_transform=target_transform)
    dset_trainval = torch.utils.data.Subset(eval_dataset, list(range(0, len(eval_dataset), 10)))

    train_loader = DataLoader(dset_train, batch_size=batch_size, shuffle=True, num_workers=nw, drop_last=True, pin_memory=True, persistent_workers=True)
    test_loader = DataLoader(dset_test, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False)
    train_eval_loader = DataLoader(dset_trainval, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=True)

    return train_loader, test_loader, train_eval_loader

def inf_generator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
        for i, (x, y) in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()


def learning_rate_with_decay(batch_size, batch_denom, batches_per_epoch, boundary_epochs, decay_rates, base_lr):
    initial_learning_rate = base_lr * batch_size / batch_denom

    boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
    vals = [initial_learning_rate * decay for decay in decay_rates]

    def learning_rate_fn(itr):
        lt = [itr < b for b in boundaries] + [True]
        i = np.argmax(lt)
        return vals[i]

    return learning_rate_fn


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)


def get_logger(logpath, filepath, package_files=[], displaying=True, saving=True, debug=False):
    logger = logging.getLogger()
    if debug:
        level = logging.DEBUG
    else:
        level = logging.INFO
    logger.setLevel(level)
    if saving:
        info_file_handler = logging.FileHandler(logpath, mode="a")
        info_file_handler.setLevel(level)
        logger.addHandler(info_file_handler)
    if displaying:
        console_handler = logging.StreamHandler()
        console_handler.setLevel(level)
        logger.addHandler(console_handler)
    # logger.info(filepath)
    # with open(filepath, "r") as f:
    #     logger.info(f.read())

    # for f in package_files:
    #     logger.info(f)
    #     with open(f, "r") as package_f:
    #         logger.info(package_f.read())

    return logger


class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


def one_hot(x, K):
    return np.array(x[:, None] == np.arange(K)[None, :], dtype=int)


def custom_ce(pred, target):
    '''
    target is one-hot
    '''
    target = target.argmax(dim=-1)
    return F.cross_entropy(pred, target)


def label_transform(y, mode='no', num_dim=2304, num_classes=10, smoothing=0., rand_scale=0.2):
    '''
    y: Batch of labels, shape (batch_size,)
    '''
    if mode == 'no':
        return y
    y_onehot = torch.nn.functional.one_hot(y, num_classes)
    if smoothing > 0:
        y_onehot = y_onehot.float() * (1 - smoothing) + smoothing / num_classes
    if mode == 'onehot+zero':
        # onehot encoding at first (num_classes) dimension and zero for the rest
        y_zero = torch.zeros(y_onehot.shape[0], num_dim - num_classes).to(y.device)
        return torch.cat([y_onehot, y_zero], dim=1)
    elif mode == 'onehot+zero+spatial':
        # Here assuming 64*h*w dimension, h=w
        h = w = int((num_dim // 64) ** 0.5)
        y_zero = torch.zeros(y_onehot.shape[0], 64 - num_classes).to(y.device)
        y_unit = torch.cat([y_onehot, y_zero], dim=1).unsqueeze(-1).unsqueeze(-1)
        return y_unit.repeat(1, 1, h, w).flatten(1, -1)
    elif mode == 'onehot+random':
        # onehot encoding at first (num_classes) dimension and random noise for the rest
        y_random = torch.randn(y_onehot.shape[0], num_dim - num_classes).to(y.device) * rand_scale
        return torch.cat([y_onehot, y_random], dim=1)
    elif mode == 'onehot+random+spatial':
        # Here assuming 64*h*w dimension, h=w
        h = w = int((num_dim // 64) ** 0.5)
        y_zero = torch.rand(y_onehot.shape[0], 64 - num_classes).to(y.device)
        y_unit = torch.cat([y_onehot, y_zero], dim=1).unsqueeze(-1).unsqueeze(-1)
        return y_unit.repeat(1, 1, h, w).flatten(1, -1)
    elif mode == 'tile':
        # tile the onehot encoding to all dimension, fill the remainder with zero
        multiple, remainder = divmod(num_dim, num_classes)
        y_tile = y_onehot.repeat(1, multiple).view(-1, multiple*num_classes)
        # pad until num_dim
        y_tile = torch.cat([y_tile, torch.zeros(y_tile.shape[0], remainder).to(y.device)], dim=1)
        return y_tile

    else:
        raise ValueError('mode not found')


def label_parse(y, mode='no', num_dim=2304, num_classes=10):
    '''
    parse transformed label back to classification label
    y: Batch of transformed labels, shape (batch_size,) or (batch_size, 2304)
    '''
    if mode == 'no':
        return torch.argmax(y, dim=1)
    elif mode.startswith('onehot'):
        if 'spatial' in mode:
            h = w = int((num_dim // 64) ** 0.5)
            y = y.view(-1, 64, h, w).sum(dim=(2, 3))
        return torch.argmax(y[:, :num_classes], dim=1)
    elif mode == 'tile':
        multiple, remainder = divmod(num_dim, num_classes)
        # remove the padding
        y = y[:, :multiple*num_classes]
        # sum and argmax
        y = y.view(-1, multiple, num_classes).sum(dim=1)
        return torch.argmax(y, dim=1)
    else:
        raise ValueError('mode not found')


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
        )
    return x[(...,) + (None,) * dims_to_append]


def norm_avg(vecs):
    # flatten to 2dim
    vecs = vecs.reshape(vecs.shape[0], -1)
    return (vecs - vecs.mean(dim=0, keepdim=True)).norm(dim=1).mean().item()


@torch.inference_mode()
def straightness(net, X, normalize=True, minibatch=250):
    '''
    Measure the straightness with 128 Euler step.
    minibatch option is for avoiding high memory usage.
    '''
    bs = minibatch
    ret = 0
    for bidx in range(0, X.shape[0], bs):
        traj, _ = net.get_traj(X[bidx:bidx+bs].to(net.device), timesteps=128+1, method='euler')
        N = len(traj) - 1  # 128 + 1, traj: (N+1, B, ...)
        assert N == 128
        dt = 1 / N
        # flatten the spatial dimension
        traj = traj.reshape(N+1, traj.shape[1], -1)
        gt = traj[-1:] - traj[:1] # (1, B, -1)
        pred = torch.diff(traj, dim=0) # (N, B, -1)
        # normalize so that gt has norm 1
        norm_const = 1
        if normalize:
            norm_const = torch.norm(gt, dim=-1, keepdim=True)
        ret += torch.sum((gt / norm_const - pred / dt / norm_const) ** 2).item()    
    return ret / N / X.shape[0]  # mean over time and batch, sum over channels
    


@torch.inference_mode()
def flow_loss_timestep(net, X, Y, timesteps=128):
    '''
    Returns the flow MSE loss at each timestep. (num_timesteps)
    '''
    timesteps = torch.linspace(0, 1, timesteps+1).to(X.device)
    timesteps = timesteps[:-1]  # exclude the last one
    B = X.shape[0]
    T = len(timesteps)
    z0 = net.in_projection(X).unsqueeze(0)  # (1, B, ...)
    z1 = net.label_projection(Y).unsqueeze(0)
    timesteps = append_dims(timesteps, z0.ndim)  # (T, ...)
    zt = z0 + (z1 - z0) * timesteps
    zt = rearrange(zt, 'T B ... -> (T B) ...')
    v_target = repeat(z1[0] - z0[0], 'B ... -> (T B) ...', T=T)
    timesteps_repeat = repeat(timesteps.squeeze(-1), 'T ... -> (T B) ...', B=B)
    v_pred = net.pred_v(zt, timesteps_repeat)
    error = rearrange(v_target - v_pred, '(T B) ... -> T B ...', B=B)
    return reduce(error**2, 'T ... -> T', 'mean')

def area_scaling_factor(weight_matrix):
    """
    Calculate the area scaling factor of a linear layer.

    Args:
    - weight_matrix: torch.Tensor of shape (n, m) representing the weights of a linear layer

    Returns:
    - area_scaling_factor: float representing the area scaling factor
    """
    # Perform singular value decomposition (SVD)
    U, S, Vt = torch.svd(weight_matrix)

    # Calculate the determinant of Sigma (product of non-zero singular values)
    area_scaling_factor = torch.prod(S[S > 0])

    return area_scaling_factor.item()

