import os

from disentangle.TI_MNN import TIMechanisticLitModule
import numpy as np
import pytorch_lightning as pl
import utils
from data import LitDataModule, SpeedyWeatherDiscreteDataset,ShallowWaterDiscreteDataset
from disentangle import MechanisticLitModule, ContrastiveLitModule, AdaLitModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger

base_dir = ".."
log_dir = f"{base_dir}/lightning_logs"

is_training = False
experiment_name = "New_ada_sw_discrete2"

os.makedirs(os.path.join(log_dir, experiment_name), exist_ok=True)

# ------------------- tuple: (original, share_sim, share_loc) --------------
device_id = 3
device = f"cuda:{device_id}"

param_dim = 12
n_views = 3

shared_ids = [[0]]
factor_sharing = {0: (0,1)}
splits = [tuple(s) for s in np.array_split(np.arange(param_dim), 3)]
print(splits)

# no free space for local features
# splits = list(np.array_split(np.arange(param_dim), len(factor_sharing)))
code_sharing = {
    splits[0]: (0, 1), # share layer thickness
    splits[1]: (0, 2), # share local features
    splits[2]: (0, 1, 2), # share other global features
} # subset of views: shared coding dims
print(code_sharing)


# ------------------- pair: (original, shared_sim_diff_loc) --------------

# param_dim = 12
# n_views = 2

# shared_ids = [[0]]
# factor_sharing = {0: (0,1)}
# splits = [range(param_dim//2)]

# # no free space for local features
# # splits = list(np.array_split(np.arange(param_dim), len(factor_sharing)))
# code_sharing = {
#     splits[0]: (0, 1), # share layer thickness
# } # subset of views: shared coding dims
# print(code_sharing)

if __name__ == "__main__":
    data_dir = "data"
    batch_size = 47 * 96 * 2
    pl.seed_everything(423472)

    datamodule = LitDataModule(
        dataset_class=ShallowWaterDiscreteDataset,
        data_path=data_dir,
        model_name="ShallowWaterModel/discrete2",
        num_simulations=2,
        num_views=n_views,
        grid_size=[2] * 1, # activate in case of discrete model
        include_keys=["u", "v"],
        shared_ids=shared_ids,
        factor_sharing=factor_sharing,
        batch_size=batch_size,
        collate_style="default",
        chunk_size=121,
    )
            

    method = MechanisticLitModule(
        learning_rate=1e-5,
        batch_size=batch_size,
        n_views=n_views if is_training else 1,
        order=2,
        state_dim=2,
        n_step=datamodule.train_set.chunk_size,
        n_iv_steps=10,
        param_dim=param_dim,
        map_location=device,
        mlp_enc=True,
        dct_layer=True,
        freq_frac_to_keep=0.5,
        factor_type="discrete",
        code_sharing=code_sharing,
        alignment_reg=10.0,
        eval_metrics=['r2'],
        notes="sw2_iv10_chunk121_view2_dim12",
    )

    sample = next(iter(datamodule.train_dataloader()))
    import numpy as np
    import torch

    params = np.stack(list(sample["gt_params"].values()), -1)
    params = torch.from_numpy(params)
    from utils import feature_sharing_fn

    # shared = feature_sharing_fn(params.float(), num_views=n_views, 
    #                             code_sharing=datamodule.test_set.factor_sharing,
    #                             **sample)
    # assert (params == shared).all(), "unity check for feature sharing failed"

    trainer = Trainer(
        max_steps=30000,
        # max_epochs=100,
        accelerator="auto",
        devices=[device_id],
        check_val_every_n_epoch=30,  # check validation every epoch
        log_every_n_steps=500,
        logger=CSVLogger(log_dir, name=experiment_name) if is_training else False,
        inference_mode=not is_training,
    )

    if is_training:
        # run without validation
        # trainer.limit_val_batches = 0
        trainer.num_sanity_val_steps = 0
        trainer.fit(method, datamodule)
         
    else:
        trainer.validate(method, datamodule.val_dataloader())
        pred_params = np.concatenate(method.misc["pred_params"], axis=0)
        gt_params = np.concatenate(method.misc["gt_params"], axis=0)
