import tqdm
import torch
import matplotlib

matplotlib.use('Agg')
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as col


def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor,
              filename: str, source_color='r', target_color='b'):
    """
    Visualize features from different domains using t-SNE.

    Args:
        source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
        target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
        filename (str): the file name to save t-SNE
        source_color (str): the color of the source features. Default: 'r'
        target_color (str): the color of the target features. Default: 'b'

    """
    source_feature = source_feature.numpy()
    target_feature = target_feature.numpy()
    features = np.concatenate([source_feature, target_feature], axis=0)

    # map features to 2-d using TSNE
    X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)

    # domain labels, 1 represents source while 0 represents target
    domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))

    # visualize using matplotlib
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(filename)


def collect_feature(data_loader, model, device, mode=None) -> torch.Tensor:
    model.eval()
    all_features = []
    if mode == "text":
        with torch.no_grad():
            ground_truth = torch.arange(385, dtype=torch.long, device=device)
            all_features = feature_extractor(ground_truth, mode="scene")
        return all_features.cpu()
    
    with torch.no_grad():
        for i, data in enumerate(tqdm.tqdm(data_loader)):
            if mode == "s":
                inputs = data[0].to(device)
            elif mode == "t":
                inputs = data[1].to(device)
            _, feature = model(inputs)  
            #feature += torch.clip(torch.normal(0, 0.03, size=feature.size()), -0.05, 0.05).to(device)
            if isinstance(feature, tuple):
                feature=feature[1].cpu()
            else:
                feature=feature.cpu()
            all_features.append(feature)
    return torch.cat(all_features, dim=0)