import numpy as np
import os
from transformers import BlipProcessor, BlipForConditionalGeneration
from tqdm import tqdm
import torch

fps = 10
path = "/home/zijiao/Desktop/Zijiao/side_project/datasets/wen2017"
# (18, 240, 6, 256, 256, 3)
video_frames_train = np.load(os.path.join(path,'preprocessed', f'video_train_256_{fps}hz.npy'))
# (5, 240, 6, 256, 256, 3)
video_frames_test = np.load(os.path.join(path,'preprocessed', f'video_test_256_{fps}hz.npy'))   

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", cache_dir='.cache')
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", cache_dir='.cache').to("cuda",  torch.float16)

# test video
data_to_generate = video_frames_test
total_length = len(data_to_generate.reshape(-1, 256, 256, 3)) // (fps*2)
progress_bar = tqdm(total=total_length)
text_list = []
for video in data_to_generate:
    c_list = []
    for clips in video:
        f_list = [c for c in clips] # making batches
        progress_bar.update(1)
        inputs = processor(images=f_list, return_tensors="pt", padding=True).to("cuda", torch.float16)
        outputs = model.generate(**inputs)
        texts = processor.batch_decode(outputs, skip_special_tokens=True)
        texts = np.stack(texts)
        c_list.append(texts)
    c_list = np.stack(c_list)
    text_list.append(c_list)
text_list = np.stack(text_list)
save_path = os.path.join(path, 'preprocessed', f'text_test_256_{fps}hz.npy')
np.save(save_path, text_list)

# train video
data_to_generate = video_frames_train
total_length = len(data_to_generate.reshape(-1, 256, 256, 3)) // (fps*2)
progress_bar = tqdm(total=total_length)
text_list = []
for video in data_to_generate:
    c_list = []
    for clips in video:
        f_list = [c for c in clips] # making batches
        progress_bar.update(1)
        inputs = processor(images=f_list, return_tensors="pt", padding=True).to("cuda", torch.float16)
        outputs = model.generate(**inputs)
        texts = processor.batch_decode(outputs, skip_special_tokens=True)
        texts = np.stack(texts)
        c_list.append(texts)
    c_list = np.stack(c_list)
    text_list.append(c_list)
text_list = np.stack(text_list)
save_path = os.path.join(path, 'preprocessed', f'text_train_256_{fps}hz.npy')
np.save(save_path, text_list)