from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import jax.numpy as jnp


def termination_fn_false(obs, act, next_obs):
    assert len(obs.shape) == len(next_obs.shape) == len(act.shape) == 2
    done = np.array([False]).repeat(len(obs))
    done = done[:, np.newaxis]
    return done


class ModelEnv:
    def __init__(self, model, env_id=None, rew_dim=1):
        self.model = model
        self.rew_dim = rew_dim
        self.termination_func = termination_fn_false

    def step(self, obs, act, deterministic: bool = False) -> Tuple[np.ndarray,np.ndarray,np.ndarray,np.ndarray]:
        assert len(obs.shape) == len(act.shape)
        if len(obs.shape) == 1:
            obs = obs.reshape(1, -1)
            act = act.reshape(1, -1)
            return_single = True
        else:
            return_single = False

        inputs = jnp.concatenate([obs, act], axis=-1)
        samples, vars, uncertainties = self.model.sample(inputs, deterministic=deterministic)

        samples[:, self.rew_dim :] += obs

        rewards, next_obs = samples[:, : self.rew_dim], samples[:, self.rew_dim :]
        terminals = self.termination_func(obs, act, next_obs)
        var_rewards, var_obs = vars[:, : self.rew_dim], vars[:, self.rew_dim :]

        if return_single:
            next_obs = next_obs[0]
            rewards = rewards[0]
            terminals = terminals[0]
            uncertainties = uncertainties[0]
            var_obs = var_obs[0]
            var_rewards = var_rewards[0]

        info = {'uncertainty': uncertainties,
                'var_obs': var_obs,
                'var_rewards': var_rewards}

        # info = {'mean': return_means, 'std': return_stds, 'log_prob': log_prob, 'dev': dev}
        return next_obs, rewards, terminals, info
