import argparse
import string

from omegaconf import OmegaConf
from tqdm import tqdm
import json
import numpy as np
import os
import torch
from open_flamingo import create_model_and_transforms
from huggingface_hub import hf_hub_download
from scripts import icl_helpers, dataset
from transformers.modeling_outputs import CausalLMOutputWithPast

transform_fns = {
    'clevr_count_transform': lambda x: x + 3
}

def save_results(save_file, paths, results_raw, results, answers):
    correct = (np.array(results) == np.array(answers)).sum()
    acc = correct / len(results)
    print("classification accuracy:", acc)

    save_file = f"{save_file}.json"
    save_folder = os.path.dirname(save_file)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder, exist_ok=True)
    json.dump({"paths": paths, "results_raw": results_raw, "results": results, "answers": answers}, open(save_file, "w"))

def build_prompt(classes):
    prompt = f'<image>Output:{classes}<|endofchunk|>'
    return prompt

def tokenize_classes(dset_classes, tokenizer):
    class_tokens = []
    for classes in tqdm(dset_classes):
        class_tokens.append(tokenizer(classes, add_special_tokens=False, return_tensors="pt")['input_ids'])
    return class_tokens

def main(args):
    device = 'cuda:0'
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    config = OmegaConf.load(args.config_path)
    model, image_processor, tokenizer = create_model_and_transforms(
        clip_vision_encoder_path='ViT-L-14',
        clip_vision_encoder_pretrained='openai',
        lang_encoder_path=config.model_path,
        tokenizer_path=config.model_path,
        cross_attn_every_n_layers=1,
    )

    config['dataset_kwargs']['nn_file'] = args.nn_file
    nn_components = config['dataset_kwargs']['nn_file'].split('/')
    out_dir = '/'.join(nn_components[:-2])
    metadata_name = nn_components[-1].split('.')[0]
    save_file = f'{out_dir}/metadata3/{metadata_name}_{seed}.json'
    if os.path.exists(save_file):
        print(f'Skipping {save_file} as it already exists.')
        return

    with open('prompts.json', 'r') as f:
        prompt_data = json.load(f)
    dataset_name = config['dataset_kwargs']['name']
    dset_classes = prompt_data[dataset_name]['classes']

    checkpoint_path = hf_hub_download(config.ckpt_path, 'checkpoint.pt')
    model.load_state_dict(torch.load(checkpoint_path), strict=False)
    tokenizer.padding_side = 'left'
    model.eval()
    model = model.to(device)
    test_dataset = dataset.get_dataset(image_processor, config)

    class_tokens = tokenize_classes(dset_classes, tokenizer)

    preds, gts, nns, labels, nn_labels = [], [], [], [], []
    corrects = 0
    rmse = 0
    mae = 0

    n = args.n
    n_random = np.random.choice(len(test_dataset), n, replace=False)
    for i, rand_idx in enumerate(tqdm(n_random, total=n)):
        batch = test_dataset[rand_idx]
        x, label, x_nns, label_nns = batch
        vision_x = torch.stack([*x_nns, x]).unsqueeze(1).unsqueeze(0).to(device)
        nns.append(test_dataset.nns[rand_idx].tolist())
        labels.append(label)
        nn_labels.append(label_nns)

        prompt = ''
        for j, label_nn in enumerate(label_nns):
            prompt += build_prompt(dset_classes[label_nn])
        prompt += f'<image>Output:'

        encodings = tokenizer([prompt], return_tensors="pt")
        input_ids = encodings["input_ids"]
        input_ids = input_ids.to(device)

        overall_probs = []
        for class_token in class_tokens:
            class_token = class_token.to(device)
            num_tokens_in_classname = class_token.shape[1]
            _lang_x = torch.cat([input_ids, class_token], dim=1)

            with torch.inference_mode():
                outputs = model(
                    vision_x=vision_x,
                    lang_x=_lang_x,
                    clear_conditioned_layers=True,
                    past_key_values=None,
                    use_cache=False,
                )

            logits = outputs.logits
            logprobs = torch.log_softmax(logits, dim=-1)
            gen_probs = logprobs[
                :, -num_tokens_in_classname - 1 : -1, :
            ]  # (B, num_tokens_in_classname, vocab_len)
            gen_probs = torch.gather(
                gen_probs, 2, class_token[:, :, None]
            ).squeeze(-1)
            class_prob = torch.mean(gen_probs, dim=1)
            overall_probs.append(class_prob)

        overall_probs = torch.cat(overall_probs)
        top_k = torch.topk(overall_probs, 1)
        pred_val = top_k.indices.item()

        corrects += pred_val == label
        preds.append(pred_val)
        gts.append(label)
    corrects /= n

    metadata = {
        'metrics': {
            'corrects': corrects,
        },
        'preds': preds,
        'gts': gts,
        'idcs': n_random.tolist(),
        'labels': labels,
        'nn_idcs': nns,
        'nn_labels': nn_labels
    }

    if not os.path.exists(f'{out_dir}/metadata3'):
        os.makedirs(f'{out_dir}/metadata3', exist_ok=True)
    json.dump(metadata, open(save_file, 'w'))

    print(save_file)
    print('------------------------------------------------', corrects, rmse, mae)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--config_path", type=str)
    parser.add_argument("--nn_file", type=str)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--n", type=int, default=500)
    args = parser.parse_args()
    main(args)
