import os
import torch
from tqdm import tqdm
import numpy as np
from torchmetrics.functional import pairwise_cosine_similarity

k = 3
model_names = 'ensemble,dreamsim_ensemble,dino_vitb16,dreamsim_dino_vitb16,clip_vitb32,dreamsim_clip_vitb32,open_clip_vitb32,dreamsim_open_clip_vitb32,synclr_vitb16,dreamsim_synclr_vitb16,dinov2_vitb14,dreamsim_dinov2_vitb14'
model_names = model_names.split(',')
n = 9
model_names = model_names[n:n+1]
datasets = [
                 # 'oxford_flowers102',
                 # 'clevr(task="closest_object_distance")',
                 # 'diabetic_retinopathy(config="btgraham-300")',
                 # 'patch_camelyon',
                 # 'dsprites(predicted_attribute="label_orientation",num_classes=16)',
                 # 'caltech101',
                 # 'dmlab',
                 # 'cifar(num_classes=100)',
                 # 'smallnorb(predicted_attribute="label_elevation")',
                 # 'clevr(task="count_all")',
                 # 'kitti(task="closest_vehicle_distance")',
                 'dsprites(predicted_attribute="label_x_position",num_classes=16)',
                 # 'eurosat',
                 # 'smallnorb(predicted_attribute="label_azimuth")',
                 # 'svhn',
                 # 'oxford_iiit_pet',
                 # 'dtd',
                 # 'resisc45',
                 # 'sun397_ours'
]

splits = ('val',)
device = 'cuda:0'

for dataset in datasets:
    print(dataset)
    output_dir = f'outputs_nn/{dataset}/'

    for model_name in tqdm(model_names):
        embed_paths = [f'embeds/{model_name}_{dataset}_{split}.npz' for split in splits]

        try:
            all_data = [np.load(embed_path) for embed_path in embed_paths]
        except:
            print(f'skipping {model_name}...')
            continue
        embeds = np.concatenate([embed['embeds'] for embed in all_data])
        labels = np.concatenate([embed['labels'] for embed in all_data])

        embeds = torch.from_numpy(embeds).to(device)
        labels = torch.from_numpy(labels).to(device)

        sim_matrix = pairwise_cosine_similarity(embeds, embeds)
        top_k = torch.topk(sim_matrix, k=k+1, largest=True, sorted=True)
        top_k_sims = top_k[0][:, 1:].cpu().numpy()
        top_k_idcs = top_k[1][:, 1:].cpu().numpy()

        os.makedirs(output_dir, exist_ok=True)
        np.savez(os.path.join(output_dir, f'{model_name}_k{k}.npz'), sims=top_k_sims, idcs=top_k_idcs)