#!/usr/bin/env python
import copy
import torch
import argparse
import os
import time
import warnings
import numpy as np
import torchvision
import logging

from flcore.servers.serveravg import FedAvg
from flcore.servers.serverpFedMe import pFedMe
from flcore.servers.serverperavg import PerAvg
from flcore.servers.serverprox import FedProx
from flcore.servers.serverfedaptor import FEDAPTOR
from flcore.servers.servermapmep import MAPMEP
from flcore.servers.servermpavg import MPAvg
from flcore.servers.servermmpavg import MMPAvg
from flcore.servers.servermcpa import MCPA
from flcore.servers.servermp import FedMP
from flcore.servers.serverscaffold import SCAFFOLD



# from flcore.servers.serverfomo import FedFomo
# from flcore.servers.serveramp import FedAMP
# from flcore.servers.servermtl import FedMTL
# from flcore.servers.serverlocal import Local
# from flcore.servers.serverper import FedPer
# from flcore.servers.serverapfl import APFL
# from flcore.servers.serverditto import Ditto
# from flcore.servers.serverrep import FedRep
# from flcore.servers.serverphp import FedPHP
# from flcore.servers.serverbn import FedBN
# from flcore.servers.serverrod import FedROD
# from flcore.servers.serverproto import FedProto
# from flcore.servers.serverdyn import FedDyn
# from flcore.servers.servermoon import MOON
# from flcore.servers.serverbabu import FedBABU
# from flcore.servers.serverapple import APPLE
# from flcore.servers.servergen import FedGen
# from flcore.servers.serverdistill import FedDistill

from flcore.trainmodel.models import *

from flcore.trainmodel.bilstm import BiLSTM_TextClassification
# from flcore.trainmodel.resnet import resnet18 as resnet
from flcore.trainmodel.alexnet import alexnet
from flcore.trainmodel.mobilenet_v2 import mobilenet_v2
from utils.result_utils import average_data
from utils.mem_utils import MemReporter

# @WJM: from FL_HSI
import datetime
from utils.utils_spectral_dataloader import *
# from torchsummaryX import summary
import torch.nn as nn
from utils.options import args_parser
from models.SRN import SRN
# from models.MST_ADA_plain import MST
from models.MST_Adaptor import MST
from models.Prompt import Prompt_base, Prompt_align
from models.Prompt_Transformer import Prompt_Transformer


# from models.Fed import FedAvg

logger = logging.getLogger()
logger.setLevel(logging.ERROR)

warnings.simplefilter("ignore")
torch.manual_seed(0)

# # hyper-params for Text tasks
# vocab_size = 98635
# max_len=200
# emb_dim=32

