import librosa
import glob
import os

import pandas
import torch
import json
import laion_clap
import numpy as np

from PIL import Image
from PIL import ImageFile
from transformers import CLIPProcessor, CLIPModel

ImageFile.LOAD_TRUNCATED_IMAGES = True

def CLAP_text_emb(text_data, batch_size, save_path):
    model = laion_clap.CLAP_Module(enable_fusion=True)
    model.load_ckpt('./laion_clap_fullset_fusion.pt')

    split_num = len(text_data) // batch_size
    text_emb_list = []
    for i in range(split_num):
        if i == split_num - 1:
            text_files_split = text_data[batch_size * i:len(text_data)]
        else:
            text_files_split = text_data[batch_size * i:batch_size * (i + 1)]
        text_emb = model.get_text_embedding(x=text_files_split)
        text_emb_list.append(torch.tensor(text_emb))
        print(f"{batch_size * (i + 1)}//{len(text_data)} shape:{text_emb.shape}")
    text_embs = torch.cat(text_emb_list)
    torch.save(torch.tensor(text_embs), save_path)

    return text_embs

def CLIP_text_emb(text_data, batch_size, save_path):
    from transformers import CLIPProcessor, CLIPModel

    model = CLIPModel.from_pretrained('./clip-vit-base-patch32/clip-vit-base-patch32')
    processor = CLIPProcessor.from_pretrained('./clip-vit-base-patch32/clip-vit-base-patch32')
    model.cuda()

    image = Image.open('./demo/images/cars_1.jpg')

    split_num = len(text_data) // batch_size
    text_emb_list = []
    for i in range(split_num):
        if i == split_num - 1:
            text_files_split = text_data[batch_size * i:len(text_data)]
        else:
            text_files_split = text_data[batch_size * i:batch_size * (i + 1)]

        inputs = processor(text=text_files_split, images=image, return_tensors="pt", padding=True)
        if inputs['input_ids'].shape[1] > 77:
            inputs['input_ids'] = inputs['input_ids'][:, 0:77]
            inputs['attention_mask'] = inputs['attention_mask'][:, 0:77]

        for k in inputs.keys():
            inputs[k] = inputs[k].cuda()

        outputs = model(**inputs)

        text_features = outputs.text_embeds
        text_emb_list.append(torch.tensor(text_features))
        print(f"{batch_size * (i + 1)}//{len(text_data)} shape:{text_features.shape}")

    text_embs = torch.cat(text_emb_list)
    torch.save(torch.tensor(text_embs), save_path)
    print(text_embs.shape)
    return text_embs

def load_text_data(text_path):
    with open(text_path, 'r') as f:
        data = json.load(f)

    text_data = []
    keys = list(data.keys())
    keys.sort()
    for key in keys:
        texts = data[key]
        texts.sort()
        for text in texts:
            text_data.append(text)
    return text_data

def load_mscoco_text_data(text_path):
    with open(text_path, 'r') as f:
        data = json.load(f)
    data = data['annotations']
    text_dict = {}
    for i in data:
        text_dict[i['image_id']] = i['caption']
    return text_dict

def load_CC3M_text_data(text_path):
    data = pandas.read_csv(text_path, sep='\t')
    data = data['a very typical bus station'].values.tolist()
    return data

def load_MSRVTT_text_data(text_path):
    with open(text_path, 'r') as f:
        data = json.load(f)
    text_data = []

    for i, sentence in enumerate(data['sentences']):
        text_data.append(sentence['caption'])
    return text_data

def load_MAD_text_data(text_path):
    with open(text_path, 'r') as f:
        data = json.load(f)
    text_data = []
    for i in data.keys():
        text_data.append(data[i]['sentence'])
    return text_data