import torch
from tqdm import tqdm
import numpy as np


@torch.no_grad()
def DOA(R: torch.Tensor, user_emb: torch.FloatTensor, Q: torch.LongTensor):
    # 分子：对于每个知识点k，针对两两学生a和b，统计所有包含知识点k的习题中，(学生a做对，学生b做错)且(学生a的第k维比学生b的第k维大)的数量
    # 分母：对于每个知识点k，针对两两学生a和b，(学生a的第k维比学生b的第k维大)的数量
    cpt_count = Q.shape[1]
    doa = np.zeros((cpt_count,))
    for k in tqdm(range(cpt_count), desc='DOA'):
        # Denominator, Molecular = 0, 0  # Denominator分母，Molecular分子
        stu_cpt_master_diff_mat = (user_emb[:, k][:, None] - user_emb[:, k]) > 0 # 所有学生之间关于第k个知识点的熟练度差值矩阵，再选择符合条件的
        # Denominator = stu_cpt_master_diff_mat.sum()
        
        exercise_k = torch.argwhere(Q[:, k] == 1) # 选择出存在知识点k的习题id
        denominator, molecular = 0, 0
        for e in exercise_k:
            m = R[:, e.item()][:, None] - R[:, e.item()] # 所有学生之间关于习题e的答题情况差值矩阵
            stu_ans_ques_diff_mat1 =  m == 2 
            stu_ans_ques_diff_mat2 =  m == -2
            molecular += (stu_cpt_master_diff_mat & stu_ans_ques_diff_mat1).sum()
            denominator += (stu_cpt_master_diff_mat & (stu_ans_ques_diff_mat1 | stu_ans_ques_diff_mat2)).sum()

        doa_k = (molecular / denominator).item() if denominator > 0 else 0.0
        doa[k] = doa_k
    return doa


def DOAV1(R: torch.Tensor, user_emb: torch.FloatTensor, Q: torch.LongTensor):
    # 分子：对于每个知识点k，针对两两学生a和b，统计所有包含知识点k的习题中，(学生a做对，学生b做错)且(学生a的第k维比学生b的第k维大)的数量
    # 分母：对于每个知识点k，针对两两学生a和b，(学生a的第k维比学生b的第k维大)的数量
    cpt_count = Q.shape[1]
    doa = np.zeros((cpt_count,))
    for k in tqdm(range(cpt_count), desc='DOA'):
        Denominator, Molecular = 0, 0  # Denominator分母，Molecular分子
        stu_cpt_master_diff_mat = (user_emb[:, k][:, None] - user_emb[:, k]) > 0 # 所有学生之间关于第k个知识点的熟练度差值矩阵，再选择符合条件的
        Denominator = stu_cpt_master_diff_mat.sum()
        
        exercise_k = torch.argwhere(Q[:, k] == 1) # 选择出存在知识点k的习题id
        for e in exercise_k:
            m = R[:, e.item()][:, None] - R[:, e.item()] # 所有学生之间关于习题e的答题情况差值矩阵
            stu_ans_ques_diff_mat1 =  m == 2 
            stu_ans_ques_diff_mat2 =  m == -2
            tmp = (stu_cpt_master_diff_mat & stu_ans_ques_diff_mat1).sum()
            if tmp != 0:
                Molecular +=  tmp / (stu_cpt_master_diff_mat & (stu_ans_ques_diff_mat1 | stu_ans_ques_diff_mat2)).sum()

        doa_k = Molecular / Denominator
        doa[k] = doa_k.item()
    return doa
