import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback
import torch.autograd as autograd
import pandas as pd
from evaluate import doa_report

class DINA(BaseModel):
    expose_default_cfg = {
        "step": 0,
        "max_step": 1000,
        "max_slip": 0.4,
        "max_guess": 0.4,
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)
    
    def build_cfg(self):
        self.n_user = self.data_cfg['dt_info']['user_count']
        self.n_item = self.data_cfg['dt_info']['item_count']

        self.step = 0
        self.max_step = self.model_cfg['max_step']
        self.max_slip =  self.model_cfg['max_slip']
        self.max_guess =  self.model_cfg['max_guess']
        self.emb_dim = self.data_cfg['dt_info']['cpt_count']

    def build_model(self):
        self.guess = nn.Embedding(self.n_item, 1)
        self.slip = nn.Embedding(self.n_item, 1)
        self.theta = nn.Embedding(self.n_user, self.emb_dim)

    def forward(self, user, item):
        theta = self.theta(user)
        slip = torch.squeeze(torch.sigmoid(self.slip(item)) * self.max_slip)
        guess = torch.squeeze(torch.sigmoid(self.guess(item)) * self.max_guess)

        knowledge = self.Q_mat[item]
        if self.training:
            n = torch.sum(knowledge * (torch.sigmoid(theta) - 0.5), dim=1)
            t, self.step = max((np.sin(2 * np.pi * self.step / self.max_step) + 1) / 2 * 100,
                               1e-6), self.step + 1 if self.step < self.max_step else 0
            return torch.sum(
                torch.stack([1 - slip, guess]).T * torch.softmax(torch.stack([n, torch.zeros_like(n)]).T / t, dim=-1),
                dim=1
            )
        else:
            n = torch.prod(knowledge * (theta >= 0) + (1 - knowledge), dim=1)
            return (1 - slip) ** n * guess ** (1 - n)

    def fit(self, train_dataset, val_dataset=None,callbacks: Sequence[Callback]=()):
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        self.Q_mat = train_dataset.Q_mat.to(self.device)

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        callback_list.on_train_begin()
        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                users = batch[:, 0]
                items = batch[:, 1]
                labels = batch[:, 2].float()
                pd = model(users, items).flatten()
                loss = F.binary_cross_entropy(input=pd, target=labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                logs['loss'][batch_id] = loss.item()

            for name in logs: logs[name] = float(np.nanmean(logs[name]))

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})
            logs.update({f"official_doa_gt": self.get_doa(gt=True)})
            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def get_user_emb(self, users=None):
        user_emb = None
        if users is None:
            user_emb = self.theta.weight
        else:
            user_emb = self.theta(users)
        return user_emb

    @torch.no_grad()
    def evaluate(self, loader):
        self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.forward(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result
    

    def get_doa(self, gt=False):
        user_emb = tensor2npy(self.get_user_emb())
        if not gt:
            raise NotImplementedError
            df_Q = self.df_Q_eval
            df_interact = self.df_interact
        else:
            df_Q = self.df_Q_final
            df_interact = self.df_interact_final
        
        df_user = pd.DataFrame.from_dict({uid:str(list(user_emb[uid, :])) for uid in range(user_emb.shape[0])}, orient='index', columns=['theta']).reset_index().rename(columns={'index': 'uid'})
        df_user['theta'] = df_user['theta'].apply(lambda x: eval(x))
        df = df_interact.merge(df_user, on=['uid']).merge(df_Q, on=['iid'])
        df = df.rename(columns={"uid": 'user_id', 'iid':'item_id', 'label': 'score'})
        official_doa = doa_report(df)
        return float(official_doa['doa'])

class STEFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)


class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = STEFunction.apply(x)
        return x


class STEDINA(DINA):
    expose_default_cfg = {
        "max_slip": 0.4,
        "max_guess": 0.4,
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)
        self.sign = StraightThroughEstimator()

    def build_cfg(self):
        self.n_user = self.data_cfg['dt_info']['user_count']
        self.n_item = self.data_cfg['dt_info']['item_count']

        self.max_slip =  self.model_cfg['max_slip']
        self.max_guess =  self.model_cfg['max_guess']
        self.emb_dim = self.data_cfg['dt_info']['cpt_count']

    def forward(self, user, item):
        theta = self.sign(self.theta(user))
        knowledge = self.Q_mat[item]
        slip = torch.squeeze(torch.sigmoid(self.slip(item)) * self.max_slip)
        guess = torch.squeeze(torch.sigmoid(self.guess(item)) * self.max_guess)
        mask_theta = (knowledge == 0) + (knowledge == 1) * theta
        n = torch.prod((mask_theta + 1) / 2, dim=-1)
        return torch.pow(1 - slip, n) * torch.pow(guess, 1 - n)
