import math

import torch.optim.optimizer
from torch.utils.data import ConcatDataset
import pytorch_lightning as pl
from transformers import (
    get_linear_schedule_with_warmup,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup
)

from .model.icd_model import ICDModel
from .dataset import MimicFullDataset, build_dataloader, KFold
from .metrics import ICDCodingMetrics, ICDCodingDevTestMetrics
from file_io import Files


class ICDCodingWrapper(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.model = ICDModel(config.icd_model)

        train_section_dataset = Files.section_datasets[config.dataset.version]['train'].load()
        dev_section_dataset = Files.section_datasets[config.dataset.version]['dev'].load()
        train_dataset = Files.datasets[config.dataset.version]['train'].load()
        dev_dataset = Files.datasets[config.dataset.version]['dev'].load()
        test_datatset = Files.datasets[config.dataset.version]['test'].load()

        self.kfold = KFold(train_section_dataset, dev_section_dataset, train_dataset, dev_dataset)

        # self.train_dataset = MimicFullDataset(config.dataset, 'train', len(train_dataset) + len(dev_dataset))
        self.train_dataset = MimicFullDataset(config.dataset, 'train', len(train_dataset))
        self.dev_dataset = MimicFullDataset(config.dataset, 'dev', len(dev_dataset))
        self.test_dataset = MimicFullDataset(config.dataset, 'test', len(test_datatset))
        self.test_dataset.set_dataset(test_datatset)

        c_input_ids = torch.LongTensor(self.train_dataset.c_input_ids)
        c_word_mask = torch.FloatTensor(self.train_dataset.c_word_mask)
        self.register_buffer('c_input_ids', c_input_ids)
        self.register_buffer('c_word_mask', c_word_mask)

        self.train_metrics = ICDCodingMetrics(config.metrics.ks)
        self.dev_test_metrics = ICDCodingDevTestMetrics(config.metrics.ks, len(self.dev_dataset))

        print(f'# training/dev/test samples:'
              f'{len(self.train_dataset)}/{len(self.dev_dataset)}/{len(self.test_dataset)}')

    def log_metrics(self, metrics, type_):
        for key, value in metrics.items():
            self.log(f'{type_}_{key}', value, prog_bar=False)

    def training_step(self, batch, batch_index):
        input_ids, word_mask, labels = batch['input_ids'], batch['word_mask'], batch['labels']
        output = self.model(input_ids, word_mask, labels, self.c_input_ids, self.c_word_mask)
        loss, preds, labels = output['loss'], output['preds'], output['labels']
        self.train_metrics(preds.detach().cpu().numpy(), labels.detach().cpu().numpy())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_index):
        input_ids, word_mask, labels = batch['input_ids'], batch['word_mask'], batch['labels']
        output = self.model.evaluate(input_ids, word_mask, labels, self.c_input_ids, self.c_word_mask)
        loss, preds, labels = output['loss'], output['preds'], output['labels']
        self.dev_test_metrics(preds.detach().cpu().numpy(), labels.detach().cpu().numpy())
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_end(self):
        print('training metrics')
        self.train_metrics.compute()
        print('dev and test metrics')
        dev_metrics, test_metrics = self.dev_test_metrics.compute()
        self.log_metrics(dev_metrics, 'dev')
        self.log_metrics(test_metrics, 'test')

    def on_train_epoch_start(self):
        train_dataset, dev_dataset = self.kfold.choice()
        self.train_dataset.set_dataset(train_dataset)
        self.dev_dataset.set_dataset(dev_dataset)

    def train_dataloader(self):
        train_loader = build_dataloader(self.config.train_loader, self.train_dataset, self.train_dataset)
        return train_loader

    def val_dataloader(self):
        dev_loader = build_dataloader(self.config.dev_loader,
                                      ConcatDataset([self.dev_dataset, self.test_dataset]),
                                      self.train_dataset)
        return dev_loader

    def configure_optimizers(self):
        config = self.config.train
        no_decay = ["bias", "LayerNorm.weight"]
        params = [
            {
                "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": config.weight_decay,
                "lr": config.learning_rate
            },
            {
                "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": config.learning_rate
            },
        ]

        optimizer = torch.optim.AdamW(params, eps=config.adam_epsilon)

        n_iterations = math.ceil(len(self.train_dataset) / self.config.train_loader.batch_size)
        # n_iterations = math.ceil((len(self.train_dataset) + len(self.dev_dataset)) / self.config.train_loader.batch_size)
        total_steps = n_iterations * config.epochs

        # total_steps = len(self.train_loader) * config.epochs
        if config.scheduler == "linear":
            scheduler = get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(total_steps * config.warmup_ratio),
                num_training_steps=total_steps,
            )
        elif config.scheduler == "constant":
            scheduler = get_constant_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(total_steps * config.warmup_ratio)
            )
        elif config.scheduler == "cosine":
            scheduler = get_cosine_schedule_with_warmup(
                optimizer,
                num_warmup_steps=int(total_steps * config.warmup_ratio),
                num_training_steps=total_steps,
            )
        else:
            raise KeyError(f'scheduler {config.scheduler} is not supported')
        return [optimizer], [scheduler]
