import torch.nn as nn
import torch.nn.functional as F
import math

track_running_stats=True

class Generator(nn.Module):
    def __init__(self, img_size, latent_dim, dim):
        super(Generator, self).__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        self.img_size = img_size
        # self.feature_sizes = (int(self.img_size[0] / 16), int(self.img_size[1] / 16))
        self.feature_sizes = (math.ceil(self.img_size[0] / 16), math.ceil(self.img_size[1] / 16))

        self.latent_to_features = nn.Sequential(
            nn.Linear(latent_dim, 8 * dim * self.feature_sizes[0] * self.feature_sizes[1]),
            nn.ReLU()
        )

        self.features_to_image = nn.Sequential(
            nn.ConvTranspose2d(8 * dim, 4 * dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(4 * dim,track_running_stats=track_running_stats),
            nn.ConvTranspose2d(4 * dim, 2 * dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * dim, track_running_stats=track_running_stats),
            nn.ConvTranspose2d(2 * dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(dim, track_running_stats=track_running_stats),
            nn.ConvTranspose2d(in_channels=dim, out_channels=self.img_size[2], 
                kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, input_data):
        # Map latent into appropriate size for transposed convolutions
        x = self.latent_to_features(input_data)
        # Reshape
        x = x.view(-1, 8 * self.dim, self.feature_sizes[0], self.feature_sizes[1])
        # Return generated image
        x = self.features_to_image(x)
        # Center-crop if too large
        [B, C, H, W] = x.shape
        crop_h = abs(int((H - self.img_size[0])/2))
        crop_w = abs(int((W - self.img_size[1])/2))
        x = x[:, :, crop_h:(crop_h+self.img_size[0]), crop_w:(crop_w+self.img_size[1])]
        return x



class Discriminator(nn.Module):
    def __init__(self, img_size, dim):
        """
        img_size : (int, int, int)
            Height and width must be powers of 2.  E.g. (32, 32, 1) or
            (64, 128, 3). Last number indicates number of channels, e.g. 1 for
            grayscale or 3 for RGB
        """
        super(Discriminator, self).__init__()

        self.img_size = img_size

        self.image_to_features = nn.Sequential(
            nn.Conv2d(self.img_size[2], dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim, 2 * dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(2 * dim, 4 * dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(4 * dim, 8 * dim, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )

        # 4 convolutions of stride 2, i.e. halving of size everytime
        # So output size will be 8 * (img_size / 2 ^ 4) * (img_size / 2 ^ 4)
        output_size = int(8 * dim * (img_size[0] / 16) * (img_size[1] / 16))
        self.features_to_prob = nn.Sequential(
            nn.Linear(output_size, 1)
        )

    def forward(self, input_data):
        batch_size = input_data.size()[0]
        x = self.image_to_features(input_data)
        x = x.view(batch_size, -1)
        return self.features_to_prob(x)
    
class Decoder_FC(nn.Module):
    def __init__(self, x_dim, latent_dim):
        super(Decoder_FC, self).__init__()
        self.latent_dim=latent_dim
        
        self.prob_to_features = nn.Sequential(
            nn.Linear(latent_dim, x_dim),
            nn.ReLU(),
            nn.Linear(x_dim, x_dim),
            nn.ReLU(),
            nn.Linear(x_dim, x_dim)
        )

    def forward(self, input_data):
        return self.prob_to_features(input_data)