def run(args):

    time_list = []
    time_list = []
    # reporter = MemReporter()
    model_str = args.model

    # @WJM: for trn_split_ratio
    if args.trn_split == 1:
        args.trn_split_ratio = np.ones(args.num_clients)
    elif args.trn_split == 2:
        args.trn_split_ratio = np.array(args.trn_split_ratio)
    elif args.trn_split ==0:
        pass
    else:
        raise ValueError

    # @WJM: args dependency
    if args.PTP:
        assert args.adaptor is not None, 'in PTP setting, must employ adaptor!'
        assert len(args.last_train_clients) == args.num_clients, 'in PTP setting, must provide clients ckpts (pre-trained model_epoch)!'
        assert len(args.model_save_filename_clients) == args.num_clients, 'in PTP setting, must provide clients ckpts (pre-trained model dir)!'
    if args.CA:
        assert len(args.last_train_clients) == args.num_clients, 'in CA setting, must provide clients ckpts (pre-trained model_epoch)!'
        assert len(args.model_save_filename_clients) == args.num_clients, 'in CA setting, must provide clients ckpts (pre-trained model dir)!'
    if args.MB:
        assert args.adaptor is None, 'in MB setting, adaptor is not allowed!'
    if args.FMABFT:
        assert args.backbone_interval >0 and args.backbone_interval < args.local_steps, 'in FMABFT, [backbone_interval] must satisfy some conditions!'


    # for i in range(args.prev, args.times):
    #     print(f"\n============= Running time: {i}th =============")
    print("Creating server and clients ...")
    # start = time.time()

    # # Generate args.model
    # if model_str == "mlr":
    #     if args.dataset == "mnist" or args.dataset == "fmnist":
    #         args.model = Mclr_Logistic(1*28*28, num_classes=args.num_classes).to(args.device)
    #     elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
    #         args.model = Mclr_Logistic(3*32*32, num_classes=args.num_classes).to(args.device)
    #     else:
    #         args.model = Mclr_Logistic(60, num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "cnn":
    #     if args.dataset == "mnist" or args.dataset == "fmnist":
    #         args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=1024).to(args.device)
    #     elif args.dataset == "omniglot":
    #         args.model = FedAvgCNN(in_features=1, num_classes=args.num_classes, dim=33856).to(args.device)
    #     elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
    #         args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=1600).to(args.device)
    #         # args.model = CifarNet(num_classes=args.num_classes).to(args.device)
    #     elif args.dataset == "Digit5":
    #         args.model = Digit5CNN().to(args.device)
    #     else:
    #         args.model = FedAvgCNN(in_features=3, num_classes=args.num_classes, dim=10816).to(args.device)
    #
    # elif model_str == "dnn": # non-convex
    #     if args.dataset == "mnist" or args.dataset == "fmnist":
    #         args.model = DNN(1*28*28, 100, num_classes=args.num_classes).to(args.device)
    #     elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
    #         args.model = DNN(3*32*32, 100, num_classes=args.num_classes).to(args.device)
    #     else:
    #         args.model = DNN(60, 20, num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "resnet":
    #     args.model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_classes).to(args.device)
    #
    #     # args.model = torchvision.models.resnet18(pretrained=True).to(args.device)
    #     # feature_dim = list(args.model.fc.parameters())[0].shape[1]
    #     # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
    #
    #     # args.model = resnet18(num_classes=args.num_classes, has_bn=True, bn_block_num=4).to(args.device)
    #
    # elif model_str == "alexnet":
    #     args.model = alexnet(pretrained=False, num_classes=args.num_classes).to(args.device)
    #
    #     # args.model = alexnet(pretrained=True).to(args.device)
    #     # feature_dim = list(args.model.fc.parameters())[0].shape[1]
    #     # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
    #
    # elif model_str == "googlenet":
    #     args.model = torchvision.models.googlenet(pretrained=False, aux_logits=False, num_classes=args.num_classes).to(args.device)
    #
    #     # args.model = torchvision.models.googlenet(pretrained=True, aux_logits=False).to(args.device)
    #     # feature_dim = list(args.model.fc.parameters())[0].shape[1]
    #     # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
    #
    # elif model_str == "mobilenet_v2":
    #     args.model = mobilenet_v2(pretrained=False, num_classes=args.num_classes).to(args.device)
    #
    #     # args.model = mobilenet_v2(pretrained=True).to(args.device)
    #     # feature_dim = list(args.model.fc.parameters())[0].shape[1]
    #     # args.model.fc = nn.Linear(feature_dim, args.num_classes).to(args.device)
    #
    # elif model_str == "lstm":
    #     args.model = LSTMNet(emb_dim=emb_dim, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "bilstm":
    #     args.model = BiLSTM_TextClassification(input_size=vocab_size, hidden_size=emb_dim, output_size=args.num_classes,
    #                 num_layers=1, embedding_dropout=0, lstm_dropout=0, attention_dropout=0,
    #                 embedding_length=emb_dim).to(args.device)
    #
    # elif model_str == "fastText":
    #     args.model = fastText(emb_dim=emb_dim, vocab_size=vocab_size, num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "TextCNN":
    #     args.model = TextCNN(emb_dim=emb_dim, max_len=max_len, vocab_size=vocab_size,
    #                     num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "Transformer":
    #     args.model = TransformerModel(ntoken=vocab_size, d_model=emb_dim, nhead=2, d_hid=emb_dim, nlayers=2,
    #                     num_classes=args.num_classes).to(args.device)
    #
    # elif model_str == "AmazonMLP":
    #     args.model = AmazonMLP().to(args.device)

    # @WJM: for MP mode
    if args.MP or args.algorithm=='MPT':

        # prompt
        if args.model == 'Prompt_base':
            args.model = Prompt_base(in_ch=28,
                                    out_ch=28,
                                    n_resblocks=args.Prompt_BLK,
                                    n_feats=64,
                                    kernel_size=3,
                                    bn=args.Prompt_BN).cuda()
            args.model = nn.DataParallel(args.model)
        elif args.model == 'Prompt_align':
            args.model = Prompt_align(in_ch=28,
                                     out_ch=28,
                                     n_resblocks=args.Prompt_BLK,
                                     n_feats=64,
                                     kernel_size=3,
                                     bn=args.Prompt_BN).cuda()
            args.model = nn.DataParallel(args.model)
        elif args.model == 'Prompt_transformer':
            args.model = Prompt_Transformer(upscale=1,
                                             in_chans=28,
                                             img_size=args.patch_size,
                                             window_size=8,
                                             img_range=1.,
                                             depths=[2],
                                             embed_dim=args.embed_dim,
                                             num_heads=[4],
                                             mlp_ratio=2,
                                             upsampler='',
                                             resi_connection='1conv').cuda()
            args.model = nn.DataParallel(args.model)
        else:
            raise NotImplementedError

        # backbone
        if args.backbone == 'SRN':
            args.backbone = SRN(in_ch=28,
                             out_ch=28,
                             n_resblocks=16,
                             n_feats=64,
                             kernel_size=3).to(args.device)
            args.backbone = nn.DataParallel(args.backbone)
            # args.model = nn.parallel.DistributedDataParallel(args.model)

        elif args.backbone == 'S2_ViT':
            pass

        elif args.backbone == 'MST-S':
            args.backbone = MST(dim=28,
                             stage=2,
                             num_blocks=[2, 2, 2],
                             adaptor=args.adaptor).to(args.device)
            args.backbone = nn.DataParallel(args.backbone)

        elif args.backbone == 'MST-M':
            args.backbone = MST(dim=28,
                             stage=2,
                             num_blocks=[2, 4, 4],
                             adaptor=args.adaptor).to(args.device)
            args.backbone = nn.DataParallel(args.backbone)

        elif args.backbone == 'MST-L':
            args.backbone = MST(dim=28,
                             stage=2,
                             num_blocks=[4, 7, 5],
                             adaptor=args.adaptor).to(args.device)
            args.backbone = nn.DataParallel(args.backbone)

        else:
            raise NotImplementedError
    else:
        # @WJM: competiable with SRN, S2ViT, MST
        if model_str == 'SRN':
            args.model = SRN(in_ch=28,
                             out_ch=28,
                             n_resblocks=16,
                             n_feats=64,
                             kernel_size=3).to(args.device)
            args.model = nn.DataParallel(args.model)
            # args.model = nn.parallel.DistributedDataParallel(args.model)

        elif model_str == 'S2_ViT':
            pass

        elif model_str == 'MST-S':
            args.model = MST(dim=28,
                             stage=2,
                             num_blocks=[2, 2, 2],
                             adaptor=args.adaptor).to(args.device)
            args.model = nn.DataParallel(args.model)

        elif model_str == 'MST-M':
            args.model = MST(dim=28,
                             stage=2,
                             num_blocks=[2, 4, 4],
                             adaptor=args.adaptor).to(args.device)
            args.model = nn.DataParallel(args.model)

        elif model_str == 'MST-L':
            args.model = MST(dim=28,
                             stage=2,
                             num_blocks=[4, 7, 5],
                             adaptor=args.adaptor).to(args.device)
            args.model = nn.DataParallel(args.model)

        else:
            raise NotImplementedError

    # print('>>>args.model=', args.model)
    for name, params in args.model.named_parameters():
        print(name, params.shape)

    # @WJM: global weight initialization (backbone(adaptor) will be erased if load from checkpoint)
    global_model_init(model=args.model, param_init=args.param_init)

    # @WJM: global weight initialization from pre-trained clients
    if args.CA:
        args.model = global_init_CA(args=args, strict=False)
        print('------Successfully load the global model with pre-trained clients!------')

    # load from checkpoint
    if args.MP or args.algorithm=='MPT':
        # args.model ==> prompt model
        # args.backbone ==> backbone
        if args.last_train_prompt == 0:
            print('train from stratch')
            assert args.model_save_filename_prompt == '', 'ERROR: No need to specify model_save_filename'
            rand_wait = np.random.randint(low=1, high=20)
            time.sleep(rand_wait)
            date_time = str(datetime.datetime.now())
            date_time = time2file_name(date_time)
        else:
            print('train from checkpoint')
            assert args.model_save_filename_prompt != '', 'ERROR: please specify model_save_filename when last_train !=0'
            # date_time = args.model_save_filename_prompt

            checkpoint_path = './' + args.model_path + '/' + args.model_save_filename_prompt + '/prompt_model_epoch_{}.pth'.format(
                args.last_train_prompt)
            model_checkpoint = torch.load(checkpoint_path)
            missing_key, unexpected_key = args.model.load_state_dict(model_checkpoint['model_weights'], strict=False)
            print('Missing keys=', missing_key)
            print('Unexpected_keys=', unexpected_key)
            # args.learning_rate = model_checkpoint['lr_last_record'][0]
            # print('>>>>>>load lr:', args.learning_rate, type(args.learning_rate))
            print('------Successfully load the pre-trained Prompt global model!------')
            rand_wait = np.random.randint(low=1, high=25)
            time.sleep(rand_wait)
            date_time = str(datetime.datetime.now())
            date_time = time2file_name(date_time)
        # if args.last_train == 0:
        #     # print('train from stratch')
        #     # assert args.model_save_filename == '', 'ERROR: No need to specify model_save_filename'
        #     # rand_wait = np.random.randint(low=1, high=20)
        #     # time.sleep(rand_wait)
        #     # date_time = str(datetime.datetime.now())
        #     # date_time = time2file_name(date_time)
        #     pass
        # else:
        #     print('train from checkpoint')
        #     assert args.model_save_filename != '', 'ERROR: please specify model_save_filename when last_train !=0'
        #     date_time = args.model_save_filename
        #
        #     checkpoint_path = './' + args.model_path + '/' + args.model_save_filename + '/model_epoch_{}.pth'.format(
        #         args.last_train)
        #     model_checkpoint = torch.load(checkpoint_path)
        #     missing_key, unexpected_key = args.backbone.load_state_dict(model_checkpoint['model_weights'], strict=False)
        #     print('Missing keys=', missing_key)
        #     print('Unexpected_keys=', unexpected_key)
        #     # args.learning_rate = model_checkpoint['lr_last_record'][0]
        #     # print('>>>>>>load lr:', args.learning_rate, type(args.learning_rate))
        #     print('------Successfully load the pre-trained global model!------')
        #     # rand_wait = np.random.randint(low=1, high=25)
        #     # time.sleep(rand_wait)
        #     # date_time = str(datetime.datetime.now())
        #     # date_time = time2file_name(date_time)

    else:
        if args.last_train == 0:
            print('train from stratch')
            assert args.model_save_filename == '', 'ERROR: No need to specify model_save_filename'
            rand_wait = np.random.randint(low=1, high=20)
            time.sleep(rand_wait)
            date_time = str(datetime.datetime.now())
            date_time = time2file_name(date_time)
        else:
            print('train from checkpoint')
            assert args.model_save_filename != '', 'ERROR: please specify model_save_filename when last_train !=0'
            date_time = args.model_save_filename

            checkpoint_path = './' + args.model_path + '/' + args.model_save_filename + '/model_epoch_{}.pth'.format(
                args.last_train)
            model_checkpoint = torch.load(checkpoint_path)
            missing_key, unexpected_key = args.model.load_state_dict(model_checkpoint['model_weights'], strict=False)
            print('Missing keys=', missing_key)
            print('Unexpected_keys=', unexpected_key)
            # args.learning_rate = model_checkpoint['lr_last_record'][0]
            # print('>>>>>>load lr:', args.learning_rate, type(args.learning_rate))
            print('------Successfully load the pre-trained global model!------')
            rand_wait = np.random.randint(low=1, high=25)
            time.sleep(rand_wait)
            date_time = str(datetime.datetime.now())
            date_time = time2file_name(date_time)

    args.model_path = args.model_path + '/' + date_time
    print('>>>args.model_path', args.model_path)
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # initialize logger
    for i in range(args.num_clients + 1):  # global: K, user: 0,...,K-1

        if i == args.num_clients:
            msg = "/*******Global model logger******/ \n"
            gen_log(model_path=args.model_path, msg=msg, user_id=i)
            argsDict = args.__dict__
            gen_log(model_path=args.model_path, msg=argsDict, user_id=i)
        else:
            msg = "/*******Client {:d} logger******/ \n Learning rate:{}, batch_size:{}.\n".format(i,
                                                                                                   args.local_learning_rate,
                                                                                                   args.batch_size)
            gen_log(model_path=args.model_path, msg=msg, user_id=i)

    # select algorithm
    if args.algorithm == "FedAvg":
        server = FedAvg(args)

    # elif args.algorithm == "Local":
    #     server = Local(args, i)

    # elif args.algorithm == "FedMTL":
    #     server = FedMTL(args, i)

    elif args.algorithm == "PerAvg":
        server = PerAvg(args)

    elif args.algorithm == "pFedMe":
        server = pFedMe(args)

    elif args.algorithm == "FedProx":
        server = FedProx(args)

    # @WJM: add proposed method: FEDAPTOR
    elif args.algorithm == 'FEDAPTOR':
        server = FEDAPTOR(args)

    # @WJM: add proposed method: MAPMEP
    elif args.algorithm == 'MAPMEP':
        server = MAPMEP(args)

    # @WJM: add proposed method: MPAvg
    elif args.algorithm == 'MPAvg':
        server = MPAvg(args)

    elif args.algorithm == 'MMPAvg':
        server = MMPAvg(args)

    elif args.algorithm == 'MCPA':
        server = MCPA(args)

    elif args.algorithm == 'MPT':
        server = FedMP(args)

    elif args.algorithm == "SCAFFOLD":
        server = SCAFFOLD(args)

    # elif args.algorithm == "FedFomo":
    #     server = FedFomo(args, i)
    #
    # elif args.algorithm == "FedAMP":
    #     server = FedAMP(args, i)
    #
    # elif args.algorithm == "APFL":
    #     server = APFL(args, i)
    #
    # elif args.algorithm == "FedPer":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedPer(args, i)
    #
    # elif args.algorithm == "Ditto":
    #     server = Ditto(args, i)
    #
    # elif args.algorithm == "FedRep":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedRep(args, i)
    #
    # elif args.algorithm == "FedPHP":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedPHP(args, i)
    #
    # elif args.algorithm == "FedBN":
    #     server = FedBN(args, i)
    #
    # elif args.algorithm == "FedROD":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedROD(args, i)
    #
    # elif args.algorithm == "FedProto":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedProto(args, i)
    #
    # elif args.algorithm == "FedDyn":
    #     server = FedDyn(args, i)
    #
    # elif args.algorithm == "MOON":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = MOON(args, i)
    #
    # elif args.algorithm == "FedBABU":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedBABU(args, i)
    #
    # elif args.algorithm == "APPLE":
    #     server = APPLE(args, i)
    #
    # elif args.algorithm == "FedGen":
    #     args.head = copy.deepcopy(args.model.fc)
    #     args.model.fc = nn.Identity()
    #     args.model = BaseHeadSplit(args.model, args.head)
    #     server = FedGen(args, i)
    #

    #
    # elif args.algorithm == "FedDistill":
    #     server = FedDistill(args, i)

    else:
        raise NotImplementedError

    if args.test_mode:
        server.test()
    else:
        server.train()


    # time_list.append(time.time()-start)

    # print(f"\nAverage time cost: {round(np.average(time_list), 2)}s.")
    

    # Global average
    # average_data(dataset=args.dataset, algorithm=args.algorithm, goal=args.goal, times=args.times)

    print("All done!")

    # reporter.report()


