import timeit
from logging import INFO
from typing import Optional

from flwr.common.logger import log
from flwr.server import Server
from flwr.server.history import History
from flwr.server.strategy import Strategy
from torch.utils.tensorboard import SummaryWriter

from src.client_manager import SeedCM as ClientManager


class StandardServer(Server):
    """Standard Flower server."""

    def __init__(
        self,
        *,
        client_manager: ClientManager,
        strategy: Strategy,
        writer: SummaryWriter,
        write_prox_loss: bool,
        eval_freq: int,
    ) -> None:
        self.parameters = None
        self.writer = writer
        self.write_prox_loss = write_prox_loss
        self.eval_freq = eval_freq
        super().__init__(client_manager=client_manager, strategy=strategy)

    # pylint: disable=too-many-locals
    def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()
        # Initialize parameters
        log(INFO, "Initializing global parameters")
        self.parameters = self._get_initial_parameters(timeout=timeout)
        log(INFO, "Evaluating initial parameters")
        res = self.strategy.evaluate(0, parameters=self.parameters)
        if res is not None:
            log(INFO, f"initial parameters (loss, other metrics): {res[0]}, {res[1]}")
            history.add_loss_centralized(server_round=0, loss=res[0])
            history.add_metrics_centralized(server_round=0, metrics=res[1])

        # Run federated learning for num_rounds
        log(INFO, f"FL starting")
        for current_round in range(1, num_rounds + 1):
            tb_props = dict(global_step=current_round, walltime=timeit.default_timer())

            start_time_fit = timeit.default_timer()
            self.parameters, fit_metrics, (success, failures) = self.fit_round(server_round=current_round, timeout=timeout)
            if len(failures) == 0:
                print(f"Select & fit time: {timeit.default_timer() - start_time_fit:.1f} s")
                _write_fit_metrics(self.writer, fit_metrics, self.write_prox_loss, tb_props)
            else:
                print(failures[0])
                raise failures[0]
                
            if current_round % self.eval_freq == 0:
                # Evaluate global model using strategy implementation
                start_time_eval = timeit.default_timer()
                self.strategy.evaluate(current_round, parameters=self.parameters)
                print(f"Global model eval time: {timeit.default_timer() - start_time_eval:.1f} s")

                # Evaluate model on a sample of available clients
                start_time_eval_local = timeit.default_timer()
                res_fed = self.evaluate_round(server_round=current_round, timeout=timeout)
                if res_fed:
                    print(f'Select & evaluate time: {timeit.default_timer() - start_time_eval_local:.1f} s')
                    loss_fed, eval_metrics, _ = res_fed
                    _write_eval_metrics(self.writer, loss_fed, eval_metrics, tb_props)

        log(INFO, "FL finished.")
        return history


def _write_fit_metrics(writer, fit_metrics, write_prox_loss: bool, tb_props):
    writer.add_scalar("avg_train_loss", fit_metrics['avg_train_loss'], **tb_props)
    writer.add_scalar("avg_train_acc", fit_metrics['avg_train_acc'], **tb_props)
    writer.add_scalar("avg_train_ece", fit_metrics['avg_train_ece'], **tb_props)
    if write_prox_loss:
        writer.add_scalar("avg_train_ce_loss", fit_metrics['avg_ce_loss'], **tb_props)
        writer.add_scalar("avg_train_prox_loss", fit_metrics['avg_prox_loss'], **tb_props)
    if "grad_var" in fit_metrics:
        writer.add_scalar("grad_var", fit_metrics["grad_var"], **tb_props)


def _write_eval_metrics(writer, loss_fed, eval_metrics, tb_props):
    writer.add_scalar("avg_val_loss", loss_fed, **tb_props)
    writer.add_scalar("avg_val_acc", eval_metrics['avg_val_acc'], **tb_props)
    writer.add_scalar("avg_val_ece", eval_metrics['avg_val_ece'], **tb_props)
    for i, client_val_acc in enumerate(eval_metrics['val_accs']):
        writer.add_scalar(f"client_val_metrics/{i}_client/val_acc", client_val_acc, **tb_props)
        writer.add_scalar(f"client_val_metrics/{i}_client/val_loss", eval_metrics['val_losses'][i], **tb_props)
        writer.add_scalar(f"client_val_metrics/{i}_client/val_ece", eval_metrics['val_eces'][i], **tb_props)
