import abc
import os
import math

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.distributions import MultivariateNormal
from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

from vae import VAE
from data import get_normal_data, get_augmented_data


class Solver(metaclass=abc.ABCMeta):
    def __init__(self, args):
        self.args = args

        if self.args.alg == 'vae':
            self.train_data, self.test_data, self.generalization_data, self.args = \
                    get_normal_data(args)
        elif self.args.alg == 'our':
            self.train_data, self.test_data, self.generalization_data, self.args = \
                    get_augmented_data(args)

        self.cuda = torch.cuda.is_available()

        self.models = []
        self.optims = []

        self.model = VAE(self.args.num_channels, z_dim=self.args.latent_dim)
        self.optim = optim.Adam(self.model.parameters(), lr=1e-4)

        if self.cuda:
            self.model.cuda()

        self.args.log_path = os.path.join(self.args.log_path, self.args.log_name)

        if not os.path.exists(self.args.log_path):
            os.mkdir(self.args.log_path)

        self.writer = SummaryWriter(self.args.log_path)

        if self.args.dataset == 'mnist':
            self.recon_loss = nn.BCELoss()
        else:
            self.recon_loss = nn.MSELoss()

        self.prior_distribution = MultivariateNormal(
                torch.zeros(self.args.latent_dim),
                torch.diag(torch.ones(self.args.latent_dim))
        )

    def kl_from_another_gaussian(self, mu1, logvar1, mu2, logvar2):
        std1 = torch.exp(0.5*logvar1)
        std2 = torch.exp(0.5*logvar2)
        KLD = 0.5 * torch.sum(2 * torch.log(std2 / std1) - \
                1 + (std1 ** 2 + (mu2 - mu1) ** 2) / std2 ** 2,
                dim=1).mean()
        return KLD

    def wasserstein_distance(self, mu1, logvar1, mu2, logvar2):
        distance = torch.sqrt(torch.sum((mu1 - mu2) ** 2, dim=1) + \
                torch.sum((torch.sqrt(torch.exp(0.5*logvar1)) - \
                 torch.sqrt(torch.exp(0.5*logvar2))) ** 2, dim=1))
        return distance.mean()


    def base_vae_loss(self, recon_x, x, mu, logvar):
        batch_size = x.size(0)
        BCE = F.binary_cross_entropy(
                recon_x.view(batch_size, -1), x.view(batch_size, -1), reduction='sum')

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + self.args.beta * KLD, BCE, KLD 

    def kl_test(self, mu, logvar):
        KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean(1)
        return KLD


    def recon_test(self, recon, imgs):
        batch_size = imgs.size(0)
        BCE = F.binary_cross_entropy(
                recon.view(batch_size, -1), imgs.view(batch_size, -1), reduction='sum')
        return BCE


    def test(self, epoch):
        self.model.eval()
        z = torch.randn(25, self.args.latent_dim)
        if self.cuda:
            z = z.cuda()
        images = self.model.decode(z)

        self.writer.add_images('eval/generated', images, epoch)


        total_images, total_log_prob, total_kl_loss, total_recon_loss = 0, 0, 0, 0
        
        with torch.no_grad():
            for count, (imgs, _) in enumerate(self.test_data):
                total_images += imgs.size(0)
                if self.cuda:
                    imgs = imgs.cuda()

                z, mu, logvar = self.model.encode(imgs)

                recon = self.model.decode(mu)

                if count == 0:
                    self.writer.add_images('eval/reconstruction', recon[:25], epoch)
                                    
                kl_loss = self.kl_test(mu, logvar).sum()
                recon_loss = self.recon_test(recon, imgs).sum()

                log_prob = -self.prior_distribution.log_prob(z.cpu()).sum()

                total_kl_loss += kl_loss.item() 
                total_recon_loss += recon_loss.item()    
                total_log_prob += log_prob.item()

        avg_kl = total_kl_loss / total_images
        avg_recon = total_recon_loss / total_images
        avg_log_prob = total_log_prob / total_images

        self.writer.add_scalar('eval/avg_kl_loss', avg_kl, epoch)
        self.writer.add_scalar('eval/avg_recon_loss', avg_recon, epoch)
        self.writer.add_scalar('eval/avg_log_prob', avg_log_prob, epoch)

        print('current epoch: ', epoch)
        print('avg_kl_loss: ', avg_kl)
        print('avg_recon_loss: ', avg_recon)
        print('avg_log_prob: ', avg_log_prob)

        total_images, total_log_prob, total_kl_loss, total_recon_loss = 0, 0, 0, 0
        with torch.no_grad(): 
            for count, (imgs, _) in enumerate(self.generalization_data):
                total_images += imgs.size(0)
                if self.cuda:
                    imgs = imgs.cuda()

                z, mu, logvar = self.model.encode(imgs)

                recon = self.model.decode(mu)

                if count == 0:
                    self.writer.add_images('gen/reconstruction', recon[:25], epoch)
                                    
                kl_loss = self.kl_test(mu, logvar).sum()
                recon_loss = self.recon_test(recon, imgs).sum()

                log_prob = -self.prior_distribution.log_prob(z.cpu()).sum()

                total_kl_loss += kl_loss.item() 
                total_recon_loss += recon_loss.item()    
                total_log_prob += log_prob.item()

        avg_kl = total_kl_loss / total_images
        avg_recon = total_recon_loss / total_images
        avg_log_prob = total_log_prob / total_images

        self.writer.add_scalar('gen/avg_kl_loss', avg_kl, epoch)
        self.writer.add_scalar('gen/avg_recon_loss', avg_recon, epoch)
        self.writer.add_scalar('gen/avg_log_prob', avg_log_prob, epoch)

        print('generalize_avg_kl_loss: ', avg_kl)
        print('generalize_avg_recon_loss: ', avg_recon)
        print('generalize_avg_log_prob: ', avg_log_prob)
        self.model.train()

    def log_stats(self, steps, **kwargs):
        print('current steps: {}'.format(steps))
        for k, v in kwargs.items():
            self.writer.add_scalar('train/' + k, v.item(), steps)
            print(str(k) + ': ' + str(v.item()))
