import torch
import os
import torch.optim as optim
import torch.utils.data as data
import math

from solver_base import Solver

class VAESolver(Solver):
    def __init__(self, args):
        super(VAESolver, self).__init__(args)


    def solve(self):
        num_iter = 0
        for epoch_count in range(self.args.num_epochs):
            for images, _ in self.train_data:
                if self.cuda:
                    images = images.cuda()

                recon, mu, logvar = self.model(images)

#                kl_loss = self.kl_from_prior(mu, logvar)
#                recon_loss = self.recon_loss(recon, images)

#                total_loss = self.args.beta * kl_loss + recon_loss

                total_loss, recon_loss, kl_loss  = self.base_vae_loss(recon, images, mu, logvar)
                self.optim.zero_grad()
                total_loss.backward()
                self.optim.step()


                if num_iter % 50 == 0:
                    self.log_stats(
                            num_iter,
                            total_loss=total_loss, 
                            kl_loss=kl_loss,
                            recon_loss=recon_loss,
                    )
                num_iter += 1

            self.test(epoch_count)
