import torch 
import torch.nn as nn 
import torch.distributions as D 

class ForwardPolicy(nn.Module): 

    def __init__(self, input_dim, hidden_dim, num_comp, device='cpu'): 
        super(ForwardPolicy, self).__init__() 
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim 
        self.num_comp = num_comp 
        self.device = device 
        
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.LeakyReLU(), 
            nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), 
            nn.Linear(hidden_dim, 3 * num_comp)
        ).to(self.device)  

        self.mlp_flow = nn.Linear(input_dim, 1).to(self.device)     
        self.eps = .3 # unused 

    def get_dist(self, batch_state): 
        out = self.mlp(batch_state.state) 
        means, sigma, mixture_comp = torch.chunk(out, dim=1, chunks=3)  
        eps = self.eps if self.training else 0. 
        sigma = sigma.exp() + eps / 2 # shape: (batch_size, num_comp) 
        mixture_comp = mixture_comp.softmax(dim=1) # shape: (batch_size, num_comp) 
        return D.MixtureSameFamily(D.Categorical(mixture_comp), D.Normal(means, sigma)) 

    def forward(self, batch_state, actions=None): 
        dist = self.get_dist(batch_state) 

        if actions is None: 
            actions = dist.sample(sample_shape=(1,)).squeeze(dim=0) 

        log_probs = dist.log_prob(actions) 
        # probably unused 
        log_flows = self.mlp_flow(batch_state.state).squeeze(dim=1) 

        return actions, log_probs, log_flows, log_probs 
        
class BackwardPolicy(nn.Module): 

    def __init__(self, device): 
        super(BackwardPolicy, self).__init__() 
        self.device = device 

    def forward(self, batch_state, actions=None): 
        bs = batch_state.batch_size  
        return actions, torch.zeros((bs,), device=self.device) 
