import torch 
import torch.nn as nn
import torch.linalg as LA 
import torch.nn.functional as F 
from typing import Dict, Tuple
import numpy as np

from src.own.models.base_ae import AbstractAE
from src.own.shared.constants import * 
from src.own.models.utils import init_embeddings, weights_init
from src.own.models.nearest_embed import NearestEmbed
from src.own.models.tpr import BaseTPREncoder

def mse_recon_loss_fn(x_hat, x, logging: bool=True, postprocessing=F.sigmoid):
    if logging:  
        return F.mse_loss(postprocessing(x_hat), x, reduction='sum').div(x.shape[0])
    return F.mse_loss(postprocessing(x_hat), x)

def bce_recon_loss_fn(x_hat, x, logging: bool=True, logits: bool=True): 
    if logits: 
        fn = F.binary_cross_entropy_with_logits
    else: 
        fn = F.binary_cross_entropy
    if logging: 
        return fn(x_hat, x, reduction='sum').div(x.shape[0])
    return fn(x_hat, x)
    
class Quantiser(nn.Module): 
    def __init__(self, n_fillers: int, filler_embed_dim: int, init_embeddings_orth: bool, 
                 embedding_postproc: str, lambdas_loss: Dict={VQ_PENALTY: 1, COMMITMENT_PENALTY: 0.5, 
                                                         ORTH_PENALTY_FILLER: 0}) -> None: 
        super().__init__()
        self.filler_embeddings = NearestEmbed(num_embeddings=n_fillers, 
                                       embedding_dim=filler_embed_dim, 
                                       init_orth=init_embeddings_orth)
        self.filler_embed_dim = filler_embed_dim
        self.embedding_postproc = embedding_postproc
        self.lambdas_loss = lambdas_loss 

    def make_state(self, approx_fillers: torch.Tensor, 
                   bound_fillers: torch.Tensor, bound_fillers_sg: torch.Tensor, 
                   filler_idxs: torch.Tensor) -> Dict: 
        return {'approx_fillers': approx_fillers, 
                'quantised_fillers': bound_fillers, 
                'quantised_fillers_sg': bound_fillers_sg,
                'idxs': filler_idxs}
    
    def forward(self, approx_fillers: torch.Tensor) -> Dict:
        N, n_roles = approx_fillers.shape[:2]
        bound_fillers_sg, filler_idxs = self.filler_embeddings(approx_fillers, weight_sg=True)
        bound_fillers, _ = self.filler_embeddings(approx_fillers.detach())

        bound_fillers_sg = bound_fillers_sg.view(N, n_roles, self.filler_embed_dim)
        filler_idxs = filler_idxs.view(N, -1)

        bound_fillers = bound_fillers.view(N, n_roles, self.filler_embed_dim)
        state = self.make_state(approx_fillers=approx_fillers, bound_fillers=bound_fillers,
                                bound_fillers_sg=bound_fillers_sg, 
                                filler_idxs=filler_idxs)
        loss = self.get_loss(bound_fillers=bound_fillers, approx_fillers=approx_fillers)
        return {'loss': loss, 'state': state}
    
    def get_loss(self, bound_fillers: torch.Tensor, approx_fillers: torch.Tensor) -> Dict: 
        vq_loss = torch.mean(
            torch.norm((bound_fillers - approx_fillers.detach())**2, 2, 1)
        )
        commit_loss = torch.mean(
            torch.norm((bound_fillers.detach() - approx_fillers)**2, 2, 1))
        #vq_loss = torch.mean(
        #    F.cosine_similarity(bound_fillers, approx_fillers.detach(), -1) # (N_{B}, N_{R}, D_{F}) -> (N_{B}, N_{R})
        #)
        #commit_loss = torch.mean(
        #    F.cosine_similarity(bound_fillers.detach(), approx_fillers, -1) # (N_{B}, N_{R}, D_{F}) -> (N_{B}, N_{R})
        #)

        # a simple thing that makes sense is placing a semi-orth penalty on the fillers bound to diff roles 
        
        orth_penalty_filler = BaseTPREncoder.get_semi_orth_penalty(self.filler_embeddings.weight.t())
        penalties = torch.stack([vq_loss, commit_loss, orth_penalty_filler], dim=0)
        lambdas = torch.tensor([self.lambdas_loss[VQ_PENALTY], self.lambdas_loss[COMMITMENT_PENALTY], 
                                self.lambdas_loss[ORTH_PENALTY_FILLER]]).to(device=penalties.device)
        quantisation_loss = torch.sum(lambdas*penalties)

        filler_rank = BaseTPREncoder.get_rank(self.filler_embeddings.weight)

        return {'quantiser_total_loss': quantisation_loss, 
                ORTH_PENALTY_FILLER: orth_penalty_filler, 
                VQ_PENALTY: vq_loss, 
                COMMITMENT_PENALTY: commit_loss,
                FILLER_RANK: filler_rank}
    

