import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import copy
import lightly

from lightly.models.modules.heads import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad
from lightly.models.utils import update_momentum
from lightly.models.utils import batch_shuffle
from lightly.models.utils import batch_unshuffle
import torchmetrics

feature_map = {0: 65536, 1: 32768, 2: 16384, 3: 8192, 4: 512, -1: 512}


class Classifier(pl.LightningModule):
    def __init__(self, backbone,
                 layer_idx=-1, num_classes=100, train_backbone=True,
                 max_epochs=500, use_vit=False):
        super().__init__()
        # return out5, [out1, out2, out3, out4, out5]
        self.max_epochs = max_epochs
        self.layer_idx = layer_idx
        self.is_cuda = torch.cuda.is_available()

        self.train_backbone = train_backbone
        self.backbone = backbone
        if not self.train_backbone:
            deactivate_requires_grad(backbone)
            deactivate_requires_grad(self.backbone)
        self.num_classes = num_classes
        self.feature_dimension = feature_map[self.layer_idx]

        self.fc = nn.Linear(self.feature_dimension, num_classes)
        if self.is_cuda:
            self.fc =self.fc.cuda()
        self.accuracy = torchmetrics.Accuracy('multiclass')

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        output, all_outputs = self.backbone(x)
        if self.layer_idx != -1:
            selected_output = all_outputs[self.layer_idx]
        else:
            selected_output = output

        y_hat = selected_output.flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y, _ = batch

        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.accuracy(y_hat, y)
        self.log("train_loss_fc", loss)
        self.log("Loss/total_loss", loss)
        self.log('train_acc_step', self.accuracy)
        return loss

    def training_epoch_end(self, outputs):
        self.log('train_acc_epoch', self.accuracy)

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        # calculate number of correct predictions
        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        return num, correct

    def validation_epoch_end(self, outputs):
        # calculate and log top1 accuracy
        if outputs:
            total_num = 0
            total_correct = 0
            for num, correct in outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        if not self.train_backbone:
            optim = torch.optim.Adam(self.fc.parameters(), lr=0.1, weight_decay=10e-6)
        else:
            optim = torch.optim.Adam(self.parameters(), lr=0.05)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, self.max_epochs)
        return [optim], [scheduler]
        # return optim
