import torch
from utils import compute_loss, compute_auc, zero_copy, compute_mr, inference_personal


def local_fedAvg(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, models, predictors, personal_epoch = 5, personal_lr = 0.01):
    for user_index in train_gs:
        models[user_index].load_state_dict(global_model.state_dict())
        predictors[user_index].load_state_dict(global_pred.state_dict())

    optimizers = dict()
    for user_index in train_gs:
        #optimizers[user_index] = torch.optim.Adam(itertools.chain(models[user_index].parameters(), predictors[user_index].parameters()), lr=personal_lr)
        optimizers[user_index] = torch.optim.SGD(itertools.chain(models[user_index].parameters(), predictors[user_index].parameters()), lr=personal_lr)



    for e in range(personal_epoch):
            for user_index in train_gs:
                    models[user_index].train()
                    predictors[user_index].train()

                    # forward
                    h = models[user_index](train_gs[user_index], train_gs[user_index].ndata['feat'])
                    pos_score = predictors[user_index](train_pos_gs[user_index], h)[list(range(len(train_pos_gs[user_index].edata['etype']))), train_pos_gs[user_index].edata['etype']]

                    neg_score = predictors[user_index](train_neg_gs[user_index], h)[list(range(len(train_neg_gs[user_index].edata['etype']))), train_neg_gs[user_index].edata['etype']]

                    loss = compute_loss(pos_score, neg_score)

                    # backward
                    optimizers[user_index].zero_grad()
                    loss.backward()

                    if fedtype == 'fedgate':
                        for mp,mdp in zip(models[user_index].parameters(), models_delta[user_index].parameters()):
                            mp.grad.data.add_(-mdp.data)
                        for pp,pdp in zip(predictors[user_index].parameters(), predictors_delta[user_index].parameters()):
                            pp.grad.data.add_(-pdp.data) 

                    optimizers[user_index].step()
                    

def test_global_model(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, global_model, global_pred):   
        global_model.eval()
        global_pred.eval()
        total_loss = 0
        total_AUC = 0
        total_pos_MR = 0
            
        total_train_loss = 0
        total_train_AUC = 0
        total_train_pos_MR = 0
        with torch.no_grad():
            for user_index in test_pos_gs:
                train_g = train_gs[user_index]
                train_pos_g = train_pos_gs[user_index]
                train_neg_g = train_neg_gs[user_index]

                h = global_model(train_g, train_g.ndata['feat'])
                
                pos_score = global_pred(train_pos_g, h)[list(range(len(train_pos_g.edata['etype']))), train_pos_g.edata['etype']]

                neg_score = global_pred(train_neg_g, h)[list(range(len(train_neg_g.edata['etype']))), train_neg_g.edata['etype']]

                    
                total_train_loss += compute_loss(pos_score, neg_score)
                total_train_AUC += compute_auc(pos_score, neg_score)
                total_train_pos_MR += compute_mr(global_pred(train_pos_g, h), train_pos_g)
                
                

                test_pos_g = test_pos_gs[user_index]
                test_neg_g = test_neg_gs[user_index]

                pos_score = global_pred(test_pos_g, h)[list(range(len(test_pos_g.edata['etype']))), test_pos_g.edata['etype']]

                neg_score = global_pred(test_neg_g, h)[list(range(len(test_neg_g.edata['etype']))), test_neg_g.edata['etype']]
                
                total_loss += compute_loss(pos_score, neg_score)
                total_pos_MR += compute_mr(global_pred(test_pos_g, h), test_pos_g)
                total_AUC += compute_auc(pos_score, neg_score)
                
        
        print('Global Test Loss', total_loss/len(test_pos_gs))
        print('Global Test AUC', total_AUC/len(test_pos_gs))
        print('Global Test Positive MR', float(total_pos_MR / len(test_pos_gs)))
        
        return float(total_train_loss / len(train_pos_gs)), total_train_AUC / len(train_pos_gs), float(total_train_pos_MR / len(train_pos_gs)), float(total_loss / len(test_pos_gs)), total_AUC / len(test_pos_gs), float(total_pos_MR / len(test_pos_gs))

        
        
                 
                    
