import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
from collections import defaultdict
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from lib.utils import tensor2npy
from .BaseModel import BaseModel
from lib.callback import CallbackList, Callback


class IRT(BaseModel):
    """
        第一种: fix_a = True, fix_c = True
        第二种: fix_a = False, fix_c = True
        第三种: fix_a = False, fix_c = False
    """
    expose_default_cfg = {
        "a_range": -1.0, # disc range
        "diff_range": -1.0, # diff range
        "fix_a": False,
        "fix_c": True,
    }
    def __init__(self, cfg, xavier_init=True):
        super().__init__(cfg, xavier_init)

    def build_cfg(self):
        if self.model_cfg['a_range'] and self.model_cfg['a_range']  < 0: self.model_cfg['a_range'] = None
        if self.model_cfg['a_range'] and self.model_cfg['diff_range'] < 0: self.model_cfg['diff_range'] = None

        self.n_user = self.data_cfg['dt_info']['user_count']
        self.n_item = self.data_cfg['dt_info']['item_count']

        # 确保c固定时，a一定不能固定
        if self.model_cfg['fix_c'] is False: assert self.model_cfg['fix_a'] is False


    def build_model(self):
        self.theta = nn.Embedding(self.n_user, 1) # student ability
        self.a = 0.0 if self.model_cfg['fix_a'] else nn.Embedding(self.n_item, 1) # exer discrimination
        self.b = nn.Embedding(self.n_item, 1) # exer difficulty
        self.c = 0.0 if self.model_cfg['fix_c'] else nn.Embedding(self.n_item, 1)

    def forward(self, user_idx, item_idx):
        theta = self.theta(user_idx)
        a = self.a(item_idx)
        b = self.b(item_idx)
        c = self.c if self.model_cfg['fix_c'] else self.c(item_idx).sigmoid()

        if self.model_cfg['diff_range'] is not None:
            b = self.model_cfg['diff_range'] * (torch.sigmoid(b) - 0.5)
        if self.model_cfg['a_range'] is not None:
            a = self.model_cfg['a_range'] * torch.sigmoid(a)
        else:
            a = F.softplus(a) # 让区分度大于0，保持单调性假设
        if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b):  # pragma: no cover
            raise ValueError('ValueError:theta,a,b may contains nan!  The diff_range or a_range is too large.')
        return self.irf(theta, a, b, c)

    @staticmethod
    def irf(theta, a, b, c, D=1.702):
        return c + (1 - c) / (1 + torch.exp(-D * a * (theta - b)))


    def fit(self, train_dataset, val_dataset=None,callbacks: Sequence[Callback]=()):
        lr = self.train_cfg['lr']
        epoch_num = self.train_cfg['epoch_num']
        batch_size = self.train_cfg['batch_size']
        num_workers = self.train_cfg['num_workers']
        eval_batch_size = self.train_cfg['eval_batch_size']
        weight_decay = self.train_cfg['weight_decay']
        eps = self.train_cfg['eps']

        model = self.train()
        optimizer = self._get_optim(optimizer=self.train_cfg['optim'], lr=lr, weight_decay=weight_decay, eps=eps)
        self.optimizer = optimizer

        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
        if val_dataset is not None:
            val_loader = DataLoader(val_dataset, shuffle=False, batch_size=eval_batch_size, num_workers=num_workers)

        callback_list = CallbackList(callbacks=callbacks, model=model, logger=self.logger)
        callback_list.on_train_begin()
        for epoch in range(epoch_num):
            callback_list.on_epoch_begin(epoch + 1)
            logs = defaultdict(lambda: np.full((len(train_loader),), np.nan, dtype=np.float32))
            for batch_id, batch in enumerate(
                    tqdm(train_loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[EPOCH={:03d}]".format(epoch + 1))
            ):
                batch = batch.to(self.device)
                users = batch[:, 0]
                items = batch[:, 1]
                labels = batch[:, 2].float()
                pd = model(users, items).flatten()
                loss = F.binary_cross_entropy(input=pd, target=labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                logs['loss'][batch_id] = loss.item()

            for name in logs: logs[name] = float(np.nanmean(logs[name]))

            if val_dataset is not None:
                val_metrics = self.evaluate(val_loader)
                logs.update({f"val_{metric}": val_metrics[metric] for metric in val_metrics})
            
            callback_list.on_epoch_end(epoch + 1, logs=logs)
            if self.share_obj_dict.get('stop_training', False):
                break

        callback_list.on_train_end()

    @torch.no_grad()
    def evaluate(self, loader):
        self.eval()
        pd_list = list(range(len(loader)))
        gt_list = list(range(len(loader)))
        for idx, batch in enumerate(tqdm(loader, ncols=self.environ_cfg['tqdm_ncols'], desc="[PREDICT]")):
            batch = batch.to(self.device)
            u = batch[:, 0]
            i = batch[:, 1]
            r = batch[:, 2]
            pd_list[idx] = self.forward(u, i).flatten()
            gt_list[idx] = r.flatten()
        y_pd = tensor2npy(torch.hstack(pd_list))
        y_gt = tensor2npy(torch.hstack(gt_list))
        eval_result = {
            metric: self._get_metrics(metric)(y_gt, y_pd) for metric in self.eval_cfg['metrics']
        }
        return eval_result
