import numpy as np
import torch as t
import os
from PIL import Image
from utils import *

class2idx = {'climbing': 0, 'diving': 1, 'fishing': 2, 'racing' : 3, 'throwing' : 4, 'pole vaulting' : 5}


        
def bar_gen(args):
    train_valid_split = 0.9

    _dir = args.data_storage+'BAR/train/'
    train_list = os.listdir(_dir)
    label = t.zeros(len(train_list))
    data = t.zeros((len(train_list),3,224,224))
    for idx,fidx in enumerate(train_list):
        label[idx] = class2idx[fidx.split('_')[0]]
        img = Image.open(_dir + fidx)
        img = img.resize((224,224)).convert('RGB')
        data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))


    ret, train, valid = {}, {}, {}
    b_label = t.zeros_like(label)

    train['data'] = data[:int(len(label)*train_valid_split)]
    train['label'] = label[:int(len(label)*train_valid_split)]
    train['b_label'] = b_label[:int(len(label)*train_valid_split)]
    
    valid['data'] = data[int(len(label)*train_valid_split):]
    valid['label'] = label[int(len(label)*train_valid_split):]
    valid['b_label'] = b_label[int(len(label)*train_valid_split):]

    ret['train'] = train
    ret['valid'] = valid
    
    data_name = args.data
    save_data(ret, args.save_dir+data_name)



    _dir = args.data_storage+'BAR/test/'
    test_list = os.listdir(_dir)
    ret = {}
    
    label = t.zeros(len(test_list))
    data = t.zeros((len(test_list),3,224,224))
    for idx,fidx in enumerate(test_list):
        label[idx] = class2idx[fidx.split('_')[0]]
        img = Image.open(_dir + fidx)
        img = img.resize((224,224)).convert('RGB')
        data[idx] = t.tensor((np.array(img)/255.).transpose((2,0,1)))
        
    label = label.clone()
    b_label = t.zeros_like(label)

    ret['data'] = data
    ret['label'] = label
    ret['b_label'] = b_label
    data_name = args.data + '_test'
    save_data(ret,args.save_dir+data_name)
