import os
import argparse
import sys; sys.path.append("./ANODE") # import hack
from tqdm import tqdm
import torch
import torch.nn.functional as F
from models.model import OurModel, Simulator
from models.conv_models import ConvODENet, MNISTConvODENet
from models.mlp_model import ODENet
from omegaconf import OmegaConf
import plotly.express as px
from utils import *
import wandb
from torch.func import vmap, jacrev, jacfwd, functional_call
import time
import copy
import pandas as pd

def learnable(z0, z1, t):
    return None

def linear(z0, z1, t):
    return (1-t) * z0 + t * z1

def half_circle(z0, z1, t):
    ori_shape = z0.shape
    if len(z0.shape) != 2:
        if len(z0.shape) > 2:
            z0 = z0.view(z0.shape[0], -1)
        if len(z0.shape) < 2:
            z0 = z0.view(1, -1)
    if len(z1.shape) != 2:
        if len(z1.shape) > 2:
            z1 = z1.view(z1.shape[0], -1)
        if len(z1.shape) < 2:
            z1 = z1.view(1, -1)
    if len(t.shape) != 2:
        if len(t.shape) > 2:
            t = t.view(t.shape[0], -1)
        if len(t.shape) < 2:
            t = t.view(1, -1)

    degree = t * np.pi
    origin = (z0 + z1) / 2
    radius = (z0 - z1) / 2
    radius_normalized = radius / torch.norm(radius, dim=-1, keepdim=True)
    # choose any vector
    vec = (z0 + z1) / 2
    # Calculate the projection of vec onto target
    ortho = (vec - (vec * radius_normalized).sum(dim=-1, keepdim=True) * radius_normalized)
    radius_ortho = ortho / torch.norm(ortho, dim=-1, keepdim=True) * torch.norm(radius, dim=-1, keepdim=True) # make it same length as radius
    ret = origin + radius * torch.cos(degree) + radius_ortho * torch.sin(degree)
    return ret.reshape(ori_shape)

# def lin_cos(z0, z1, t):
#     degree = t * np.pi / 2.
#     return torch.cos(degree) * z0 + (1-torch.cos(degree)) * z1

def lin_sin(z0, z1, t):
    degree = t * np.pi / 2.
    return (1 - torch.sin(degree)) * z0 + torch.sin(degree) * z1


def cos(z0, z1, t):
    degree = t * np.pi / 2.
    return torch.cos(degree) * z0 + torch.sin(degree) * z1

def inv_cos(z0, z1, t):
    degree = t * np.pi / 2.
    return (1-torch.sin(degree)) * z0 + (1-torch.cos(degree)) * z1

def vp_ode(z0, z1, t):
    alpha = torch.exp(-0.25 * 19.9 * (1-t)**2 - 0.5 * 0.1 * (1-t))
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

def const_vp_ode(z0, z1, t):
    alpha = t
    beta = torch.sqrt(1-alpha**2)
    return alpha * z1 + beta * z0

@torch.inference_mode()
def test_metric(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False, metric_key='accuracy', label_scaler=None):
    if metric_key == 'accuracy':
        metric, latent_mse, data_mse = test_accuracy(net, test_dataloader, method=method, num_timesteps=num_timesteps, return_mse=return_mse)
    elif metric_key == 'rmse':
        metric, latent_mse, data_mse = test_rmse(net, test_dataloader, method=method, num_timesteps=num_timesteps, return_mse=return_mse,
                                               label_scaler=label_scaler)
    if return_mse:
        return metric, latent_mse, data_mse
    return metric

@torch.inference_mode()
def test_accuracy(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False):
    net.eval()
    correct = 0
    count = 0
    latent_mse = data_mse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure metric'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            feat, pred = net(X, return_features=True, method='dopri5')
            traj = [feat]
        else:
            traj, pred = net.get_traj(X, method=method, timesteps=num_timesteps)
        correct += (pred.argmax(dim=-1) == Y.argmax(dim=-1)).float().sum().item()
        count += Y.size(0)
        if net.augment_dim > 0:
            latent_mse += 0.
        else:
            latent_mse += F.mse_loss(traj[-1], net.label_projection(Y)).item() * Y.size(0)
        data_mse += F.mse_loss(pred, Y).item() * Y.size(0)
    acc = correct / count
    latent_mse /= count
    data_mse /= count
    return acc, latent_mse, data_mse

