from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms

from itertools import chain

from core import metrics
from core import utils
from time import time
from core.logger import Logger
from scripts.vae import models

parser = argparse.ArgumentParser()
parser.add_argument("--nocuda", action='store_true', default=False)
parser.add_argument("--checkpoint", type=str, default='')
flags = parser.parse_args()

fmt = {'lr': '.4f',
       'tr_loss': '.4f',
       'test_logl': '.4f',
       'time': '.3f'}
logger = Logger(base='./logs', name="VAE-celeba", fmt=fmt)
device = torch.device("cpu") if flags.nocuda else torch.device("cuda")
torch.manual_seed(322)
encoder, decoder = models.get_vae_celeba(device)
if flags.checkpoint != '':
    encoder_dict, decoder_dict = torch.load(flags.checkpoint, map_location='cpu')
    encoder.load_state_dict(encoder_dict)
    decoder.load_state_dict(decoder_dict)

dataroot = "../../data/celeba"
image_size = 64
trainset = datasets.ImageFolder(root=dataroot,
                                transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor()
                                ]))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=4)

criterion = metrics.ELBO_VAE().to(device)
lr_start = 1e-3
optimizer = optim.Adam(chain(encoder.parameters(), decoder.parameters()), lr=lr_start, betas=(0.9, 0.999))

epochs = 90
for epoch in range(epochs):
    t0 = time()

    if epoch == 30:
        torch.save([encoder.state_dict(), decoder.state_dict()], logger.get_checkpoint(30))
        utils.adjust_learning_rate(optimizer, lr_start/2.0)
    if epoch == 50:
        utils.adjust_learning_rate(optimizer, lr_start/4.0)
    if epoch == 70:
        utils.adjust_learning_rate(optimizer, lr_start/10.0)
    train_loss = 0.0
    for i, (real_images, labels) in enumerate(trainloader, 0):
        real_images = real_images.to(device)
        mu, logs2, z = encoder(real_images)
        recon_images = decoder(z)
        loss = criterion(real_images, recon_images, mu, logs2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.detach().cpu().numpy()
    logger.add(epoch, tr_loss=train_loss/len(trainset))

    logger.add(epoch, lr=optimizer.param_groups[0]['lr'])
    logger.add(epoch, time=time()-t0)
    logger.iter_info()
    logger.save(silent=True)
    torch.save([encoder.state_dict(), decoder.state_dict()], logger.checkpoint)
