import os
from copy import deepcopy
import numpy as np

from tqdm import tqdm

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on


class DiscountedValueBuffer:
    def __init__(
        self, save_path: str, size: int, batch_size: int, observation_dim: int
    ) -> None:
        self.save_path = save_path
        self.buffer_size = size
        self.batch_size = batch_size
        self.observation_shape = observation_dim

        self.observations = np.zeros((size, observation_dim), dtype=np.float32)
        self.next_observations = np.zeros((size, observation_dim), dtype=np.float32)
        self.rewards = np.zeros(size, dtype=np.float32)
        self.dones = np.zeros(size, dtype=bool)

        self.pointer = 0

    def __len__(self):
        return self.pointer

    def store(self, observation, next_observation, reward, done):
        self.observations[self.pointer] = observation
        self.next_observations[self.pointer] = next_observation
        self.rewards[self.pointer] = reward
        self.dones[self.pointer] = done

        self.pointer += 1

    def dump(self):
        np.savez(
            self.save_path,
            observations=self.observations,
            next_observations=self.next_observations,
            rewards=self.rewards,
            dones=self.dones,
        )

    def sample(self):
        assert self.pointer >= self.batch_size, "Buffer does not have enough samples"
        batch_indices = np.random.randint(0, self.pointer, size=self.batch_size)

        observations = self.observations[batch_indices]
        next_observations = self.next_observations[batch_indices]
        rewards = self.rewards[batch_indices]
        dones = self.dones[batch_indices]

        return observations, next_observations, rewards, dones


def rollin_rollout(
    env,
    policy,
    buffer,
    horizon,
    num_episodes=8,
    init_policy=None,
    init_horizon=None,
    init_index_type="spaced",
    name="",
):
    logger.info(f"Collecting {num_episodes} episodes with max horizon {horizon}")
    pbar = tqdm(
        total=num_episodes, desc=f"Collecting samples for buffer {name}", leave=False
    )

    if init_horizon is None:
        init_horizon = horizon - 1

    all_ep_rew = []
    policy_ep_rew = []
    episodes_collected = 0
    while episodes_collected < num_episodes:
        episode_reward = 0
        done = False
        init_policy_index = 0

        observation = env.reset()
        if init_policy is not None and init_horizon > 0:
            if hasattr(init_policy, "eval"):
                init_policy.eval()

            if init_index_type == "spaced":
                init_policy_index = init_horizon
            elif init_index_type == "spaced_random":
                init_policy_index = np.random.randint(0, init_horizon)
            elif init_index_type == "random":
                init_policy_index = np.random.randint(0, horizon - 1)
            else:
                raise ValueError(f"Unknown init index type {init_index_type}")

            logger.info(f"Using init policy for {init_policy_index} steps")
            for i in range(init_policy_index):
                action = init_policy(observation.astype(np.float32))
                if isinstance(action, tuple):
                    action = action[0]
                observation, reward, done, _ = env.step(action)
                episode_reward += reward

                if done:
                    if hasattr(init_policy, "train"):
                        init_policy.train()
                    break

        if hasattr(policy, "train"):
            init_policy.train()

        if done:
            observation = env.reset()
            continue

        policy_episode_reward = 0
        for h in range(init_policy_index, horizon):
            action = policy(observation.astype(np.float32))
            if isinstance(action, tuple):
                action = action[0]
            next_observation, reward, done, _ = env.step(action)
            episode_reward += reward
            policy_episode_reward += reward

            if h == horizon - 1:
                done = False

            buffer.store(
                deepcopy(observation),
                deepcopy(next_observation),
                reward,
                done,
            )

            observation = next_observation

            if done:
                pbar.update(1)
                episodes_collected += 1
                break

        policy_ep_rew.append(policy_episode_reward)
        all_ep_rew.append(episode_reward)

        pbar.update(1)
        episodes_collected += 1
        observation = env.reset()

    pbar.close()
    logger.info(f"Successfully added to buffer, new buffer size: {len(buffer)}")
    logger.info(
        f"Average policy episode reward: {np.mean(policy_ep_rew)} "
        + f"and average episode reward: {np.mean(all_ep_rew)}"
    )

    return buffer, policy_ep_rew, all_ep_rew