def test_rmse(net, test_dataloader, method='dopri5', num_timesteps=1+1, return_mse=False, label_scaler=None):
    net.eval()
    count = 0
    latent_mse = data_mse = 0
    rmse = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure rmse'):
        X, Y = X.cuda(), Y.cuda()
        if method == 'dopri5':
            feat, pred = net(X, return_features=True, method='dopri5')
            traj = [feat]
        else:
            traj, pred = net.get_traj(X, method=method, timesteps=num_timesteps)
        count += Y.size(0)
        if label_scaler is not None:
            Y_unnorm = label_scaler.inverse_transform(Y.cpu().numpy())
            pred_unnorm = label_scaler.inverse_transform(pred.cpu().numpy())
            rmse += np.mean((Y_unnorm - pred_unnorm)**2) * Y.size(0)
        else:
            rmse += F.mse_loss(pred, Y).item() * Y.size(0)
        if net.augment_dim > 0:
            latent_mse += 0.
        else:
            latent_mse += F.mse_loss(traj[-1], net.label_projection(Y)).item() * Y.size(0)
        data_mse += F.mse_loss(pred, Y).item() * Y.size(0)

    #TODO: compute rmse
    latent_mse /= count
    data_mse /= count
    rmse /= count
    rmse = rmse ** 0.5
    return rmse, latent_mse, data_mse


@torch.inference_mode()
def test_straightness(net, test_dataloader, normalize=True):
    net.eval()
    straight = 0.
    count = 0
    for i, (X, Y) in tqdm(enumerate(test_dataloader), leave=False, total=len(test_dataloader), desc='Measure Straightness'):
        X, Y = X.cuda(), Y.cuda()
        count += Y.size(0)
        straight += straightness(net, X, normalize=normalize) * Y.size(0)
    straight /= count
    return straight


def test_norm_avg_reg(net, val_dataloader=None):
    if val_dataloader is None:
        return 0, 0
    net.eval()
    # z1 norm avg
    z1_norm_avg = 0
    # z0 norm avg
    z0_norm_avg = 0
    count = 0
    for i, (X, Y) in tqdm(enumerate(val_dataloader), leave=False, total=len(val_dataloader), desc='Measure Norm Avg'):
        X = X.cuda()
        z0 = net.in_projection(X)
        z1 = net.label_projection(Y.cuda())
        z0_norm_avg += norm_avg(z0) * X.size(0)
        z1_norm_avg += norm_avg(z1) * X.size(0)
        count += X.size(0)
    z0_norm_avg /= count
    z1_norm_avg /= count
    return z0_norm_avg, z1_norm_avg


def test_norm_avg_cls(net, val_dataloader=None, num_classes=10):
    if val_dataloader is None:
        return 0, 0
    net.eval()
    # z1 norm avg
    all_labels = F.one_hot(torch.arange(num_classes)).float().cuda()
    z1 = net.label_projection(all_labels)
    z1_norm_avg = norm_avg(z1)
    # z0 norm avg
    z0_norm_avg = 0
    count = 0
    for i, (X, _) in tqdm(enumerate(val_dataloader), leave=False, total=len(val_dataloader), desc='Measure Norm Avg'):
        X = X.cuda()
        z0 = net.in_projection(X)
        z0_norm_avg += norm_avg(z0) * X.size(0)
        count += X.size(0)
    z0_norm_avg /= count
    return z0_norm_avg, z1_norm_avg


