#!/usr/bin/env python
# coding: utf-8

# In[1]:
import imp
import torch
from torch import nn
from torch.nn.modules.linear import Linear
from torch.utils.data import Dataset
import torch.optim as optim
import torchvision.models as models
import numpy as np
import os,sys,os.path
from tensorboardX import SummaryWriter
import pickle
from tqdm import tqdm
import copy
import gc
import torch.nn.functional as F
import time
# In[2]:


from option import args_parser
from utils import Accuracy,average_weights
from sampling import LocalDataset, LocalDataloaders , partition_data
from finch import FINCH

# In[3]:

torch.set_default_dtype(torch.float64)
print(torch.__version__)
torch.cuda.is_available()
device = torch.device("cuda:0")
print(device)
args = args_parser()
np.set_printoptions(threshold=np.inf)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'


# In[4]:
args = args_parser()
args.num_clients = 10
args.code_len = 32
args.batch_size = 64

# In[5]:

class net(nn.Module):
    def __init__(self,
                 code_length=32, 
                 num_classes = 10,
                 ):
        super(net,self).__init__()
        self.code_length = code_length
        self.num_classes = num_classes  
        self.feature_extractor = models.resnet18(num_classes=self.code_length)

        self.l1 = nn.Linear(self.code_length, self.code_length)
        self.l2 = nn.Linear(self.code_length, 256)
        self.l3 = nn.Linear(256, self.num_classes)

    def forward(self,x): #x = [batch,time,freq]
        x = F.relu(self.feature_extractor(x))
        h = x.squeeze()
        x = self.l1(h)
        x = F.relu(x)
        x = self.l2(x)
        y = self.l3(x)

        return h, x, y

global_model = net(code_length=32, num_classes = 10)
print('# model parameters:', sum(param.numel() for param in global_model.parameters()))
global_model = nn.DataParallel(global_model)
global_model.to(device)
# In[8]:
train_dataset, testset, dict_users, dict_users_test = partition_data(n_users = args.num_clients, alpha=5,rand_seed = 0, dataset='SVHN')
# In[9]:
Loaders_train = LocalDataloaders(train_dataset,dict_users,args.batch_size,ShuffleorNot = True,frac=0.1)
Major_classes = []
Counts = []
Available_labels = []
for idx in range(args.num_clients):
    available_labels = []
    counts = [0]*10
    for batch_idx,(X,y) in enumerate(Loaders_train[idx]):
        batch = len(y)
        y = np.array(y)
        for i in range(batch):
            counts[int(y[i])] += 1
    print(counts)
    Counts.append(counts)
    for i in range(10):
        if counts[i] != 0: available_labels.append(i)
    Available_labels.append(available_labels)
# In[10]:
Loaders_test = LocalDataloaders(testset, dict_users_test, args.batch_size, ShuffleorNot = True,frac=0.2)
Major_classes = []
for idx in range(args.num_clients):
    counts = [0]*10
    for batch_idx,(X,y) in enumerate(Loaders_test[idx]):
        batch = len(y)
        y = np.array(y)
        for i in range(batch):
            counts[int(y[i])] += 1
    print(counts)  
# In[11]:
logger = SummaryWriter('./logs')
checkpoint_dir = './checkpoint/'+ args.dataset + '/'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
with open(checkpoint_dir+'args.pkl', 'wb') as fp:
    pickle.dump(args, fp)
print('Data and model loaded')
print('Checkpoint dir:', checkpoint_dir)
# In[12]:
for m in global_model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))