class SoftTPRAutoencoder(AbstractAE): 
    def __init__(self, encoder: nn.Module, decoder: nn.Module, 
                 n_roles: int, n_fillers: int, role_embed_dim: int, 
                 filler_embed_dim: int, 
                 lambdas_reg: Dict,
                 filler_postprocessing: str,
                 init_fillers_orth: bool,
                 init_roles_orth: bool, 
                 freeze_role_embeddings: bool, 
                 role_postprocessing: str,
                 weakly_supervised: bool, 
                 recon_loss_fn: str, 
                 use_cached: bool=True) -> None: 
        super().__init__()
        self.role_embeddings = nn.Embedding(num_embeddings=n_roles, 
                                                 embedding_dim=role_embed_dim).requires_grad_(not freeze_role_embeddings)
        self.quantiser = Quantiser(n_fillers=n_fillers, filler_embed_dim=filler_embed_dim, 
                                   init_embeddings_orth=init_fillers_orth, embedding_postproc=filler_postprocessing,
                                   lambdas_loss={VQ_PENALTY: lambdas_reg[VQ_PENALTY], 
                                                 COMMITMENT_PENALTY: lambdas_reg[COMMITMENT_PENALTY],
                                                 ORTH_PENALTY_FILLER: lambdas_reg[ORTH_PENALTY_FILLER]})
        
        self.n_roles = n_roles 
        self.n_fillers = n_fillers
        self.filler_embed_dim = filler_embed_dim
        self.role_embed_dim = role_embed_dim
        self.freeze_role_embeddings = freeze_role_embeddings
        self.role_postproc = role_postprocessing
        self.lambda_recon = lambdas_reg[RECON_PENALTY]
        self.lambda_orth_penalty_role = lambdas_reg[ORTH_PENALTY_ROLE]
        self.lambda_ws_recon = lambdas_reg[WS_RECON_LOSS_PENALTY]
        self.lambda_ws_embed = lambdas_reg[WS_EMBED_PENALTY]
        self.lambda_ws_argmax_embed = lambdas_reg[WS_ARGMAX_EMBED_PENALTY]
        self.lambda_ws_dis = lambdas_reg[WS_DIS_PENALTY]
        self.weakly_supervised = weakly_supervised
        self.use_cached = use_cached
        self.recon_loss_fn = recon_loss_fn
        
        self.encoder = encoder 
        self.decoder = decoder 
        self.embed_dim = role_embed_dim * filler_embed_dim

        if freeze_role_embeddings: 
            init_embeddings(init_orth=True, weights=self.role_embeddings.weight)
        else: 
            init_embeddings(init_orth=init_roles_orth, weights=self.role_embeddings.weight)

        weights_init(self.encoder.modules)
        weights_init(self.decoder.modules)

        self.kwargs_for_loading = {
            'n_roles': n_roles, 
            'n_fillers': n_fillers, 
            'role_embed_dim': role_embed_dim, 
            'filler_embed_dim': filler_embed_dim, 
            'lambdas_reg': lambdas_reg, 
            'filler_postprocessing': filler_postprocessing, 
            'init_fillers_orth': init_fillers_orth,
            'init_roles_orth': init_roles_orth,
            'freeze_role_embeddings': freeze_role_embeddings, 
            'role_postprocessing': role_postprocessing,
            'weakly_supervised': weakly_supervised,
            'recon_loss_fn': recon_loss_fn
        }

    def decode(self, x: torch.Tensor) -> torch.Tensor: 
        return self.decoder(x)

    def encode(self, x: torch.Tensor) -> torch.Tensor: 
        return self.encoder(x)
    
    def repn_fn(self, x: torch.Tensor, key: str='bound_fillers') -> torch.Tensor: 
        z = self.encode(x) 
        quantised_out = self.get_quantised(z) 
        if APPROX_FILLERS in key or BOUND_FILLERS in key:
            if APPROX_FILLERS in key: 
                fillers = quantised_out[APPROX_FILLERS] # (N_{B}, N_{R}, D_{F})
            elif BOUND_FILLERS in key: 
                fillers = quantised_out[BOUND_FILLERS]
            if CONCATENATED in key: 
                fillers = fillers.view(x.shape[0], -1) # (N_{B}, N_{R}*D_{F})
            return fillers
        if key == FILLER_IDXS: 
            return quantised_out[key].to(torch.float32) # (N_{B}, N_{R})
        if key == Z_TPR: 
            return quantised_out['z_tpr'] # (N_{B}, D_{F}*D_{R})
        if key == TPR_BINDINGS: 
            bindings = quantised_out['tpr_bindings_sg'].view(x.shape[0], self.n_roles, -1) # (N_{B}, N_{R}, D_{F}, D_{R}) -> (N_{B}, N_{R}, D_{F}*D_{R})
            if CONCATENATED in key: 
                bindings = bindings.view(x.shape[0], -1) # (N_{B}, N_{R}, D_{F}*D_{R})
        if key == Z_SOFT_TPR: 
            return z.view(-1, self.embed_dim) # soft approximation to a TPR (N_{B}, D_{F}*D_{R})
    
    def make_state(self, x: torch.Tensor, x_hat: torch.Tensor, z: torch.Tensor, 
                   z_tpr: torch.Tensor, bound_fillers: torch.Tensor, approx_fillers: torch.Tensor, 
                   filler_idxs: torch.Tensor) -> Dict: 
        return {'x': x, 
                'x_hat': x_hat, 
                'z': z, 
                'z_tpr': z_tpr, 
                'bound_fillers': bound_fillers, 
                'approx_fillers': approx_fillers,
                'filler_idxs': filler_idxs}
    
    def get_quantised(self, z: torch.Tensor) -> Dict: 
        # perform unbinding operation 
        # provided that self.role_embeddings is semi-orthogonal matrix \in \mathbb{R}^{D_{R} \times N_{R}}
        # then, it is left-invertible, and the left inverse is the transpose of self.role_embeddings
        if self.role_postproc is not None and not self.freeze_role_embeddings: 
            self.role_embeddings.weight.data = BaseTPREncoder.get_new_embeddings(self.role_embeddings.weight.t(), self.role_postproc).t()
        #print(f'Z shape is {z.shape}, role embedding shape {self.role_embeddings.weight.shape}')
        approx_fillers = torch.bmm(z, self.role_embeddings.weight.unsqueeze(0).expand(z.shape[0], -1, -1).permute(0, 2, 1)) 
        approx_fillers = approx_fillers.permute(0, 2, 1).contiguous()
        #print(f'Approx fillers shape {approx_fillers.shape}')
        # (N_{B}, D_{F}, D_{R}) @ (N_{B}, N_{R}, D_{R}) -> (N_{B}, D_{F}, N_{R}) -> (N_{B}, N_{R}, D_{F})
        
        quantiser_out = self.quantiser(approx_fillers)
        quantiser_state, quantiser_loss = quantiser_out['state'], quantiser_out['loss']
        bound_fillers, bound_fillers_sg, approx_fillers, filler_idxs = (quantiser_state['quantised_fillers'], 
                                                                        quantiser_state['quantised_fillers_sg'], 
                                                                        quantiser_state['approx_fillers'],
                                                                        quantiser_state['idxs'])

        batched_roles = self.role_embeddings.weight.unsqueeze(0).expand(bound_fillers.shape[0], -1, -1)
        tpr_bindings_sg = torch.einsum('bsf,bsr->bsfr', bound_fillers_sg, batched_roles)
        z_tpr = tpr_bindings_sg.sum(dim=1).view(-1, self.embed_dim)
        return {'z_tpr': z_tpr, 
                'tpr_bindings_sg': tpr_bindings_sg,
                'bound_fillers_sg': bound_fillers_sg,
                'bound_fillers': bound_fillers,
                'filler_idxs': filler_idxs, 
                'approx_fillers': approx_fillers,
                'quantiser_loss': quantiser_loss}
    
    def get_swapped_tpr(self, quantised_out: Dict, 
                        x: torch.Tensor, gt_factor_classes: torch.Tensor) -> torch.Tensor: 
        # (2*N_{B}, N_{R}, D_{F}) -> (N_{B}, 2, N_{R}, D_{F})
        recon_loss = 0 
        ws_dis_loss = 0 
        ws_l2_argmax_loss = 0
        ws_l2_loss = 0
        
        quantised_fillers, quantised_fillers_sg = quantised_out['bound_fillers'], quantised_out['bound_fillers_sg']
        quantised_fillers = torch.stack(torch.chunk(quantised_fillers, 2, 0), dim=1)
        quantised_fillers_sg = torch.stack(torch.chunk(quantised_fillers_sg, 2, 0), dim=1) 
        # (N_{B}, 2, N_{R}, D_{F})
        
        N = quantised_fillers.shape[0]
        dist = LA.vector_norm(quantised_fillers_sg[:, 0] - quantised_fillers_sg[:, 1], 2, dim=-1) + 1e-8 # (N_{B}, N_{R}, D_{F}) -> (N_{B}, N_{R}, 1)
        gt1, gt2 = torch.chunk(gt_factor_classes, 2, 0)
        one_hot = (gt1 != gt2).to(torch.float16) # (N_{B}, N_{R}) 
        
        if one_hot.shape != dist.shape: # (N_{R}) != (N_{R'})
            diff = dist.shape[1] - one_hot.shape[1]
            one_hot = torch.concatenate([one_hot, torch.zeros(size=(one_hot.shape[0], diff)).cuda()], dim=-1)
            
        ws_dis_loss = F.cross_entropy(dist, one_hot)
        
        mask = F.gumbel_softmax(dist, dim=1, hard=True).unsqueeze(-1).expand(-1, -1, self.filler_embed_dim).to(bool) # (N_{B}, N_{R})
            
        tpr_bindings = quantised_out['tpr_bindings_sg'].detach() # (2*N_{B}, N_{R}, D_{F}, D_{R})
        tpr_bindings_split = torch.stack(torch.chunk(tpr_bindings, 2, 0), dim=1)  # (N_{B}, 2, N_{R}, D_{F}, D_{R})
        
        mask_for_tpr = mask.unsqueeze(-1).expand(-1, -1, -1, self.role_embed_dim)
        temp0 = tpr_bindings_split[:, 0][mask_for_tpr].reshape(N, self.filler_embed_dim, self.role_embed_dim) 
        temp1 = tpr_bindings_split[:, 1][mask_for_tpr].reshape(N, self.filler_embed_dim, self.role_embed_dim)
        
        #print(f'First 10 dists are:\n{dist[0:10]}\nFirst 10 masks:\n{mask[0:10]}\nTpr[:,0]\n{tpr_bindings_split[:, 0][0:10][0:2][0:3]}\nTemp0 :\n{temp0[0:10][0:3][0:2]}\n')
        tpr_bindings_split[:, 0][mask_for_tpr] = temp1.view(-1)
        tpr_bindings_split[:, 1][mask_for_tpr] = temp0.view(-1)
        swapped_bindings = tpr_bindings_split 
        tpr_bindings_split = torch.concatenate(torch.unbind(tpr_bindings_split, dim=1), dim=0) # (2*N_{B}, N_{R}, D_{F}, D_{R})
        swapped_tpr = tpr_bindings_split.sum(dim=1).view(-1, self.embed_dim)
        
        diff = quantised_fillers[:, 0] - quantised_fillers[:, 1] # (N_{B}, N_{R}, D_{F})
        diff_along_argmax = diff[mask].view(-1, self.filler_embed_dim) # (N_{B}*1, D_{F})
                
        ws_l2_argmax_loss = -torch.sum(
                        torch.norm(diff_along_argmax**2, 2, -1)
                    ) / N 
        diff[mask] = 0 # mask 
        diff = diff.view(-1, self.filler_embed_dim)
        norm_constant = (self.n_roles - 1)*N 
        ws_l2_loss = torch.sum(torch.norm(diff**2, 2, -1)) / norm_constant

        x_hat = self.decode(swapped_tpr)
        x1, x2 = torch.chunk(x, 2, dim=0)
        
        with torch.no_grad():
            mse_recon_loss = mse_recon_loss_fn(x_hat, torch.concatenate([x2, x1], dim=0), logging=True)
            bce_recon_loss = bce_recon_loss_fn(x_hat, torch.concatenate([x2, x1], dim=0), logging=True)
        
        if self.recon_loss_fn == 'mse': 
            recon_loss = mse_recon_loss_fn(x_hat, torch.concatenate([x2, x1], dim=0), logging=False)
        else: 
            recon_loss = bce_recon_loss_fn(x_hat, torch.concatenate([x2, x1], dim=0), logging=False)
        
        return {'loss': {'ws_embed_loss': ws_l2_loss, 
                         'ws_argmax_embed_loss': ws_l2_argmax_loss, 
                         'ws_mse_recon_loss': mse_recon_loss, 
                         'ws_bce_recon_loss': bce_recon_loss,
                         'ws_recon_loss': recon_loss,
                         'ws_dis_loss': ws_dis_loss}, 
                'state': {'swapped_tpr': swapped_tpr, 
                'ws_argmax': torch.argwhere(mask == 1)[:, 1], 
                'swapped_bindings': swapped_bindings,
                'ws_x_hat': x_hat}}  
        
    def forward(self, x: torch.Tensor, gt_factor_classes: torch.Tensor) -> Dict: 

        z = self.encode(x) # (N_{B}, D_{F}, D_{R}
        quantised_out = self.get_quantised(z) 
        z_tpr, bound_fillers_sg, bound_fillers, filler_idxs, approx_fillers, quantiser_loss = (quantised_out['z_tpr'], 
                                                                quantised_out['bound_fillers_sg'], 
                                                                quantised_out['bound_fillers'],
                                                                quantised_out['filler_idxs'], quantised_out['approx_fillers'],
                                                                quantised_out['quantiser_loss'])
        x_hat = self.decode(z_tpr)
        state = self.make_state(x=x, x_hat=x_hat, z=z, z_tpr=z_tpr, 
                                bound_fillers=bound_fillers_sg, 
                                approx_fillers=approx_fillers, 
                                filler_idxs=filler_idxs)
        loss = self.get_loss(x=x, x_hat=x_hat, quantiser_loss=quantiser_loss)
        
        if self.weakly_supervised and self.train: 
            swapped_tpr_out = self.get_swapped_tpr(quantised_out=quantised_out,
                                                   x=x, gt_factor_classes=gt_factor_classes)
            loss['total_loss'] += (self.lambda_ws_embed*swapped_tpr_out['loss']['ws_embed_loss'] + 
                                   self.lambda_ws_recon*swapped_tpr_out['loss']['ws_recon_loss'] + 
                                   self.lambda_ws_argmax_embed*swapped_tpr_out['loss']['ws_argmax_embed_loss'] + 
                                   self.lambda_ws_dis*swapped_tpr_out['loss']['ws_dis_loss']) 
            return {'loss': {**loss, **swapped_tpr_out['loss']}, 'state': {**state, **swapped_tpr_out['state']}}
        
        return {'loss': loss, 'state': state}

    def get_loss(self, x: torch.Tensor, x_hat: torch.Tensor, quantiser_loss: Dict) -> torch.Tensor: 
        loss_logs = {}

        with torch.no_grad():
            mse_recon_loss = mse_recon_loss_fn(x_hat, x, logging=True)
            bce_recon_loss = bce_recon_loss_fn(x_hat, x, logging=True)
        if self.recon_loss_fn == 'mse': 
            recon_loss = mse_recon_loss_fn(x_hat, x, logging=False)
        else: 
            recon_loss = bce_recon_loss_fn(x_hat, x, logging=False)
            
        total_loss = self.lambda_recon*recon_loss + quantiser_loss['quantiser_total_loss']
        
        if self.lambda_orth_penalty_role != 0 and not self.freeze_role_embeddings: 
            orth_penalty_role_loss = BaseTPREncoder.get_semi_orth_penalty(self.role_embeddings.weight.t()) * self.lambda_orth_penalty_role
            total_loss += orth_penalty_role_loss
            loss_logs = {ORTH_PENALTY_ROLE: orth_penalty_role_loss}

        role_rank = BaseTPREncoder.get_rank(self.role_embeddings.weight.t())

        return {'mse_recon_loss': mse_recon_loss,
                'bce_recon_loss': bce_recon_loss,
                'total_loss': total_loss, 
                ROLE_RANK: role_rank,
                **quantiser_loss,
                **loss_logs}

    def count_params(self): 
        return sum(p.numel() for p in self.encoder.parameters()) + sum(p.numel() for p in self.decoder.parameters()) + sum(p.numel() for p in self.quantiser.parameters())
         

            
