import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.backend as k
from data_generators.dos_generator import DOSBatchGenerator
from data_engineering.transform_data import transform_data, transform_data_to_markovian
from utils.metrics import calculate_stopping_reward

import tensorflow as tf
import time


EPS = 1e-8


def dos_loss(reward, phi):
    J = phi[:, 0] * reward[:, 0] + (1 - phi[:, 0]) * reward[:, 1]
    loss = -k.mean(J)
    return loss


def build_dos_model(config, F):

    input_tensor = keras.layers.Input(shape=(F+1,), name='input')
    dense_out = keras.layers.BatchNormalization()(input_tensor)

    for i in range(config['num_stacked_layers']):
        dense_out = keras.layers.Dense(units=config['units_hidden'])(dense_out)
        dense_out = keras.layers.ReLU()(dense_out)
        dense_out = keras.layers.BatchNormalization()(dense_out)
        
    proba_out = keras.layers.Dense(units=1, activation='sigmoid')(dense_out)

    dos_model = keras.Model(input_tensor, proba_out)

    return dos_model


def train_dos_model(config, data_stats_dict, transform_str=None, make_markovian=False):

    nfolds = len(data_stats_dict['training_folds'])
    L = data_stats_dict['training_folds'][0][0].shape[1]
    F = data_stats_dict['training_folds'][0][0].shape[2]

    dos_rewards = []
    dos_reward_idxs = []
    prediction_time_per_ts = []
    train_times = []


    for i in range(nfolds):

        data_stats = None
        if transform_str is not None:
            data_stats = data_stats_dict[transform_str][i]

        # TRANSFORM INPUT DATA
        transformed_train_data = transform_data(data_stats_dict['training_folds'][i][0], data_stats, transform_str)
        train_rewards = data_stats_dict['training_folds'][i][1]
        transformed_val_data = transform_data(data_stats_dict['validation_folds'][i][0], data_stats, transform_str)
        val_rewards = data_stats_dict['validation_folds'][i][1]
        transformed_test_data = transform_data(data_stats_dict['test_folds'][i][0], data_stats, transform_str)
        test_rewards = data_stats_dict['test_folds'][i][1]

        if make_markovian:
            transformed_train_data = transform_data_to_markovian(transformed_train_data)
            transformed_val_data = transform_data_to_markovian(transformed_val_data)
            transformed_test_data = transform_data_to_markovian(transformed_test_data)
            F =transformed_train_data.shape[2]

        dos_models = []
        dos_train_value = train_rewards[:, [L-1]]
        dos_val_value = val_rewards[:, [L - 1]]
        test_interventions = np.ones((test_rewards.shape[0], test_rewards.shape[1]))
        start_train_time = time.time()
        for n in range(L-2, -1, -1):

            #BUILD MODEL
            dos_model = build_dos_model(config, F)
            # COMPILE MODEL
            dos_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=0, patience=5)
            dos_model.compile(loss=dos_loss,
                           optimizer=keras.optimizers.Adam(lr=config['dos_lr'], clipnorm=config['clipnorm']))

            # FIT MODEL
            train_target = np.concatenate([train_rewards[:, [n]], dos_train_value], axis=1)
            val_target = np.concatenate([val_rewards[:, [n]], dos_val_value], axis=1)
            dos_train_generator = DOSBatchGenerator(transformed_train_data[:, n, :], train_target, config, config['batch_size'], randomize=True)
            dos_val_generator = DOSBatchGenerator(transformed_val_data[:, n, :], val_target, config, config['batch_size'], randomize=False)

            dos_history = dos_model.fit(dos_train_generator,
                                    validation_data=dos_val_generator,
                                    callbacks=[dos_callback],
                                    epochs=config['dos_epochs'], shuffle=False, verbose=0)

            #PREDICT AND UPDATE CONTINUATION VALUE
            dos_train_prediction_generator = DOSBatchGenerator(transformed_train_data[:, n, :], train_target, config, config['batch_size'], randomize=False)
            dos_val_prediction_generator = DOSBatchGenerator(transformed_val_data[:, n, :], val_target, config, config['batch_size'], randomize=False)
            dos_train_predictions = dos_model.predict(dos_train_prediction_generator)
            dos_val_predictions = dos_model.predict(dos_val_prediction_generator)

            dos_train_value[dos_train_predictions[:, 0] > 0.5, 0] = train_rewards[dos_train_predictions[:, 0] > 0.5, n]
            dos_val_value[dos_val_predictions[:, 0] > 0.5, 0] = val_rewards[dos_val_predictions[:, 0] > 0.5, n]

            # STORE MODELS FOR INFERENCE
            dos_models.append(dos_model)
        end_train_time = time.time()

        # MODEL INFERENCE
        start_predict_time = time.time()
        for n in range(L-1):
            dos_model = dos_models[L-2-n]
            test_target = np.concatenate([test_rewards[:, [n]], test_rewards[:, [n]]], axis=1)
            dos_test_prediction_generator = DOSBatchGenerator(transformed_test_data[:, n, :], test_target, config, config['batch_size'], randomize=False)
            dos_test_predictions = dos_model.predict(dos_test_prediction_generator)
            test_interventions[:, [n]] = (dos_test_predictions > 0.5) * 1
        test_interventions[:, -1] = 1
        if config['omit_time_zero']:
            test_interventions[:, 0] = 0
        end_predict_time = time.time()


        stopping_reward, stop_idxs = calculate_stopping_reward(0.5, test_interventions, test_rewards)

        print(str(stopping_reward))
        dos_rewards.append(stopping_reward)
        prediction_time_per_ts.append((end_predict_time-start_predict_time) * (10**3) / (test_interventions.shape[0] * L))
        train_times.append((end_train_time-start_train_time))


    dos_results = {'dos_rewards': dos_rewards, 'dos_reward_idxs': dos_reward_idxs, 'prediction_time_per_ts': prediction_time_per_ts, 'train_times': train_times}


    return dos_results
