from abc import abstractmethod


class BaseServer:
    def __init__(self, server_args):
        self.clients = []

    @abstractmethod
    def load_checkpoint(self, checkpoint):
        pass

    def test_classification_detection_ability(self, checkpoint, client_id_loaders, ood_loader, score_method="msp"):
        self.load_checkpoint(checkpoint)

        auroc = 0.0
        fpr95 = 0.0
        accuracy = 0.0

        test_samples = [len(id_loader) for id_loader in client_id_loaders]
        client_weights = [x / sum(test_samples) for x in test_samples]

        for client, id_loader, w in zip(self.clients, client_id_loaders, client_weights):
            client_accuracy, client_fpr95, client_auroc = client.test_classification_detection_ability(
                id_loader, ood_loader, score_method=score_method
            )
            accuracy += client_accuracy * w
            fpr95 += client_fpr95 / len(self.clients)
            auroc += client_auroc / len(self.clients)

        return accuracy, fpr95, auroc

    def test_corrupt_accuracy(self, client_cor_loaders):
        cor_accuracy = {}
        for cor_type, cor_loaders in client_cor_loaders.items():
            cor_accuracy[cor_type] = 0.0
            test_samples = [len(cor_loader) for cor_loader in cor_loaders]
            client_weights = [x / sum(test_samples) for x in test_samples]

            for client, cor_loader, w in zip(self.clients, cor_loaders, client_weights):
                cor_accuracy[cor_type] += client.test_corrupt_accuracy(cor_loader) * w

        return cor_accuracy
