import torch

from utils import *
import models as model_utils
from copy import deepcopy

device = 'cuda' if torch.cuda.is_available() else 'cpu'


class Device(object):
    def __init__(self, loader):
        self.loader = loader

    def evaluate(self, loader=None):
        return eval_op(self.model, self.loader if not loader else loader)

    def save_model(self, path=None, name=None, verbose=True):
        if name:
            torch.save(self.model.state_dict(), path + name)
            if verbose: print("Saved model to", path + name)

    def load_model(self, path=None, name=None, verbose=True):
        if name:
            self.model.load_state_dict(torch.load(path + name))
            if verbose: print("Loaded model from", path + name)


class Client(Device):
    def __init__(self, model_name, optimizer_fn, loader, idnum=0, num_classes=10, dataset='cifar10', lr_schedule=None,
                 quant=None, mode=None):
        super().__init__(loader)
        self.id = idnum
        print(f"dataset client {dataset}")
        self.model_name = model_name
        self.model = partial(model_utils.get_model(self.model_name)[0], num_classes=num_classes, dataset=dataset,
                             quant=quant)().to(device)

        self.W = {key: value for key, value in self.model.named_parameters()}
        if mode == 'scaffold':
            self.c_local = {}
            self.c_global = []

        self.optimizer = optimizer_fn(self.model.parameters())
        if lr_schedule is not None:
            self.lr_schedule = lr_schedule(self.optimizer)

        self.mode = mode
        self.dataset = dataset
        self.unique_target = [i for i in range(num_classes)]

        # self.model.weight_acc = {}
        # for name, param in self.model.named_parameters():
        #     self.model.weight_acc[name] = param.data.clone().to(device)

    def synchronize_with_server(self, server, quant=None):
        server_state = server.model.state_dict()
        self.model.load_state_dict(server_state, strict=True)

        self.origin = copy.deepcopy(self.model)

        # for name, param in server.model.named_parameters():
        #     w[name] = param.data.clone().detach()
        # for name, param in self.model.named_parameters():

    def compute_weight_update(self, epochs=1, loader=None, quant_fn=None, lambda_fedprox=0.0, c_global=None,
                              current_global_epoch=None, generator=None, regularization=0):
        if self.mode == 'scaffold':

            weight_Q = quant_fn['weight_Q']
            grad_Q = quant_fn['grad_Q']
            self.c_global = c_global
            if self.id not in self.c_local.keys():
                self.c_local[self.id] = [torch.zeros_like(c) for c in c_global]

            self.model.train()
            running_loss, samples = 0.0, 0
            for _ in range(epochs):
                for x, y in self.loader:
                    x, y = x.to(device), y.to(device)
                    self.optimizer.zero_grad()
                    loss = nn.CrossEntropyLoss()(self.model(x), y)
                    running_loss += loss.item() * y.shape[0]
                    samples += y.shape[0]
                    loss.backward()

                    for param, c, c_i in zip(self.model.parameters(), self.c_global, self.c_local[self.id]):
                        param.grad.data += c-c_i

                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            param.grad.data = grad_Q(param.grad.data).data

                    self.optimizer.step()
                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            param.data = weight_Q(param.data).data

            with torch.no_grad():
                y_delta = []
                c_plus = []
                c_delta = []

                for x, y_i in zip(self.origin.parameters(), self.model.parameters()):
                    y_delta.append(y_i.data - x.data)

                coef = 1 / (epochs * self.optimizer.defaults['lr'])
                for c, c_i, y_del in zip(
                        self.c_global, self.c_local[self.id], y_delta
                ):
                    c_plus.append(c_i - c - coef * y_del)

                for c_p, c_l in zip(c_plus, self.c_local[self.id]):
                    c_delta.append(c_p - c_l)

                self.c_local[self.id] = c_plus

            return {"loss": running_loss / samples, "y_delta": y_delta, "c_delta": c_delta}

        elif self.mode == 'fedgen':
            from args import parse_argument
            self.args = parse_argument()
            self.current_global_epoch = current_global_epoch

            all_targets = []
            target_count = {target: 0 for target in self.unique_target}

            for batch_idx, (inputs, targets) in enumerate(self.loader):
                all_targets.extend(targets.tolist())
                for i in range(targets.size(0)):
                    target_count[targets[i].item()] += 1
            self.available_labels = torch.unique(torch.tensor(all_targets)).tolist()
            target_list = [target_count[target] if target in target_count else 1 for target in self.unique_target]

            weight_Q = quant_fn['weight_Q']
            grad_Q = quant_fn['grad_Q']
            self.model.train()
            generator.train()
            running_loss, samples = 0.0, 0
            for it in range(epochs):
                for x, y in self.loader:
                    x, y = x.to(device), y.to(device)
                    logits = self.model(x)
                    loss = nn.CrossEntropyLoss()(logits, y)

                    if regularization:
                        alpha = self.exp_coef_scheduler(self.args.generative_alpha)
                        beta = self.exp_coef_scheduler(self.args.generative_beta)
                        generator_output, _ = generator(y)
                        logits_gen = self.model.classifier(generator_output).detach()

                        latent_loss = beta * F.kl_div(
                            F.log_softmax(logits, dim=1),
                            F.softmax(logits_gen, dim=1),
                            reduction="batchmean",
                        )

                        sampled_y = torch.tensor(
                            np.random.choice(
                                self.available_labels, self.args.gen_batch_size
                            ),
                            dtype=torch.long,
                            device=device,
                        )
                        generator_output, _ = generator(sampled_y)
                        logits = self.model.classifier(generator_output)
                        teacher_loss = alpha * nn.CrossEntropyLoss()(logits, sampled_y)

                        gen_ratio = self.args.gen_batch_size / self.args.batch_size

                        loss += gen_ratio * teacher_loss + latent_loss


                    running_loss += loss.item() * y.shape[0]
                    samples += y.shape[0]

                    self.optimizer.zero_grad()
                    loss.backward()
                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            param.grad.data = grad_Q(param.grad.data).data

                    self.optimizer.step()
                    with torch.no_grad():
                        for name, param in self.model.named_parameters():
                            param.data = weight_Q(param.data).data


            delta = self.model.state_dict()

            return {"loss": running_loss / samples, "delta": delta, "weight": len(self.loader), "label_counts": target_list}

        else:
            train_stats = train_op(self.model, self.loader if not loader else loader, self.optimizer, epochs,
                                   quant_fn=quant_fn, lambda_fedprox=lambda_fedprox, id=self.id)
        return train_stats

    def compute_weight_update_ma(self, epochs=1, loader=None, quant_fn=None, moving_weight=0.1):
        train_stats = train_op_ma(self.model, self.loader if not loader else loader, self.optimizer, epochs,
                               quant_fn=quant_fn, moving_weight=moving_weight)
        return train_stats

    def predict_logit(self, x):
        """Softmax prediction on input"""
        self.model.train()

        with torch.no_grad():
            y_ = self.model(x)

        return y_

    def exp_coef_scheduler(self, init_coef):
        return max(
            1e-4,
            init_coef
            * (
                self.args.coef_decay
                ** (self.current_global_epoch // self.args.coef_decay_epoch)
            ),
        )
