import torch
import torch.nn as nn


class EncoderCeleba(nn.Module):
    def __init__(self, h_dim):
        super(EncoderCeleba, self).__init__()

        self.h_dim = h_dim
        self.blocks = []
        for i in range(4):
            if i == 0:
                block = [nn.Conv2d(3, 128, 5, 2, 2, bias=False),
                         nn.BatchNorm2d(128, affine=True),
                         nn.ReLU(inplace=True)]
            else:
                block = [nn.Conv2d(2 ** (i + 6), 2 ** (i + 7), 5, 2, 2, bias=False),
                         nn.BatchNorm2d(2 ** (i + 7), affine=True),
                         nn.ReLU(inplace=True)]
            self.blocks.extend(block)
        self.blocks = nn.Sequential(*self.blocks)

        self.fc_mu = nn.Linear(4 * 4 * 1024, h_dim)
        self.fc_logs2 = nn.Linear(4 * 4 * 1024, h_dim)

    def forward(self, x):
        x = self.blocks(x)
        x = x.view([-1, 4 * 4 * 1024])
        mu, logs2 = self.fc_mu(x), self.fc_logs2(x)
        std = torch.exp(0.5 * logs2)
        eps = torch.randn_like(logs2)
        return mu, logs2, mu + eps * std


class DecoderCeleba(nn.Module):
    def __init__(self, h_dim):
        super(DecoderCeleba, self).__init__()

        self.h_dim = h_dim
        self.fc1 = nn.Linear(h_dim, 8 * 8 * 1024, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        nn.ReLU(inplace=True)
        self.blocks = []
        for i in range(3):
            block = [nn.ConvTranspose2d(2**(10-i), 2**(9-i), 5, stride=2, padding=2, output_padding=1, bias=False),
                     nn.BatchNorm2d(2**(9-i), affine=True),
                     nn.ReLU(inplace=True)]
            self.blocks.extend(block)
        self.blocks = nn.Sequential(*self.blocks)
        self.conv4 = nn.Conv2d(128, 3, 5, 1, 2)

    def forward(self, z):
        z = self.fc1(z)
        z = self.relu1(z)
        z = z.view([-1, 1024, 8, 8])
        z = self.blocks(z)
        z = self.conv4(z)
        return torch.sigmoid(z)


class EncoderCifar(nn.Module):
    def __init__(self, h_dim):
        super(EncoderCifar, self).__init__()

        self.h_dim = h_dim
        self.blocks = []
        for i in range(4):
            if i == 0:
                block = [nn.Conv2d(3, 128, 4, 2, 1, bias=False),
                         nn.BatchNorm2d(128, affine=True),
                         nn.ReLU(inplace=True)]
            else:
                block = [nn.Conv2d(2 ** (i + 6), 2 ** (i + 7), 4, 2, 1, bias=False),
                         nn.BatchNorm2d(2 ** (i + 7), affine=True),
                         nn.ReLU(inplace=True)]
            self.blocks.extend(block)
        self.blocks = nn.Sequential(*self.blocks)

        self.fc_mu = nn.Linear(2 * 2 * 1024, h_dim)
        self.fc_logs2 = nn.Linear(2 * 2 * 1024, h_dim)

    def forward(self, x):
        x = self.blocks(x)
        x = x.view([-1, 2 * 2 * 1024])
        mu, logs2 = self.fc_mu(x), self.fc_logs2(x)
        std = torch.exp(0.5 * logs2)
        eps = torch.randn_like(logs2)
        return mu, logs2, mu + eps * std


class DecoderCifar(nn.Module):
    def __init__(self, h_dim):
        super(DecoderCifar, self).__init__()

        self.h_dim = h_dim
        self.fc1 = nn.Linear(h_dim, 8*8*1024, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.blocks = []
        for i in range(2):
            block = [nn.ConvTranspose2d(2**(10-i), 2**(9-i), 4, stride=2, padding=1, bias=False),
                     nn.BatchNorm2d(2**(9-i), affine=True),
                     nn.ReLU(inplace=True)]
            self.blocks.extend(block)
        self.blocks = nn.Sequential(*self.blocks)
        self.conv4 = nn.Sequential(nn.ZeroPad2d((0,1,0,1)), nn.Conv2d(256, 3, 4, 1, 1))

    def forward(self, z):
        z = self.fc1(z)
        z = self.relu1(z)
        z = z.view([-1, 1024, 8, 8])
        z = self.blocks(z)
        z = self.conv4(z)
        return torch.sigmoid(z)


def get_vae_celeba(device):
    return EncoderCeleba(64).to(device), DecoderCeleba(64).to(device)


def get_vae_cifar(device):
    return EncoderCifar(128).to(device), DecoderCifar(128).to(device)


class DiscriminatorCeleba(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(DiscriminatorCeleba, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 8, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
        )

    def forward(self, input):
        return self.main(input).flatten()


class DiscriminatorCifar(nn.Module):
    def __init__(self, nc=3, ndf=64):
        super(DiscriminatorCifar, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 2, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(ndf * 4, affine=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
        )

    def forward(self, input):
        return self.main(input).flatten()


class DiscriminatorUB(nn.Module):
    def __init__(self, discriminator):
        super(DiscriminatorUB, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        x = self(next_sample)
        y = self(prev_sample)
        d_xy = torch.sigmoid(x-y)
        d_yx = torch.sigmoid(y-x)
        return d_xy/d_yx

    def d(self, x, y):
        return torch.sigmoid(self(x)-self(y))


class DiscriminatorCCE(DiscriminatorCifar):
    def __init__(self, discriminator):
        super(DiscriminatorCCE, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        d_x = self.d(next_sample)
        d_y = self.d(prev_sample)
        return d_x/(1-d_x)*(1-d_y)/d_y

    def d(self, x):
        return torch.sigmoid(self(x))


class DiscriminatorMCE(DiscriminatorCifar):
    def __init__(self, discriminator):
        super(DiscriminatorMCE, self).__init__()
        self.discriminator = discriminator

    def forward(self, input):
        return self.discriminator(input)

    def acceptance_ratio(self, next_sample, prev_sample):
        x = self(next_sample)
        y = self(prev_sample)
        d_xy = torch.sigmoid(x - y)
        d_yx = torch.sigmoid(y - x)
        return d_xy / d_yx

    def d(self, x, y):
        return torch.sigmoid(self(x)-self(y))

