"""
    reference: https://github.com/bigdata-ustc/EduCDM/blob/main/examples/NCDM/NCDM.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback
from torch import Tensor
from typing import Union
from lib.components import PosMLP, PosLinear

import pandas as pd
from evaluate import doa_report

class NCDM(BaseModel):
    expose_default_cfg = {
        'dnn_units': [512, 256],
        'dropout_rate': 0.5,
        'activation': 'sigmoid',
        'disc_scale': 10
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)

    def build_cfg(self):
        self.n_user = self.data_cfg['dt_info']['user_count']
        self.n_item = self.data_cfg['dt_info']['item_count']
        self.n_cpt = self.data_cfg['dt_info']['cpt_count']

    def build_model(self):
        # prediction sub-net
        self.student_emb = nn.Embedding(self.n_user, self.n_cpt)
        self.k_difficulty = nn.Embedding(self.n_item, self.n_cpt)
        self.e_difficulty = nn.Embedding(self.n_item, 1)
        self.pd_net = PosMLP(
            input_dim=self.n_cpt, output_dim=1, activation=self.model_cfg['activation'],
            dnn_units=self.model_cfg['dnn_units'], dropout_rate=self.model_cfg['dropout_rate']
        )

    def forward(self, users, items):
        # before prednet
        stu_emb = self.student_emb(users)
        stat_emb = torch.sigmoid(stu_emb)
        k_difficulty = torch.sigmoid(self.k_difficulty(items))
        e_difficulty = torch.sigmoid(self.e_difficulty(items))  * self.model_cfg['disc_scale']
        # prednet
        input_knowledge_point = self.Q_mat[items]
        input_x = e_difficulty * (stat_emb - k_difficulty) * input_knowledge_point
        pd = self.pd_net(input_x).sigmoid()
        return pd

    def fit(self, train_dataset, val_dataset=None,callbacks: Sequence[Callback]=()):
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        self.Q_mat = train_dataset.Q_mat.to(self.device)

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        callback_list.on_train_begin()
        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                users = batch[:, 0]
                items = batch[:, 1]
                labels = batch[:, 2].float()
                pd = model(users, items).flatten()
                loss = F.binary_cross_entropy(input=pd, target=labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                logs['loss'][batch_id] = loss.item()

            for name in logs: logs[name] = float(np.nanmean(logs[name]))

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})
            # logs.update({f"official_doa": self.get_doa(gt=False)})
            logs.update({f"official_doa_gt": self.get_doa(gt=True)})
            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def get_user_emb(self, users=None):
        user_emb = None
        if users is None:
            user_emb = self.student_emb.weight
        else:
            user_emb = self.student_emb(users)
        return user_emb

    @torch.no_grad()
    def get_item_emb(self, items=None):
        item_emb = None
        if items is None:
            item_emb = self.k_difficulty.weight
        else:
            item_emb = self.k_difficulty(items)
        return item_emb

    @torch.no_grad()
    def evaluate(self, loader):
        self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.forward(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result

    def get_doa(self, gt=False):
        user_emb = tensor2npy(self.get_user_emb())

        if not gt:
            raise NotImplementedError
            df_Q = self.df_Q_eval
            df_interact = self.df_interact
        else:
            df_Q = self.df_Q_final
            df_interact = self.df_interact_final
        
        df_user = pd.DataFrame.from_dict({uid:str(list(user_emb[uid, :])) for uid in range(user_emb.shape[0])}, orient='index', columns=['theta']).reset_index().rename(columns={'index': 'uid'})
        df_user['theta'] = df_user['theta'].apply(lambda x: eval(x))
        df = df_interact.merge(df_user, on=['uid']).merge(df_Q, on=['iid'])
        df = df.rename(columns={"uid": 'user_id', 'iid':'item_id', 'label': 'score'})
        official_doa = doa_report(df)
        return float(official_doa['doa'])

class NCDM2(NCDM):
    expose_default_cfg = {
        'dnn_units': [512, 256],
        'dropout_rate': 0.5,
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)


    def build_model(self):
        # prediction sub-net
        self.student_emb = nn.Embedding(self.n_user, self.n_cpt)
        self.k_difficulty = nn.Embedding(self.n_item, self.n_cpt)
        self.e_difficulty = nn.Embedding(self.n_item, 1)
        self.prednet_len1, self.prednet_len2 = self.model_cfg['dnn_units']
        self.prednet_full1 = PosLinear(self.n_cpt, self.prednet_len1)
        self.drop_1 = nn.Dropout(p=self.model_cfg['dropout_rate'])
        self.prednet_full2 = PosLinear(self.prednet_len1, self.prednet_len2)
        self.drop_2 = nn.Dropout(p=self.model_cfg['dropout_rate'])
        self.prednet_full3 = PosLinear(self.prednet_len2, 1)

    def forward(self, users, items):
        # before prednet
        stu_emb = self.student_emb(users)
        stat_emb = torch.sigmoid(stu_emb)
        k_difficulty = torch.sigmoid(self.k_difficulty(items))
        e_difficulty = torch.sigmoid(self.e_difficulty(items))  # * 10
        # prednet
        input_knowledge_point = self.Q_mat[items]
        input_x = e_difficulty * (stat_emb - k_difficulty) * input_knowledge_point
        input_x = self.drop_1(torch.sigmoid(self.prednet_full1(input_x)))
        input_x = self.drop_2(torch.sigmoid(self.prednet_full2(input_x)))
        output_1 = torch.sigmoid(self.prednet_full3(input_x))
        return output_1.view(-1)
