import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from rlpyt.utils.tensor import infer_leading_dims, restore_leading_dims
from rlpyt.models.conv2d import Conv2dModel
from rlpyt.models.mlp import MlpModel
from rlpyt.models.utils import scale_grad


from RLDIM.models import *
from RLDIM.utils import select_architecture

from rlpyt.models.utils import update_state_dict


class Arguments(object):
    def __init__(self,args):
        for k,v in args.items():
            setattr(self, k, v)

class DQNHeadModel(torch.nn.Module):

    def __init__(self, input_size, layer_sizes, output_size,fc_1=None):
        super().__init__()
        if fc_1 is None:
            self.fc_1 = nn.Linear(input_size, layer_sizes)
        else:
            self.fc_1 = fc_1
        self.fc_2 = nn.Linear(layer_sizes, output_size)
        self.mlp = nn.Sequential(*[self.fc_1,nn.ReLU(),self.fc_2])
        self._output_size = output_size

    def forward(self, input):
        return self.mlp(input).view(-1, self._output_size)

class DistributionalHeadModel(torch.nn.Module):

    def __init__(self, input_size, layer_sizes, output_size, n_atoms,fc_1=None):
        super().__init__()
        if fc_1 is None:
            self.fc_1 = nn.Linear(input_size, layer_sizes)
        else:
            self.fc_1 = fc_1
        self.fc_2 = nn.Linear(layer_sizes, output_size * n_atoms)
        self.mlp = nn.Sequential(*[self.fc_1,nn.ReLU(),self.fc_2])
        self._output_size = output_size
        self._n_atoms = n_atoms

    def forward(self, input):
        return self.mlp(input).view(-1, self._output_size, self._n_atoms)

class DistributionalDuelingHeadModel(torch.nn.Module):

    def __init__(
            self,
            input_size,
            hidden_sizes,
            output_size,
            n_atoms,
            grad_scale=2 ** (-1 / 2),
            fc_1_V = None
            ):
        super().__init__()
        if isinstance(hidden_sizes, int):
            hidden_sizes = [hidden_sizes]
        self.advantage_hidden = MlpModel(input_size, hidden_sizes)
        self.advantage_out = torch.nn.Linear(hidden_sizes[-1],
            output_size * n_atoms, bias=False)
        self.advantage_bias = torch.nn.Parameter(torch.zeros(n_atoms))
        if fc_1_V is None:
            self.value = MlpModel(input_size, hidden_sizes, output_size=n_atoms)
        else:
            self.value = nn.Sequential(*[fc_1_V,nn.ReLU(),nn.Linear(hidden_sizes[0],n_atoms)])
        self._grad_scale = grad_scale
        self._output_size = output_size
        self._n_atoms = n_atoms

    def forward(self, input):
        x = scale_grad(input, self._grad_scale)
        advantage = self.advantage(x)
        value = self.value(x).view(-1, 1, self._n_atoms)
        return value + (advantage - advantage.mean(dim=1, keepdim=True))

    def advantage(self, input):
        x = self.advantage_hidden(input)
        x = self.advantage_out(x)
        x = x.view(-1, self._output_size, self._n_atoms)
        return x + self.advantage_bias


class AtariCatDqnModel_nce(torch.nn.Module):

    def __init__(
            self,
            image_shape,
            output_size,
            n_atoms=51,
            fc_sizes=512,
            dueling=False,
            use_maxpool=False,
            channels=None,  # None uses default.
            kernel_sizes=None,
            strides=None,
            paddings=None,
            architecture='Mnih',
            downsample=1,
            frame_stack=4,
            nce_loss='InfoNCE_action_loss',
            algo='c51',
            data_aug=False,
            ema_moco=False
            ):
        super().__init__()
        self.dueling = dueling
        c, h, w = image_shape
        self.args = Arguments({'architecture':architecture,'downsample':downsample,'frame_stack':int(frame_stack==1),'nce_loss':nce_loss,'algo':algo,'data_aug':data_aug})
        dummy_state = np.zeros((h,w,c))
        
        network = select_architecture(self.args,globals())
        self.model = network(dummy_state,output_size,{})

        self.conv = self.model.convs
        conv_out_size = self.model.out_channels
        if dueling:
            self.head = DistributionalDuelingHeadModel(conv_out_size, fc_sizes,
                output_size=output_size, n_atoms=n_atoms, fc_1_V=self.model.fc_1)
        else:
            self.head = DistributionalHeadModel(conv_out_size, fc_sizes,
                output_size=output_size, n_atoms=n_atoms,fc_1=self.model.fc_1 )   
        self.model.fc_1.requires_grad = True
        

    def forward(self, observation, prev_action, prev_reward):
        """Feedforward layers process as [T*B,H]. Return same leading dims as
        input, can be [T,B], [B], or []."""
        img = observation.type(torch.float)  # Expect torch.uint8 inputs
        img = img.mul_(1. / 255)  # From [0-255] to [0-1], in place.

        # Infer (presence of) leading dimensions: [T,B], [B], or [].
        lead_dim, T, B, img_shape = infer_leading_dims(img, 3)

        conv_out = self.conv(img.view(T * B, *img_shape))  # Fold if T dimension.
        
        p = self.head(conv_out.view(T * B, -1))
        p = F.softmax(p, dim=-1)

        # Restore leading dimensions: [T,B], [B], or [], as input.
        p = restore_leading_dims(p, lead_dim, T, B)
        return p
