# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.

import torch
import torch.nn as nn
from categorical_sampler import CategoricalSampler
from common_utils import assert_eq


class RuleBasedCountModel(nn.Module):
    def __init__(self, num_actions, num_unit_types, num_resource_bins):
        super().__init__()

        self.num_actions = num_actions
        self.num_unit_types = num_unit_types
        self.num_resource_bins = num_resource_bins

        unit_count_net_dim = 256
        self.unit_count_net = nn.Sequential(
            nn.Linear(2 * num_unit_types, unit_count_net_dim),
            nn.ReLU(),
        )

        unit_cons_count_net_dim = 128
        self.unit_cons_count_net = nn.Sequential(
            nn.Linear(num_unit_types, unit_cons_count_net_dim),
            nn.ReLU(),
        )

        moving_enemy_count_net_dim = 128
        self.moving_enemy_count_net = nn.Sequential(
            nn.Linear(num_unit_types, moving_enemy_count_net_dim),
            nn.ReLU(),
        )

        resource_net_dim = 32
        self.resource_net = nn.Sequential(
            nn.Linear(num_resource_bins, resource_net_dim),
            nn.ReLU(),
        )

        prev_action_net_dim = 32
        self.prev_action_net = nn.Sequential(
            nn.Linear(num_actions, prev_action_net_dim),
        )

        self.fc_in_dim = (unit_count_net_dim
                          + unit_cons_count_net_dim
                          + moving_enemy_count_net_dim
                          + resource_net_dim
                          + prev_action_net_dim)
        self.fc_out_dim = 128
        self.fc_net = nn.Sequential(
            nn.Linear(self.fc_in_dim, self.fc_out_dim),
            nn.ReLU(),
        )

        self.value = nn.Linear(self.fc_out_dim, 1)
        self.policy = nn.Linear(self.fc_out_dim, num_actions)

        self.sampler = CategoricalSampler(False)

    def get_input(self, batch):
        return batch

    def get_action(self, batch):
        assert self.training
        return batch['action']

    def get_policy(self, batch):
        assert self.training
        return batch['policy']

    def get_reward(self, batch):
        return batch['reward']

    def get_value(self, batch):
        return batch['v']

    def get_terminal(self, batch):
        return batch['terminal']

    def _forward(self, batch):
        unit_count = batch['unit_count']
        unit_cons_count = batch['unit_cons_count']
        moving_enemy_count = batch['moving_enemy_count']
        resource = batch['resource_bin']
        prev_action = batch['prev_action'].float()

        unit_count_out = self.unit_count_net(unit_count)
        unit_cons_count_out = self.unit_cons_count_net(unit_cons_count)
        moving_enemy_count_out = self.moving_enemy_count_net(moving_enemy_count)
        resource_out = self.resource_net(resource)
        prev_action_out = self.prev_action_net(prev_action)

        fc_in = torch.cat(
            [unit_count_out,
             unit_cons_count_out,
             moving_enemy_count_out,
             resource_out,
             prev_action_out],
            dim=1)
        fc_out = self.fc_net(fc_in)
        return fc_out

    def act(self, batch):
        feat = self._forward(batch)
        logit = self.policy(feat)
        assert_eq(logit.dim(), 2)
        pi =  nn.functional.softmax(logit, 1)
        action = self.sampler.sample(pi)
        reply = {'action': action.long(), 'policy': pi}
        return reply

    def forward(self, batch):
        feat = self._forward(batch)
        logit = self.policy(feat)
        assert_eq(logit.dim(), 2)
        pi =  nn.functional.softmax(logit, 1)

        assert self.training
        value = self.value(feat).squeeze(1)
        value = torch.tanh(value)
        reply = {'v': value, 'policy': pi}
        return reply