if __name__ == "__main__":
    # total_start = time.time()

    args = args_parser()

    # GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id

    if args.device == "cuda" and not torch.cuda.is_available():
        print("\ncuda is not avaiable.\n")
        # args.device = "cpu"
        raise Exception('NO GPU!')


    # @WJM: from FL_HSI
    # load trainset (dataloader) & masks
    if args.cluster == 'US':
        if args.debug:
            args.train_data_path = args.train_data_path + 'debug_mat/'
        else:
            args.train_data_path = args.train_data_path + 'mat/'
    elif args.cluster == 'CHN':
        args.train_data_path = args.train_data_path
    # @WJM: exchange with args.dataset for passing
    # train_set = LoadTraining(args.train_data_path)
    # hsi_train = HSI_train(train_set, args.epoch_sum_num)
    if not args.test_mode:
        args.dataset = LoadTraining(args.train_data_path)

    # load mask(s)
    assert args.num_clients == len(args.mask_ids), 'ERROR: current FL only support num_clients=num_masks!'
    args.mask4d_ls = generate_masks(args)
    print('>>> type args.mask4d_ls is %s len(mask4d)=%d', (type(args.mask4d_ls), len(args.mask4d_ls)))

    # load testset
    args.test_data = LoadTest(args.test_data_path)
    print('type of the args.test_data', type(args.test_data))


    print("=" * 50)

    print("Algorithm: {}".format(args.algorithm))
    print("Local batch size: {}".format(args.batch_size))
    print("Local steps: {}".format(args.local_steps))
    print("Local learing rate: {}".format(args.local_learning_rate))
    print("Local learing rate decay: {}".format(args.learning_rate_decay))
    if args.learning_rate_decay:
        print("Local learing rate decay gamma: {}".format(args.learning_rate_decay_gamma))
    print("Total number of clients: {}".format(args.num_clients))
    print("Clients join in each round: {}".format(args.join_ratio))
    print("Clients randomly join: {}".format(args.random_join_ratio))
    print("Client drop rate: {}".format(args.client_drop_rate))
    print("Client select regarding time: {}".format(args.time_select))
    if args.time_select:
        print("Time threthold: {}".format(args.time_threthold))
    # print("Running times: {}".format(args.times))
    # print("Dataset: {}".format(args.dataset))
    # print("Number of classes: {}".format(args.num_classes))
    print("Backbone: {}".format(args.model))
    print("Using device: {}".format(args.device))
    # print("Using DP: {}".format(args.privacy))
    # if args.privacy:
    #     print("Sigma for DP: {}".format(args.dp_sigma))
    print("Auto break: {}".format(args.auto_break))
    if not args.auto_break:
        print("Global rounds: {}".format(args.global_rounds))
    if args.device == "cuda":
        print("Cuda device id: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
    print("=" * 50)



    # if args.dataset == "mnist" or args.dataset == "fmnist":
    #     generate_mnist('../dataset/mnist/', args.num_clients, 10, args.niid)
    # elif args.dataset == "Cifar10" or args.dataset == "Cifar100":
    #     generate_cifar10('../dataset/Cifar10/', args.num_clients, 10, args.niid)
    # else:
    #     generate_synthetic('../dataset/synthetic/', args.num_clients, 10, args.niid)

    # with torch.profiler.profile(
    #     activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         torch.profiler.ProfilerActivity.CUDA],
    #     profile_memory=True, 
    #     on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
    #     ) as prof:
    # with torch.autograd.profiler.profile(profile_memory=True) as prof:
    run(args)

    
    # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
    # print(f"\nTotal time cost: {round(time.time()-total_start, 2)}s.")
