import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from torchvision import transforms

def recursive_glob(rootdir=".", suffix=""):
    """Performs recursive glob with given suffix and rootdir
        :param rootdir is the root directory
        :param suffix is the suffix to be searched
    """
    return [
        os.path.join(looproot, filename)
        for looproot, _, filenames in os.walk(rootdir)
        for filename in filenames
        if filename.endswith(suffix)
    ]

def _remove_axes(ax):
    ax.xaxis.set_major_formatter(plt.NullFormatter())
    ax.yaxis.set_major_formatter(plt.NullFormatter())
    ax.set_xticks([])
    ax.set_yticks([])


def remove_axes(axes):
    if isinstance(axes, plt.Axes):
        _remove_axes(axes)
    elif len(axes.shape) == 2:
        for ax1 in axes:
            for ax in ax1:
                _remove_axes(ax)
    else:
        for ax in axes:
            _remove_axes(ax)

def create_cityscapes_label_colormap(train_dset):
    """Creates a label colormap used in CITYSCAPES segmentation benchmark.

    Returns:
    A colormap for visualizing segmentation results.
    """
    colormap = [np.array([0, 0, 0])]

    for cls in train_dset.classes[1:-1]:
        color = cls[-1]
        colormap.append(np.array(list(color)))
    return colormap

def feature_pca(features, components=3, pcas=None, return_pca=False, fit_all=False, normalize=True):
    b, c, h, w = features.shape
    features_permuted = features.detach().cpu().permute(0, 2, 3, 1)
    features_flat_batched = features_permuted.reshape(b, h * w, c)
    if not fit_all:
        if pcas is None:
            # Fit and transform separate PCA per sample
            pcas = []
            for i in range(b):
                pca = PCA(n_components=components)
                pca.fit(features_flat_batched[i])
                pcas.append(pca)
        features_pca = np.array([pcas[min(i, len(pcas) - 1)].transform(features_flat_batched[i]) for i in range(b)])
    else:
        # Fit one PCA on one combined sample, then transform separately
        pca = PCA(n_components=components)
        pca.fit(features_flat_batched.reshape(b * h * w, c))
        features_pca = np.array([pca.transform(features_flat_batched[i]) for i in range(b)])
    if normalize:
        features_pca = features_pca - np.min(features_pca, axis=(1, 2), keepdims=True)
        features_pca /= np.max(features_pca, axis=(1, 2), keepdims=True)
    features_pca = features_pca.reshape(b, h, w, components)
    return (features_pca, pcas) if return_pca else features_pca


def get_preprocess(model_type):
    if 'lpips' in model_type:
        return 'LPIPS'
    elif 'dists' in model_type:
        return 'DISTS'
    elif 'psnr' in model_type:
        return 'PSNR'
    elif 'ssim' in model_type:
        return 'SSIM'
    elif 'clip' in model_type or 'open_clip' in model_type or 'dino' in model_type or 'mae' in model_type:
        return 'DEFAULT'
    else:
        return 'DEFAULT'


def get_preprocess_fn(preprocess, load_size, interpolation):
    if preprocess == "LPIPS":
        t = transforms.ToTensor()
        return lambda pil_img: t(pil_img.convert("RGB")) / 0.5 - 1.
    else:
        if preprocess == "DEFAULT":
            t = transforms.Compose([
                transforms.Resize((load_size, load_size), interpolation=interpolation),
                transforms.ToTensor()
            ])
        elif preprocess == "DISTS":
            t = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.ToTensor()
            ])
        elif preprocess == "SSIM" or preprocess == "PSNR":
            t = transforms.ToTensor()
        else:
            raise ValueError("Unknown preprocessing method")
        return lambda pil_img: t(pil_img.convert("RGB"))



