import os
import shutil
from typing import Dict, List, Tuple

import hydra
import torch

import flwr
import random
import numpy as np

from flwr.server import ServerConfig
from flwr.common import Metrics
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.tensorboard import SummaryWriter

from src.strategy import FedAvgCustom, FLOCO
from src.client_manager import SeedCM
from src.data import get_dataloaders
from src.util import server_test_approx, get_label_dist
from src.networks import net_fn
from src.server import StandardServer
from src.client import FLOCOClient, flwr_get_parameters, flwr_set_parameters

def _init_rnd(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_experiment_name(cfg):
    strategy_str = f"{cfg.strategy.strategy_name}_{cfg.dataset_model.network_arch}"
    if cfg.strategy.local_model:
        strategy_str += f"_LM_{cfg.strategy.local_prox_mu}_TAU_{cfg.strategy.tau}_CL"
    if cfg.strategy.strategy_name == "FedProx":
        strategy_str += f"_FP_{cfg.dataset_model.fedprox_mu}"
    if cfg.strategy.strategy_name == "FLOCO":
        strategy_str += f"_NT_{cfg.strategy.network_type}"
        strategy_str += f"_PJ_{cfg.strategy.projection}"
        strategy_str += f"_REG_{cfg.strategy.reg_hp}"
        if cfg.rule.num_points > 1:
            strategy_str += f"_NP_{cfg.rule.num_points}_SUS_{cfg.strategy.subspace_start}_RHO_{cfg.rule.rho}_{cfg.rule.distance_measure}"

    name = (
        f"{strategy_str}_{cfg.dataset_model.dataset_name}_"
        f"{cfg.rule.rule_name}_{cfg.rule.rule_arg}_"
        f"{cfg.dataset_model.optimizer[:3]}_"
        f"AG_{cfg.alpha_g}_"
        f"AL_{cfg.dataset_model.opt_args.lr}_"
        f"C_{cfg.num_clients}_"
        f"S_{cfg.clients_per_round}_"
        f"LE_{cfg.dataset_model.local_epochs}_"
        f"BS_{cfg.dataset_model.batch_size}_"
        f"WD_{cfg.dataset_model.opt_args.get('weight_decay', '')}_"  # TODO create opt_args.to_string
        f"M_{cfg.dataset_model.opt_args.get('momentum', '')}_"
        f"VF_{cfg.val_frac}_"
        f"NEC_{cfg.evaluate_clients_per_round}_"
        f"SEED_{cfg.seed}"
    )

    OmegaConf.set_struct(cfg, True)
    with open_dict(cfg):
      cfg.experiment_name = name

    i = 0
    while os.path.exists(f"reports/{name},i={i}"):
        i += 1
    return name + f",i={i}"


@hydra.main(version_base=None, config_path="./configs", config_name="config")
def fl_training(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    experiment_name = get_experiment_name(cfg)
    print(f"Starting experiment {experiment_name} ...")
    _init_rnd(cfg.seed)
    cfg.use_cuda = not cfg.disable_cuda and torch.cuda.is_available()
    cfg.device = 'cuda' if cfg.use_cuda else 'cpu'
    print(f"CUDA IS AVAILABLE: {torch.cuda.is_available()}; CUDA IS ENABLED: {cfg.use_cuda}")

    # prep config
    if cfg.strategy.network_type == "POINT":
        cfg.rule.num_points = 1  # TODO document this

    # Initialize Tensorboard writer
    writer = SummaryWriter(log_dir="./reports/" + experiment_name)

    # Get dataloaders
    trainloaders, valloaders, testloader, num_classes, folds = get_dataloaders(cfg)

    # Compute label distributions
    client_label_dists = {f"{i}_client": get_label_dist(loader, num_classes) for i, loader in enumerate(trainloaders)}

    # Initialize 1 model for initial params
    net = net_fn(
        dataset_name=cfg.dataset_model.dataset_name,
        num_classes=cfg.dataset_model.num_classes,
        network_arch=cfg.dataset_model.network_arch,
        network_type=cfg.strategy.network_type,
        num_points=cfg.rule.num_points,
        seed=cfg.seed,
        device=cfg.device,
    )
    
    initial_params = flwr_get_parameters(net)
    
    def client_fn(client_name) -> flwr.client.Client:
        net = net_fn(
            dataset_name=cfg.dataset_model.dataset_name,
            num_classes=cfg.dataset_model.num_classes,
            network_arch=cfg.dataset_model.network_arch,
            network_type=cfg.strategy.network_type,
            num_points=cfg.rule.num_points,
            seed=cfg.seed,
            device=cfg.device,
        )

        if cfg.strategy.local_model:
            local_net = net_fn(
                dataset_name=cfg.dataset_model.dataset_name,
                num_classes=cfg.dataset_model.num_classes,
                network_arch=cfg.dataset_model.network_arch,
                network_type=cfg.strategy.network_type,
                num_points=cfg.rule.num_points,
                seed=cfg.seed,
                device=cfg.device,
            )
        else:
            local_net = None

        client_id = int(client_name.split('_')[0])
        return FLOCOClient(
            client_name=client_name,
            net=net,
            local_net=local_net,
            trainloader=trainloaders[client_id],
            valloader=valloaders[client_id],
            cfg=cfg,
            experiment_name=experiment_name
        )

    # The `evaluate` function will be by Flower called after every round
    def server_eval_fn(server_round: int, parameters: flwr.common.NDArrays, config: Dict[str, flwr.common.Scalar]):
        net = net_fn(
            dataset_name=cfg.dataset_model.dataset_name,
            num_classes=cfg.dataset_model.num_classes,
            network_arch=cfg.dataset_model.network_arch,
            network_type=cfg.strategy.network_type,
            num_points=cfg.rule.num_points,
            seed=cfg.seed,
            device=cfg.device,
        )
        flwr_set_parameters(net, parameters)  # Update model with the latest parameters
        server_test_approx(server_round, net, testloader, writer, num_classes, cfg=cfg, folds=folds,
                           strategy_config=config)
        if server_round % 100 == 0:
            net_state_dict = net.state_dict()
            # Save the model parameters
            if net_state_dict is not None and cfg.save_model:
                os.makedirs(f"trained_models/global_models/{experiment_name}/", exist_ok=True)
                torch.save(
                    net_state_dict, f"trained_models/global_models/{experiment_name}/global_model_round_{server_round}.pth"
                )

    fraction_fit = (1 / cfg.num_clients) * cfg.clients_per_round
    fraction_evaluate = (1 / cfg.num_clients) * cfg.evaluate_clients_per_round

    # Pass parameters to the Strategy for server-side parameter initialization
    strategy_args = {
        "fraction_fit": fraction_fit,
        "fraction_evaluate": fraction_evaluate,
        "initial_parameters": flwr.common.ndarrays_to_parameters(initial_params),
        "evaluate_fn": server_eval_fn,
        "on_fit_config_fn": fit_config,
        "on_evaluate_config_fn": eval_config,
        "fit_metrics_aggregation_fn": train_metrics_aggregation_fn,
        "evaluate_metrics_aggregation_fn": eval_metrics_aggregation_fn
    }
    print(f'cfg.strategy.strategy_name {cfg.strategy.strategy_name}')
    if cfg.strategy.strategy_name in ['FedAvg', 'FedProx']:
        strategy = FedAvgCustom(**strategy_args)
    elif cfg.strategy.strategy_name == 'SCAFFOLD':
        strategy = ScaffoldStrategy(**strategy_args)
    elif cfg.strategy.strategy_name == 'FLOCO':
        strategy_args.update({
            "xp_name": experiment_name,
            "subspace": net,
            "writer": writer,
            "cfg": cfg,
        })
        strategy = FLOCO(**strategy_args)
    else:
        raise NotImplementedError("Strategy not implemented")

    # Do client label histogram plots
    for i, client_label_dist in enumerate(client_label_dists.values()):
        for j, l in enumerate(client_label_dist):
            writer.add_scalar(f"label_dists/{i}_client_label_dist", l, global_step=j)
    if cfg.strategy.strategy_name == "SCAFFOLD":
        server = ScaffoldServer(
            client_manager=SeedCM(seed=cfg.seed),
            strategy=strategy,
            writer=writer,
            write_prox_loss=cfg.strategy.strategy_name == "FedProx" or cfg.strategy.local_model,
            eval_freq=cfg.eval_freq,
        )
    else:
        server = StandardServer(
            client_manager=SeedCM(seed=cfg.seed),
            strategy=strategy,
            writer=writer,
            write_prox_loss=cfg.strategy.strategy_name == "FedProx" or cfg.strategy.local_model,
            eval_freq=cfg.eval_freq,
        )

    client_names = [f"{i}_client" for i in np.arange(cfg.num_clients)]
    if cfg.strategy.local_model:
        os.makedirs(f"./trained_models/local_client_models/{experiment_name}")

    # TODO hack
    ray_init_args = dict(cfg.ray.init_args)
    ray_init_args["_system_config"] = dict(ray_init_args["_system_config"])

    flwr.simulation.start_simulation(
        client_fn=client_fn,
        clients_ids=client_names,
        server=server,
        config=ServerConfig(num_rounds=cfg.communication_rounds),
        client_resources=cfg.ray.client_resources_gpu if cfg.use_cuda else cfg.ray.client_resources_cpu,
        ray_init_args=ray_init_args,
        keep_initialised=True,
    )
    print("Simulation finished successfully.")
    if cfg.strategy.local_model:
        shutil.rmtree(f"./trained_models/local_client_models/{experiment_name}", ignore_errors=True)


def train_metrics_aggregation_fn(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    train_losses = [m["train_loss"] for _, m in metrics]
    weighted_train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
    train_accs = [m["train_acc"] for _, m in metrics]
    weighted_train_accs = [num_examples * m["train_acc"] for num_examples, m in metrics]
    train_eces = [m["train_ece"] for _, m in metrics]
    weighted_train_eces = [num_examples * m["train_ece"] for num_examples, m in metrics]
    train_examples = [num_examples for num_examples, _ in metrics]
    ce_losses = [m["ce_loss"] for _, m in metrics]
    prox_losses = [m["prox_loss"] for _, m in metrics]

    # Aggregate and return custom metric (weighted average)
    return {
        "train_losses": train_losses,
        "train_accs": train_accs,
        "train_eces": train_eces,
        "avg_train_loss": sum(weighted_train_losses) / sum(train_examples),
        "avg_train_acc": sum(weighted_train_accs) / sum(train_examples),
        "avg_train_ece": sum(weighted_train_eces) / sum(train_examples),
        "avg_ce_loss": np.mean(ce_losses),
        "avg_prox_loss": np.mean(prox_losses)
    }


def eval_metrics_aggregation_fn(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    val_losses = [m["val_loss"] for _, m in metrics]
    weighted_val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
    val_accs = [m["val_acc"] for _, m in metrics]
    weighted_val_accs = [num_examples * m["val_acc"] for num_examples, m in metrics]
    val_eces = [m["val_ece"] for _, m in metrics]
    weighted_val_eces = [num_examples * m["val_ece"] for num_examples, m in metrics]
    val_examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {
        "val_losses": val_losses,
        "val_accs": val_accs,
        "val_eces": val_eces,
        "avg_val_loss": sum(weighted_val_losses) / sum(val_examples),
        "avg_val_acc": sum(weighted_val_accs) / sum(val_examples),
        "avg_val_ece": sum(weighted_val_eces) / sum(val_examples)
    }


def fit_config(server_round: int):
    config = {
        "server_round": server_round,
    }
    return config


def eval_config(server_round: int):
    config = {
        "server_round": server_round,
    }
    return config


if __name__ == "__main__":
    fl_training()
