import numpy as np

import torch

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class FixedProbabilityPolicy:
    def __init__(self, action_probs):
        self.action_probs = action_probs

    def sample_action(self, observation):
        return np.random.choice(len(self.action_probs), p=self.action_probs)

    def get_probs(self, observation):
        return self.action_probs

    def __call__(self, observation):
        return self.sample_action(observation)

    def eval(self):
        pass

    def train(self):
        pass


def create_fixed_ap_policies(action_probs):
    policies = []
    for ap in action_probs:
        assert np.isclose(np.sum(ap), 1.0)
        policy = FixedProbabilityPolicy(np.array(list(ap)))
        policies.append(policy)

    return policies


class MaxFollowingPolicy:
    def __init__(self, policies, value_functions):
        self.policies = policies
        self.value_functions = value_functions

    def __call__(self, observation):
        return self._max_following(self.policies, self.value_functions, observation)

    def _max_following(self, policies, value_functions, observation):
        """Compare the value for all policies and return the action of the policy with the highest value."""
        values = []
        for policy, vf in zip(policies, value_functions):
            # if not tensor then convert to tensor
            action = policy(observation.astype(np.float32))

            _observation = observation
            if len(_observation.shape) == 1:
                _observation = observation.reshape(1, -1)

            value = vf(
                torch.tensor(_observation.reshape(1, -1), dtype=torch.float32).to(
                    DEVICE
                )
            )
            values.append(value.item())

        idx_policy_selected = np.argmax(values)
        action = policies[idx_policy_selected](observation)

        return action, idx_policy_selected

    def __len__(self):
        return len(self.policies)

    def eval(self):
        for vf in self.value_functions:
            vf.eval()

    def train(self):
        for vf in self.value_functions:
            vf.train()