#!/bin/python3
import sys

sys.path.append("..")
import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import transforms
from tqdm import tqdm

from hyperdiffusion import FrozenUnet, HyperDiffusion
from luna import Luna

# Path to the directory containing the LUNA dataset *.npy files generated by create_luna_dataset.py
PATH = ""

if __name__ == "__main__":
    print("SEED:", torch.random.initial_seed())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    write_dir = Path(f"weights")
    write_dir.mkdir(exist_ok=True, parents=True)

    # Experiment settings
    N = 1000
    B = 32
    T = 100
    R = 128  # resolution
    num_epochs = 400
    n_params = 2440241

    tfm = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((R, R), antialias=True),
        ]
    )
    dset = Luna(PATH, tfm)
    train_set = Subset(dset, range(N))
    loader = DataLoader(train_set, batch_size=B, shuffle=True)

    backbone = FrozenUnet(
        dim=16, dim_mults=(1, 2, 4, 8), channels=1, self_condition=True, device=device
    )
    model = HyperDiffusion(backbone, image_size=R, timesteps=T, n_params=n_params).to(
        device
    )
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    pbar = tqdm(range(num_epochs))
    for epoch in pbar:
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            t = torch.randint(0, T, (len(x),)).to(device)
            in_vec = torch.randn(model.in_dim, device=device)
            loss = model.p_losses(x, t, y, in_vec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_description(f"Loss: {loss.item():.3f}")
    torch.save(
        model.state_dict(),
        os.path.join(
            write_dir,
            f"hyperddpm.pt",
        ),
    )