def test_local_models(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, models, predictors):
            total_loss = 0
            total_AUC = 0
            total_pos_MR = 0
            total_MRR = 0
            total_weighted_MRR = 0
            total_weights = 0
            
            total_train_loss = 0
            total_train_AUC = 0
            total_train_pos_MR = 0
            test_size = 0
            
            with torch.no_grad():
                for user_index in test_pos_gs:
                    models[user_index].eval()
                    predictors[user_index].eval()

                    train_g = train_gs[user_index]
                    train_pos_g = train_pos_gs[user_index]
                    train_neg_g = train_neg_gs[user_index]

                    h = models[user_index](train_g, train_g.ndata['feat'])
                    
                    #for test train data
                    
                    pos_score = predictors[user_index](train_pos_g, h)[list(range(len(train_pos_g.edata['etype']))), train_pos_g.edata['etype']]

                    neg_score = predictors[user_index](train_neg_g, h)[list(range(len(train_neg_g.edata['etype']))), train_neg_g.edata['etype']]

                    
                    total_train_loss += compute_loss(pos_score, neg_score)
                    total_train_AUC += compute_auc(pos_score, neg_score)
                    total_train_pos_MR += compute_mr(predictors[user_index](train_pos_g, h), train_pos_g)
                    


                    test_pos_g = test_pos_gs[user_index]
                    test_neg_g = test_neg_gs[user_index]

                    pos_score = predictors[user_index](test_pos_g, h)[list(range(len(test_pos_g.edata['etype']))), test_pos_g.edata['etype']]

                    neg_score = predictors[user_index](test_neg_g, h)[list(range(len(test_neg_g.edata['etype']))), test_neg_g.edata['etype']]
                    
                    total_loss += compute_loss(pos_score, neg_score)
                    
                    total_AUC += compute_auc(pos_score, neg_score)
                    
                    total_pos_MR += compute_mr(predictors[user_index](test_pos_g, h), test_pos_g)
                    total_MRR += compute_mrr(pred(test_pos_g, h), test_pos_g)


            #print('Local Train AUC', total_train_AUC / len(train_pos_gs))
            #print('Local Train Positive MR', float(total_train_pos_MR / len(train_pos_gs)))
            
            print('Local Test Loss', total_loss / len(test_pos_gs))
            print('Local Test AUC', total_AUC / len(test_pos_gs))
            print('Local Test Positive MR', float(total_pos_MR / len(test_pos_gs)))
            return float(total_train_loss / len(train_pos_gs)), total_train_AUC / len(train_pos_gs), float(total_train_pos_MR / len(train_pos_gs)), total_loss / len(test_pos_gs), total_AUC / len(test_pos_gs), float(total_pos_MR / len(test_pos_gs))
        
        
                   
                    
def test_personal_models(train_gs, train_pos_gs, train_neg_gs, test_pos_gs, test_neg_gs, models, predictors, personal_models, personal_predictors, alphas):
            total_loss = 0
            total_AUC = 0
            total_pos_MR = 0
            
            total_train_loss = 0
            total_train_AUC = 0
            total_train_pos_MR = 0
            
            with torch.no_grad():
                for user_index in test_pos_gs:
                    models[user_index].eval()
                    predictors[user_index].eval()

                    train_g = train_gs[user_index]
                    train_pos_g = train_pos_gs[user_index]
                    train_neg_g = train_neg_gs[user_index]

                    #h = models[user_index](train_g, train_g.ndata['feat'])
                    #h = personal_models[user_index](train_gs[user_index], train_gs[user_index].ndata['feat'])
                    h = inference_personal(personal_models[user_index], models[user_index], alphas[user_index], train_gs[user_index], train_gs[user_index].ndata['feat'])

                    pos_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   train_pos_gs[user_index], h)[list(range(len(train_pos_gs[user_index].edata['etype']))), train_pos_gs[user_index].edata['etype']]
                    
                    neg_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   train_neg_gs[user_index], h)[list(range(len(train_neg_gs[user_index].edata['etype']))), train_neg_gs[user_index].edata['etype']]
     

                    total_train_loss += compute_loss(pos_score, neg_score)
                    total_train_AUC += compute_auc(pos_score, neg_score)
                    total_train_pos_MR += compute_mr(inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], train_pos_gs[user_index], h), train_pos_g)


                    test_pos_g = test_pos_gs[user_index]
                    test_neg_g = test_neg_gs[user_index]

                    pos_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   test_pos_gs[user_index], h)[list(range(len(test_pos_g.edata['etype']))), test_pos_g.edata['etype']]

                    neg_score = inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], 
                                                   test_neg_gs[user_index], h)[list(range(len(test_neg_g.edata['etype']))), test_neg_g.edata['etype']]

                    
                    total_loss += compute_loss(pos_score, neg_score)
                    total_AUC += compute_auc(pos_score, neg_score)
                    
                    total_pos_MR += compute_mr(inference_personal(personal_predictors[user_index], predictors[user_index], alphas[user_index], test_pos_gs[user_index], h), test_pos_g)
                    


            #print('Personal Train AUC', total_train_AUC / len(train_pos_gs))
            #print('Personal Train Positive MR', float(total_train_pos_MR / len(train_pos_gs)))
                
            print('Personal Test Loss', total_loss / len(test_pos_gs))
            print('Personal Test AUC', total_AUC / len(test_pos_gs))
            print('Personal Test Positive MR', float(total_pos_MR / len(test_pos_gs)))
            
            
            return float(total_train_loss / len(train_pos_gs)), total_train_AUC / len(train_pos_gs), float(total_train_pos_MR / len(train_pos_gs)), total_loss / len(test_pos_gs), total_AUC / len(test_pos_gs), float(total_pos_MR / len(test_pos_gs))


        