from copy import deepcopy
import os
from typing import List, Optional, Union
from attrs import define, asdict
from re import S
from mllib.trainers.base_trainers import AbstractTrainer, Trainer
from mllib.models.base_models import AbstractModel
from mllib.param import BaseParameters, Parameterized
from mllib.runners.configs import BaseExperimentConfig, TrainingParams
from mllib.utils.metric_utils import compute_accuracy
import torch
import pytorch_lightning as pl
from pytorch_lightning.lite import LightningLite
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.strategies import Strategy
from pytorch_lightning.plugins import PLUGIN_INPUT

@define(slots=False)
class LightningLiteParams(object):
    accelerator: Optional[Union[str, Accelerator]] = None
    strategy: Optional[Union[str, Strategy]] = None
    devices: Optional[Union[List[int], str, int]] = None
    num_nodes: int = 1
    precision: Union[int, str] = 32
    plugins: Optional[PLUGIN_INPUT] = None
    gpus: Optional[Union[List[int], str, int]] = None
    tpu_cores: Optional[Union[List[int], str, int]] = None

class LightningLiteWrapper(LightningLite):
    def __init__(self, params, *args, **kwargs) -> None:
        llparams = params.lightning_lite_params
        super().__init__(**(asdict(llparams)))

class PytorchLightningLiteTrainerMixin(LightningLiteWrapper):
    def checkpoint(self, metric, epoch_idx, comparator):
        if self.is_global_zero:
            super().checkpoint(metric, epoch_idx, comparator)
        
    def run(self):
        super().train()
    
    def train(self):
        device = self.device
        train_loader = self.train_loader
        model = self.model
        optimizer = self.optimizer
        
        self.device = None
        self.model, self.optimizer = self.setup(model, optimizer)
        self.train_loader = self.setup_dataloaders(train_loader)
        self.run()

        self.device = device
        self.train_loader = train_loader
        self.model = model
        self.optimizer = optimizer

class PytorchLightningTrainerMixin(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _log(self, logs):
        for k, v in logs.items():
            self.log(k, v, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def training_step(self, batch, batch_idx):
       output, logs = super().train_step(batch, batch_idx)
       logs['lr'] = self.scheduler.optimizer.param_groups[0]['lr']
       self._log(logs)
       output['logs'] = logs
       return output
    
    def validation_step(self, batch, batch_idx):
        output, logs = super().val_step(batch, batch_idx)
        val_logs = {}
        for k,v in logs.items():
            val_logs[k.replace('train', 'val')] = v
        val_logs['lr'] = self.scheduler.optimizer.param_groups[0]['lr']
        self._log(val_logs)
        output['logs'] = val_logs
        return output
    
    def test_step(self, batch, batch_idx):
        output, logs = super().test_step(batch, batch_idx)
        test_logs = {}
        for k,v in logs.items():
            test_logs[k.replace('train', 'test')] = v
        self._log(test_logs)
        output['logs'] = test_logs
        return output
    
    def training_epoch_end(self, outputs):
        outputs_cpy = deepcopy(outputs)
        metrics = outputs_cpy.pop('logs')
        outputs, metrics = super().train_epoch_end(outputs_cpy, metrics, self.current_epoch)

    def validation_epoch_end(self, outputs):
        outputs_cpy = deepcopy(outputs)
        metrics = outputs_cpy.pop('logs')
        outputs, metrics = super().val_epoch_end(outputs_cpy, metrics, self.current_epoch)

    def test_epoch_end(self, outputs):
        outputs_cpy = deepcopy(outputs)
        metrics = outputs_cpy.pop('logs')
        outputs, metrics = super().test_epoch_end(outputs_cpy, metrics)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, y = batch
        y_hat = self.model(x)
        return y_hat

    def configure_optimizers(self):
        scheduler_config = {
            'scheduler': self.scheduler,
            'interval': 'epoch' if self.scheduler_step_after_epoch else 'step'
        }
        return [self.optimizer], [scheduler_config]

    def configure_callbacks(self):
        early_stop =  pl.callbacks.EarlyStopping(monitor=self.tracked_metric, mode=self.tracking_mode, patience=self.early_stop_patience)
        ckpdir = os.path.join(self.logdir, 'checkpoints')
        checkpoint = pl.callbacks.ModelCheckpoint(monitor=self.tracked_metric, mode=self.tracking_mode, dirpath=ckpdir, save_last=True)
        checkpoint.CHECKPOINT_NAME_LAST = 'model_checkpoint'
        checkpoint.FILE_EXTENSION = '.ckpt'
        return [early_stop, checkpoint]