class Trainer(object):
    def __init__(self, net, total_steps, optimizer='adam', scheduler='none', lr=1e-3, wd=[0., 0., 0.], loss_start_epoch=[0, 0, 0], lambdas=[1.0, 1.0, 0.0],
          task_criterion=torch.nn.MSELoss(), force_zero_prob=0., test_every=2, 
          label_ae_noise=0., method='ours', f_sg_target=False, augment_t=1, t_transform='identity', label_ae_mse=True,
          task_dec=False, f_jac_clamp=(-1, -1), g_jac_clamp=(-1,-1), train_alter=False, alter_order=['fgh'], alter_epoch=[1],
          fgh_lr=None, sync_t=False, label_flow_noise=0., label_flow_noise_0=0., invert_transform_t=False, dynamics=linear,
          dataset='mnist', label_scaler=None, save_every=24, patience=-1, steer=0, ema=0):
        
        self.net = net
        self.wd = wd
        self.lr = lr
        self.total_steps = total_steps
        self.loss_start_epoch = loss_start_epoch
        self.lambdas = lambdas
        self.task_criterion = task_criterion
        self.force_zero_prob = force_zero_prob
        self.test_every = test_every
        self.label_ae_noise = label_ae_noise
        self.method = method
        self.f_sg_target = f_sg_target
        self.augment_t = augment_t
        self.t_transform = t_transform
        self.label_ae_mse = label_ae_mse
        self.task_dec = task_dec
        self.f_jac_clamp = f_jac_clamp
        self.g_jac_clamp = g_jac_clamp
        self.train_alter = train_alter
        self.alter_order = alter_order
        self.alter_epoch = alter_epoch
        self.fgh_lr = fgh_lr
        self.sync_t = sync_t
        self.label_flow_noise = label_flow_noise
        self.label_flow_noise_0 = label_flow_noise_0
        self.invert_transform_t = invert_transform_t
        self.dynamics = dynamics
        self.dyn_v = vmap(jacfwd(dynamics, argnums=2))
        self.ckpt_dir = wandb.run.dir.replace('wandb', 'ckpts') # os.path.join(wandb.run.dir, 'files')
        self.ckpt_dir = os.path.join(self.ckpt_dir, wandb.run.name)
        os.makedirs(self.ckpt_dir, exist_ok=True)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.label_scaler = label_scaler
        self.dataset = dataset
        self.save_every = save_every
        self.time_threshold = save_every * 3600
        self.early_stop_count = 0
        self.early_stopping_patience = patience
        self.steer = steer

        self.ema_alpha = ema
        self.ema = copy.deepcopy(net.state_dict())
        self.ema_net = copy.deepcopy(net)

        for p in self.ema_net.parameters():
            p.requires_grad_(False)

        if dataset in ['mnist', 'cifar10', 'svhn']:
            self.metric_type = 'accuracy'
        else:
            self.metric_type = 'rmse'

        os.makedirs(self.ckpt_dir, exist_ok=True)

        self.configure_optimizer()
    
    def ema_update(self, alpha=0.999):
        for n, p in self.net.named_parameters():
            self.ema[n] = alpha * self.ema[n] + (1 - alpha) * p.data
    
    def ema_restore(self):
        for n, p in self.ema_net.named_parameters():
            p.data = self.ema[n]
    
    def configure_optimizer(self):
        if self.fgh_lr is None:
            self.fgh_lr = [self.lr, self.lr, self.lr]

        if self.optimizer == 'adam':
            opt = torch.optim.AdamW
        elif self.optimizer == 'radam':
            opt = torch.optim.RAdam
        else:
            raise ValueError(f'Optimizer {self.optimizer} not supported')

        if not self.train_alter:
            assert self.alter_order == ['fgh'] and self.alter_epoch == [1]

        optimizers = []
        for target in self.alter_order:
            assert type(target) == str, 'alter_order should be a list of strings'
            params = []
            if 'f' in target:
                params.append({'params': self.net.in_projection.parameters(), 'lr': self.fgh_lr[0], 'weight_decay': self.wd})
            if 'g' in target:
                params.append({'params': self.net.out_projection.parameters(), 'lr': self.fgh_lr[1], 'weight_decay': self.wd})
                params.append({'params': self.net.label_projection.parameters(), 'lr': self.fgh_lr[1], 'weight_decay': self.wd})
            if 'h' in target:
                params.append({'params': self.net.odeblock.parameters(), 'lr': self.fgh_lr[2], 'weight_decay': self.wd})
            
            optimizer = opt(params, lr=self.lr)
            optimizers.append(optimizer)
        assert len(optimizers) == len(self.alter_epoch), 'optimizers and train_alter_epoch should have same length'

        if self.scheduler == 'none':
            scheduler = None
        elif self.scheduler == 'cos':
            assert len(optimizers) == 1, 'Cosine annealing scheduler only supports single optimizer'
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizers[0], T_max=self.total_steps, eta_min=0)
            # TODO:
        elif self.scheduler == 'step':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10000 * (i+1) for i in range(10)], gamma=0.5)
        assert len(optimizers) == 1
        
        self.optimizer = optimizers[0]
        self.scheduler = scheduler

    @torch.inference_mode()
    def evaluate(self, test_dataloader, val_dataloader, subset_loader, current_step, best_test_metric=0, best_val_metric=0):
        metric_dict = self.test(test_dataloader)
        valtric_dict = self.test(val_dataloader, metric_key='val')
        # if subset_loader is not None:
        #     subset_dict = self.test(subset_loader, metric_key='train_subset')
        #     metric_dict.update(subset_dict)
        metric_dict.update(valtric_dict)
        metric_val = metric_dict[f'val/{self.metric_type}_dopri']
        metric1 = metric_dict[f'test/{self.metric_type}_1']
        metric2 = metric_dict[f'test/{self.metric_type}_2']
        metricinf =  metric_dict[f'test/{self.metric_type}_dopri']
        straight = metric_dict['test/straightness']
        latent_mse =  metric_dict['test/latent_mse']
        data_mse = metric_dict['test/data_mse']

        if self.metric_type == 'accuracy':
            condition = metricinf > best_test_metric
            val_condition = metric_val > best_val_metric
        else:
            condition = metricinf < best_test_metric
            val_condition = metric_val < best_val_metric

        if condition:
            best_test_metric = metricinf
            torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'best.ckpt'))
            torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'best_ema.ckpt'))
        torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'last.ckpt'))
        torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'last_ema.ckpt'))
        
        if val_condition:
            best_val_metric = metric_val
            torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'best_val.ckpt'))
            torch.save(self.ema_net.state_dict(), os.path.join(self.ckpt_dir, f'best_val_ema.ckpt'))
            best_val_log = {}
            for k, v in metric_dict.items():
                if 'val' in k:
                    continue
                if 'test' in k:
                    best_val_log[k.replace('test', 'test_on_best_val')] = v
            pd.DataFrame(best_val_log, index=[0]).to_csv(os.path.join(self.ckpt_dir, 'best_val_log.csv'))
            wandb.log(best_val_log, commit=False)
            self.early_stop_count = 0
        else:
            # TODO: 
            self.early_stop_count += 1

        if self.metric_type == 'accuracy':
            metric_dict['test/best_acc'] = best_test_metric
            metric_dict['val/best_acc'] = best_val_metric
        else:
            metric_dict['test/best_rmse'] = best_test_metric
            metric_dict['val/best_rmse'] = best_val_metric
        print(f'Step {current_step}/{self.total_steps}, Val {self.metric_type} {metric_val:.4f}, Test {self.metric_type} 1/2/inf {metric1:.4f}/{metric2:.4f}/{metricinf:.4f}, Straightness {straight:.4f}, Latent MSE {latent_mse:.4f}, Data MSE {data_mse:.4f}')
        wandb.log(metric_dict, commit=False)
        return best_test_metric, best_val_metric

    def compute_jacobian(self, x, y, mode='f'):
        # compute jacobian approx. from x to y
        if mode == 'f':
            lam_min, lam_max = self.f_jac_clamp
        elif mode == 'g':
            lam_min, lam_max = self.g_jac_clamp
        else:
            assert 0

        delta_x = x[1:] - x[:-1]
        delta_y = y[1:] - y[:-1]
        Q = delta_y.reshape(delta_y.size(0), -1).norm(dim=-1) / delta_x.reshape(delta_x.size(0), -1).norm(dim=-1)

        # pass if delta_z is zero
        mask = Q > 1e-9
        Q = Q[mask]

        if mode == 'f':
            wandb.log({"train/f_jac_Q": Q.mean().item()}, commit=False)
        else:
            wandb.log({"train/g_jac_Q": Q.mean().item()}, commit=False)

        jac_clamp_loss = torch.zeros(1).cuda()
        if lam_min >= 0 and lam_max >= 0:
            target_min = torch.ones_like(Q) * lam_min
            target_max = torch.ones_like(Q) * lam_max
            L_min = F.mse_loss(torch.minimum(Q, target_min), target_min, reduction='mean')
            L_max = F.mse_loss(torch.maximum(Q, target_max), target_max, reduction='mean')
            jac_clamp_loss = L_min + L_max
        return jac_clamp_loss
    
    def sample_timestep(self, z0, device):
        if self.sync_t: # sample one t and use it for all instances
            t = torch.rand(1).to(device).repeat(z0.size(0))
        else:
            t = torch.rand(self.augment_t * z0.size(0)).to(device)

        # Strategies to give more sampling chance to certain timesteps
        if self.t_transform == 'square':
            t = t**2
        elif self.t_transform == 'cubic':
            t = t**3
        elif self.t_transform == 'one_minus_cos':
            t = 1 - torch.cos(t * np.pi / 2)
        else:
            assert self.t_transform == 'identity'
        if self.invert_transform_t:
            t = 1 - t

        t = t * self.net.t_final # scale t to [0, t_final]
        t = append_dims(t, z0.ndim)
        # make some portion of sampled t to zero
        if self.force_zero_prob > 0.:
            mask = (torch.rand_like(t) < self.force_zero_prob).float()
            t = t * (1. - mask)
        return t

    def fit(self, train_dataloader, val_dataloader, test_dataloader, subset_loader=None,):
        self.current_step = 0
        best_test_metric = 0 if self.metric_type == 'accuracy' else 1e9
        best_val_metric = 0 if self.metric_type == 'accuracy' else 1e9
        epoch = 0
        start_time = time.time()
        next_thresh = start_time + self.time_threshold
        hour_count = 0
        self.early_stop_count = 0

        pbar = tqdm(total=self.total_steps, desc='Training')
        flow_loss_timestep_bin = torch.zeros(11) if self.method == 'ours' else None
        timestep_bin_count = torch.zeros(11) if self.method == 'ours' else None

        while True:
            logs = {}
            count = 0
            for i, (X, Y) in enumerate(train_dataloader):
                ### single epoch training
                log, flow_loss_timestep_bin, timestep_bin_count = self.training_step(X, Y, flow_loss_timestep_bin, timestep_bin_count)
                pbar.update(1)

                # accumulate log
                for k, v in log.items():
                    if k in logs:
                        logs[k] += v
                    else:
                        logs[k] = v
                count += 1                

                cur_log = {k: v / count for k, v in logs.items()}
                pbar.set_description(
                    ', '.join([f'{k}: {v:.3e}' for k, v in cur_log.items()])
                    )

                ### evaluate on test_every
                self.current_step += 1

                if self.current_step % self.test_every == 0:
                    best_test_metric, best_val_metric = self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
                    if self.method == 'ours':
                        flow_loss_timestep_bin = torch.nan_to_num(flow_loss_timestep_bin / timestep_bin_count, posinf=0)
                        # fig = wandb.Plotly(px.bar(x=list(range(len(flow_loss_timestep_bin))), y=flow_loss_timestep_bin.tolist()))
                        # wandb.log({'train/flow_loss_timestep_bin': fig}, commit=False)
                        flow_loss_timestep_bin = torch.zeros(11) if self.method == 'ours' else None
                        timestep_bin_count = torch.zeros(11) if self.method == 'ours' else None
                    if self.early_stop_count >= self.early_stopping_patience and self.early_stopping_patience > 0:
                        print(f'Early stopping at step {self.current_step}')
                        return self.net
                
                if time.time() > next_thresh:
                    hour_count += 1
                    best_test_metric, best_val_metric = self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
                    next_thresh = time.time() + self.time_threshold
                    torch.save(self.net.state_dict(), os.path.join(self.ckpt_dir, f'{hour_count * self.save_every}hr.ckpt'))
                    torch.save(self.optimizer.state_dict(), os.path.join(self.ckpt_dir, f'{hour_count * self.save_every}hr_opt.ckpt'))                        

                if self.current_step >= self.total_steps:
                    break
            
            if self.current_step >= self.total_steps:
                break

            for k, v in logs.items():
                logs[k] /= count
            # per-epoch logging
            wandb.log({k+'_epoch': v for k, v in logs.items()}, commit=False)
            wandb.log({'epoch': epoch}, commit=False)
            epoch += 1

        # save last checkpoint and finish
        self.evaluate(test_dataloader, val_dataloader, subset_loader, self.current_step, best_test_metric, best_val_metric)
        return self.net
    
    def training_step(self, X, Y, flow_loss_timestep_bin=None, timestep_bin_count=None):
        self.net.train()
        if self.method == 'node':
            self.optimizer.zero_grad()
            device = self.net.device
            if self.dataset in ['mnist', 'cifar10', 'svhn']:
                X, Y = X.to(device), Y.to(device).argmax(dim=-1)
            else:
                X, Y = X.to(device), Y.to(device)
            if self.steer == 0:
                pred = self.net(X)
            else:
                pred = self.net.steer(X, b=self.steer)
            loss = self.task_criterion(pred, Y)
            forward_nfe = int(self.net.odeblock.odefunc.nfe)
            self.net.odeblock.odefunc.nfe = 0
            loss.backward()
            backward_nfe = int(self.net.odeblock.odefunc.nfe)

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()
            wandb.log({
                'train/loss': loss.item(),
                'train/nfe': forward_nfe,
                'train/backward_nfe': backward_nfe,
            })
            log = {
                'train/loss': loss.item(),
                'train/nfe': forward_nfe,
                'train/backward_nfe': backward_nfe,
            }
            self.ema_update(alpha=self.ema_alpha)
            return log, None, None
        else:
            device = self.net.device
            X, Y = X.to(device), Y.to(device)
            z0 = self.net.in_projection(X)
            z1 = self.net.label_projection(Y)

            # jacobian clamping
            f_jac_clamp_loss = self.compute_jacobian(X, z0, mode='f')
            g_jac_clamp_loss = self.compute_jacobian(Y, z1, mode='g')

            # sampling timestep
            t = self.sample_timestep(z0, device)
            
            # augment z if needed
            z0_aug = z0
            z1_aug = z1        
            if self.augment_t > 1:
                z0_aug = repeat(z0_aug, 'B ... -> (B a) ...', a=self.augment_t)
                z1_aug = repeat(z1_aug, 'B ... -> (B a) ...', a=self.augment_t)
            if self.label_flow_noise > 0.: # add noise to label embedding for flow prediction
                z1_aug = z1_aug + self.label_flow_noise * torch.randn_like(z1_aug)
            if self.label_flow_noise_0 > 0.: # add noise to label embedding for flow prediction
                z0_aug = z0_aug + self.label_flow_noise_0 * torch.randn_like(z0_aug)
            
            # Run dynamics
            zt = self.dynamics(z0_aug, z1_aug, t)
            v_target = self.dyn_v(z0_aug, z1_aug, t).squeeze()
            flow_loss, label_ae_loss, task_loss = (torch.tensor(0.) for _ in range(3))
            boundary_loss_0, boundary_loss_1 = (torch.tensor(0.) for _ in range(2))

            # flow loss
            v_pred = self.net.pred_v(zt, t)
            flow_loss = F.mse_loss(v_pred, v_target) * self.lambdas[0]

            # label autoencoding loss
            z1_noised = z1
            if self.label_ae_noise > 0.:
                z1_noised = z1 + self.label_ae_noise * torch.randn_like(z1)
            y_pred = self.net.out_projection(z1_noised)
            if self.label_ae_mse:
                label_ae_loss = F.mse_loss(y_pred, Y) * self.lambdas[1]
            else:
                label_ae_loss = custom_ce(y_pred, Y) * self.lambdas[1]

            # optimizer step
            self.net.zero_grad()
            loss = flow_loss + label_ae_loss + task_loss + f_jac_clamp_loss + g_jac_clamp_loss +\
                boundary_loss_0 + boundary_loss_1
            loss.backward()
            if self.dataset == 'uci':
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
            self.optimizer.step()
            if self.scheduler is not None:
                self.scheduler.step()

            # pred_v variance across batch
            v_pred_var = v_pred.detach().reshape(v_pred.size(0), -1).var(dim=-1).mean().item()
            normalized_flow_loss = flow_loss.item() / (v_target.detach().reshape(v_target.size(0), -1).norm(dim=-1).mean() + 1e-9)

            # logging
            wandb.log({
                'learning_rate': self.optimizer.param_groups[0]['lr'],
                'train/loss': loss.item(),
                'train/flow_loss': flow_loss.item(),
                'train/norm_flow_loss: ': normalized_flow_loss.item(), # normalized flow loss
                'train/label_ae_loss': label_ae_loss.item(),
                'train/task_loss': task_loss.item(),
                'train/f_jac_clamp_loss': f_jac_clamp_loss.item(),
                'train/g_jac_clamp_loss': g_jac_clamp_loss.item(),
                'train/flow_pred_var': v_pred_var,
                'train/z0_norm': z0.detach().reshape(z0.size(0), -1).norm(dim=-1).mean().item(),
                'train/z1_norm': z1.detach().reshape(z1.size(0), -1).norm(dim=-1).mean().item(),
                'train/boundary_loss_0': boundary_loss_0.item(),
                'train/boundary_loss_1': boundary_loss_1.item(),
            })

            # flow loss timestep bin
            t_detach = t.detach().squeeze()
            t_candidates = torch.linspace(0, 1, len(timestep_bin_count)).to(t_detach.device)
            flow_loss_detach = F.mse_loss(v_pred.detach(), v_target.detach(), reduction='none') * self.lambdas[0]
            flow_loss_detach = reduce(flow_loss_detach, 'B ... -> B', 'mean')
            for i in range(len(t_candidates)):
                if i == len(t_candidates) - 1:
                    break
                t_start, t_end = t_candidates[i], t_candidates[i+1]
                mask = (t_detach >= t_start) & (t_detach < t_end)
                timestep_bin_count[i] += mask.sum().cpu()
                flow_loss_timestep_bin[i] += flow_loss_detach[mask].sum().cpu()
            
            log = {
                'loss': loss.item(),
                'flow_loss': flow_loss.item(),
                'label_ae_loss': label_ae_loss.item(),
                'task_loss': task_loss.item(),
            }
            self.ema_update(alpha=self.ema_alpha)
            return log, flow_loss_timestep_bin, timestep_bin_count


    def test(self, test_dataloader, metric_key='test', do_dopri=True):
        '''
        Returns a dict of metrics.
        - val/metric (if val_dataloader is not None)
        - test/metric_1
        - test/metric_2
        - test/metric_dopri
        - test/straightness
        - test/latent_mse
        - test/data_mse
        '''
        ret = {}
        self.ema_restore()
        net = self.ema_net if self.ema_alpha > 0 else self.net
        straight = test_straightness(net, test_dataloader)
        metric1 = test_metric(net, test_dataloader, method='euler', num_timesteps=1+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric2 = test_metric(net, test_dataloader, method='euler', num_timesteps=2+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric10 = test_metric(net, test_dataloader, method='euler', num_timesteps=10+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        metric20 = test_metric(net, test_dataloader, method='euler', num_timesteps=20+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        if metric_key == 'test' and self.metric_type == 'rmse':
            metric100 = test_metric(net, test_dataloader, method='euler', num_timesteps=100+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
            metric1000 = test_metric(net, test_dataloader, method='euler', num_timesteps=1000+1, metric_key=self.metric_type, label_scaler=self.label_scaler)
        if do_dopri:
            metricinf, latent_mse, data_mse = test_metric(net, test_dataloader, method='dopri5',
                                                          return_mse=True, metric_key=self.metric_type, label_scaler=self.label_scaler)
            dopri_nfe = net.odeblock.odefunc.nfe # assume no odesolve after test_metric / inaccurate since it only measures last batch
        test_norm_avg = test_norm_avg_cls if self.metric_type == 'accuracy' else test_norm_avg_reg
        z0_norm_avg, z1_norm_avg = test_norm_avg(net, test_dataloader)


        ret[f'{metric_key}/straightness'] = straight
        ret[f'{metric_key}/{self.metric_type}_1'] = metric1
        ret[f'{metric_key}/{self.metric_type}_2'] = metric2
        ret[f'{metric_key}/{self.metric_type}_10'] = metric10
        ret[f'{metric_key}/{self.metric_type}_20'] = metric20
        if metric_key == 'test' and self.metric_type == 'rmse':
            ret[f'{metric_key}/{self.metric_type}_100'] = metric100
            ret[f'{metric_key}/{self.metric_type}_1000'] = metric1000
        if do_dopri:
            ret[f'{metric_key}/{self.metric_type}_dopri'] = metricinf
            ret[f'{metric_key}/dopri_nfe'] = dopri_nfe
            ret[f'{metric_key}/latent_mse'] = latent_mse
            ret[f'{metric_key}/data_mse'] = data_mse
        ret[f'{metric_key}/z0_norm_avg'] = z0_norm_avg
        ret[f'{metric_key}/z1_norm_avg'] = z1_norm_avg
        return ret