class LocalUpdate(object):
    """
    This class is for train the local model with input global model(copied) and output the updated weight
    args: argument 
    Loader_train,Loader_val,Loaders_test: input for training and inference
    user: the index of local model
    idxs: the index for data of this local model
    logger: log the loss and the process
    """
    def __init__(self, index, args, Loader_train,available_labels,Loader_test,idxs, logger, code_length, num_classes, device):
        self.index = index
        self.args = args
        self.logger = logger
        self.trainloader = Loader_train
        self.testloader = Loader_test
        self.idxs = idxs
        self.ce = nn.CrossEntropyLoss() 
        self.device = device
        self.code_length = code_length
        self.cos = torch.nn.CosineSimilarity(dim=-1)
        self.model  = net(32,num_classes).to(device)
        self.model = nn.DataParallel(self.model).to(device)
        self.previous_model = copy.deepcopy(self.model).eval()
        self.early_stop = 20 
        self.latent_layer_idx = -1
        self.loss_func=nn.CrossEntropyLoss().to(device)
        self.available_labels = available_labels
        self.gen_batch_size = 32
        self.batch_size = 64
        
    def update_weights_Gen(self, global_round,regularization=True):
        self.model.to(self.device)
        self.model.train()
        epoch_loss = []
        optimizer = optim.Adam(self.model.parameters(),lr=self.args.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.args.lr_sh_rate, gamma=0.5)


        for iter in range(self.args.local_ep):
            epoch_loss_collector = []
            for batch_idx, (X, y) in enumerate(self.trainloader):
                X = X.to(self.device).double()
                y = y.to(self.device).double()
                optimizer.zero_grad()

                X.requires_grad = False
                y.requires_grad = False
                y = y.long()

                _, pro1, out = self.model(X)
                _, pro2, _ = global_model(X)
                if len(out.shape) == 1:
                    out = torch.unsqueeze(out, dim=0)

                posi = self.cos(pro1, pro2)
                logits = posi.reshape(-1, 1)

                _, pro3, _ = self.previous_model(X)
                nega = self.cos(pro1, pro3)
                logits = torch.cat((logits, nega.reshape(-1,1)), dim=1)


                logits /= self.args.temp
                labels = torch.zeros(X.size(0)).cuda().long()

                loss2 = self.args.moon_mu * self.loss_func(logits, labels)
                loss1 = self.loss_func(out, y)
                loss = loss1 + loss2
                loss.backward()
                optimizer.step()
    
                if batch_idx % 10 == 0:
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(X),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                self.logger.add_scalar('loss', loss.item())

                epoch_loss_collector.append(loss.item())

            epoch_loss0 = sum(epoch_loss_collector) / len(epoch_loss_collector)
            epoch_loss.append(epoch_loss0)
        return sum(epoch_loss) / len(epoch_loss)

   
    def exp_lr_scheduler(self, epoch, decay=0.98, init_lr=0.1, lr_decay_epoch=1):
        """Decay learning rate by a factor of 0.95 every lr_decay_epoch epochs."""
        lr= max(1e-4, init_lr * (decay ** (epoch // lr_decay_epoch)))
        return lr
        
    def test_accuracy(self):
        self.model.eval()
        accuracy = 0
        cnt = 0
        for batch_idx, (X, y) in enumerate(self.testloader):
            X = X.to(self.device).double()
            y = y.to(self.device).double()
            _, _, p = self.model(X)
            y_pred = p.argmax(1)
            accuracy += Accuracy(y,y_pred)
            cnt += 1
        return accuracy/cnt

    def load_model(self,global_weights):
        self.model.load_state_dict(global_weights)
# In[15]:
global_weights = global_model.state_dict()
# In[16]:
# training
args.num_epochs = 50
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
LocalModels = []
for idx in range(args.num_clients):
    LocalModels.append(LocalUpdate(idx, args,Loaders_train[idx], Available_labels[idx], Loaders_test[idx], idxs=dict_users[idx], 
                                   logger=logger, code_length = args.code_len, num_classes = 10, device=device))
# In[19]:
loader_test = LocalDataloaders(testset, dict_users_test, args.batch_size, ShuffleorNot = False,frac=1)
test_loader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=True)


for epoch in tqdm(range(args.num_epochs)):
    test_accuracy = 0
    begin_time = time.time()

    Knowledges = []
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {epoch+1} |\n')
    global_model.train()
    m = max(int(args.sampling_rate * args.num_clients), 1)
    idxs_users = np.random.choice(range(args.num_clients), m, replace=False)
    train_loss = 0
    for idx in idxs_users:
        LocalModels[idx].load_model(global_weights)
        loss = LocalModels[idx].update_weights_Gen(global_round=epoch, regularization = True)
        train_loss += loss
        acc = LocalModels[idx].test_accuracy()
        test_accuracy += acc

    #####aggregate_nets######
    online_clients = idxs_users
    global_w = global_model.state_dict()

    for k in global_w.keys():
        global_w[k] = torch.stack([LocalModels[i].model.state_dict()[k].float() for i in online_clients], 0).mean(0)
    global_model.load_state_dict(global_w)

    for idx in online_clients:
        LocalModels[idx].previous_model.load_state_dict(LocalModels[idx].model.state_dict())

    for _, cur_net in enumerate(LocalModels):
        cur_net.model.load_state_dict(global_model.state_dict())
    #####################################################################

    global_model.eval()
    accuracy = 0
    cnt = 0
    for batch_idx, (X, y) in enumerate(test_loader):
        X = X.to(device).double()
        y = y.to(device).double()
        _, _, p = global_model(X)
        y_pred = p.argmax(1)
        accuracy += Accuracy(y,y_pred)
        cnt += 1
    
    print('average test accuracy:', test_accuracy / args.num_clients)

    end_time = time.time()
    training_time  = end_time - begin_time
    print('training time: ', training_time)

    print('global test accuracy: ', accuracy/cnt)


    # global_model.eval()
    # test_loss = 0
    # correct = 0
    # with torch.no_grad():
    #     for index, (data, target) in enumerate(test_loader):
    #         data, target = data.to(device).double(), target.to(device).double()
    #         _, _, output = global_model(data)
    #         test_loss += F.cross_entropy(output, target.long(), reduction='sum').item()
    #         pred = output.data.max(dim=1, keepdim=True)[1]
    #         correct += pred.eq(target.data.view_as(pred)).long().to(device).sum()
    # test_loss /= len(test_loader.dataset)
    # accuracy = correct.item()/ len(test_loader.dataset)

    # train_loss /= args.num_clients

    # print("Round {:3d}, Testing accuracy:{:.4f}".format(i + 1, accuracy))
    # print("Train_loss:{:.5f}, Test_loss:{:.5f}".format(train_loss, test_loss))
    # print("-" * 100)

