from functools import partial

import torch
import torch.nn as nn

from timm.models.vision_transformer import PatchEmbed, Block

from util.pos_embed import get_2d_sincos_pos_embed
import model.models_mae as models_mae
from model.fusion import guide_fusion

class ViT_MAE(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=256, dropout_rate=0.1, checkpoint_path=None):
        super().__init__()
        self.model_ref = models_mae.__dict__['mae_vit_base_patch16']()

        ## Loading pretrained models on ImageNet by default
        if checkpoint_path is None:
            checkpoint_path = 'model/checkpoints/mae_pretrain_vit_base.pth'
            weights = torch.load(checkpoint_path)['model']
            self.model_ref.load_state_dict(weights,strict=False)
        
    def forward(self, imgs, ref_imgs=None, mask_ratio_ref=None, mode='encode_decode'):
        assert mode in ['encode_decode','only_encode']
        if mode == 'encode_decode':
            assert mask_ratio_ref is not None
            latent_ref, mask_ref, ids_restore_ref = self.model_ref.forward_encoder(imgs, mask_ratio_ref, mode)
            pred_ref = self.model_ref.forward_decoder(latent_ref, ids_restore_ref)

            target_ref = self.model_ref.patchify(ref_imgs)
            if self.model_ref.norm_pix_loss:
                mean_ref = target_ref.mean(dim=-1, keepdim=True)
                var_ref = target_ref.var(dim=-1, keepdim=True)
                target_ref = (target_ref - mean_ref) / (var_ref + 1.e-6)**.5
            loss_ref = (pred_ref - target_ref) ** 2
            loss_ref = loss_ref.mean(dim=-1)
            loss_ref = (loss_ref*mask_ref).sum() / mask_ref.sum()

            
            return loss_ref
        else:
            # mode=='finetune' to disable random masking in forward_encoder
            latent_ref = self.model_ref.forward_encoder(imgs, mask_ratio_ref, mode)


            return latent_ref

class Fusion(nn.Module):
    def __init__(self, embed_dim=768, hidden_dim=256, dropout_rate=0.1) -> None:
        super().__init__()
        self.fusion = guide_fusion(embed_dim=embed_dim)
        self.regression = nn.Sequential(
            # nn.BatchNorm1d(embed_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(embed_dim, hidden_dim),
            # nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, feats_ref, feats_dist):
        B, view_length, C = feats_ref.shape
        guide = feats_ref.max(dim=1)[0]
        feat = self.fusion(guide, feats_dist)
        pred_mos = self.regression(feat)
        return pred_mos


