from collections import OrderedDict
from typing import Tuple, Optional, Dict, List, Sequence
from typing import TypeVar

import gym
import torch
from gym.spaces.dict import Dict as SpaceDict
import numpy as np
import torch.nn as nn
import clip
import torchmetrics.functional as MF

from allenact.algorithms.onpolicy_sync.policy import (
    ActorCriticModel,
    LinearCriticHead,
    LinearActorHead,
    ObservationType,
)
from allenact.base_abstractions.distributions import CategoricalDistr
from allenact.embodiedai.aux_losses.losses import MultiAuxTaskNegEntropyLoss
from allenact.embodiedai.models.aux_models import AuxiliaryModel
from allenact.embodiedai.models.basic_models import RNNStateEncoder
from allenact.embodiedai.models.fusion_models import Fusion
from allenact.utils.model_utils import FeatureEmbedding
from allenact.utils.system import get_logger

FusionType = TypeVar("FusionType", bound=Fusion)

from allenact.algorithms.onpolicy_sync.policy import (
    ObservationType,
    DistributionType
)
from allenact.base_abstractions.misc import (
    Memory,
    ActorCriticOutput,
)
from allenact.embodiedai.models.visual_nav_models import VisualNavActorCritic
from allenact_plugins.clip_plugin.MVPT_AVG import AVGMultiVisualPromptTuningCLIP
from allenact_plugins.clip_plugin.MVPT_CAT import CATMultiVisualPromptTuningCLIP
from allenact_plugins.clip_plugin.MVPT_ConPE import CONPEMultiVisualPromptTuningCLIP


class ObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str = "goal_object_type_ind",
        # RNN
        hidden_size=513,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        clip_rgb_preprocessor_uuid: str = 'rgb_clip_resnet',
        clip_embedding_dim: int = 512,
        clip_model_type: str = "ViT-B/32",
        prompt: str = None,
        multi_p_mode: str = None,
        meta_mode: bool = False,
        noise_std: float = 0.0,
        source_model: str = None,
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert clip_rgb_preprocessor_uuid is not None

        self.clip_rgb_preprocessor_uuid = clip_rgb_preprocessor_uuid
        self.goal_sensor_uuid = goal_sensor_uuid
        self.prompt = prompt
        self.multi_p_mode = multi_p_mode
        self.meta_mode = meta_mode
        self.noise_std = noise_std
        self.clip_embedding_dim = clip_embedding_dim

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        clip_model = clip.load(clip_model_type, device="cpu")[0]
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()

        self.embedder = clip_model.visual
        
        del clip_model.transformer

        # self.rnn = nn.RNN(input_size=clip_embedding_dim, hidden_size=clip_embedding_dim, num_layers=1, batch_first=True, bias=False)

        self.create_state_encoders(
            obs_embed_size=hidden_size,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=hidden_size,
            action_embed_size=action_embed_size,
        )

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")
        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        
        self.train()

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        obs = observations[self.clip_rgb_preprocessor_uuid].to(self.device)
        x = obs[:, :, :3*224*224].detach().clone()
        B, env_n, _ = x.shape
        x = x.view(B*env_n, 3, 224, 224)
        goal = obs[:, :, 3*224*224:].detach().clone()
        # embedding
        x = self.embedder(x)
        x = x.unsqueeze(1) # for rnn input
        # x = x / x.norm(dim=-1, keepdim=True)
        if self.noise_std:
            noise = torch.clip(torch.normal(0, self.noise_std, size=x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
            x += noise
            # x = x / x.norm(dim=-1, keepdim=True)
        # _, x = self.rnn(x)
        x = x.view(B, env_n, x.size(-1))
        x = torch.cat([x, goal], dim=-1)
        return x

    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        obs_embeds = self.forward_encoder(observations)
        
        # 1.2 use embedding model to get prev_action embeddings
        if self.prev_action_embedder.input_size == self.action_space.n + 1:
            # In this case we have a unique embedding for the start of an episode
            prev_actions_embeds = self.prev_action_embedder(
                torch.where(
                    condition=0 != masks.view(*prev_actions.shape),
                    input=prev_actions + 1,
                    other=torch.zeros_like(prev_actions),
                )
            )
        else:
            prev_actions_embeds = self.prev_action_embedder(prev_actions)
        
        joint_embeds = torch.cat((obs_embeds, prev_actions_embeds), dim=-1)  # (T, N, *)

        # 2. use RNNs to get single/multiple beliefs
        beliefs_dict = {}
        for key, model in self.state_encoders.items():
            beliefs_dict[key], rnn_hidden_states = model(
                joint_embeds, memory.tensor(key), masks
            )
            memory.set_tensor(key, rnn_hidden_states)  # update memory here

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(
            beliefs_dict, obs_embeds
        )  # fused beliefs

        # 4. prepare output
        extras = (
            {
                aux_uuid: {
                    "beliefs": (
                        beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs
                    ),
                    "obs_embeds": obs_embeds,
                    "aux_model": (
                        self.aux_models[aux_uuid]
                        if aux_uuid in self.aux_models
                        else None
                    ),
                }
                for aux_uuid in self.auxiliary_uuids
            }
            if self.auxiliary_uuids is not None
            else {}
        )

        if self.multiple_beliefs:
            extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights

        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(beliefs),
            values=self.critic(beliefs),
            extras=extras,
        )

        torch.cuda.empty_cache()

        return actor_critic_output, memory


class COMObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str = "goal_object_type_ind",
        # RNN
        hidden_size=513,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        clip_rgb_preprocessor_uuid: str = 'rgb_clip_resnet',
        clip_embedding_dim: int = 512,
        clip_model_type: str = "ViT-B/32",
        prompt: str = None,
        multi_p_mode: str = None,
        meta_mode: bool = False,
        noise_std: float = 0.0,
        source_model: str = None,
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert clip_rgb_preprocessor_uuid is not None

        self.clip_rgb_preprocessor_uuid = clip_rgb_preprocessor_uuid
        self.goal_sensor_uuid = goal_sensor_uuid
        self.prompt = prompt
        self.multi_p_mode = multi_p_mode
        self.meta_mode = meta_mode
        self.noise_std = noise_std
        self.clip_embedding_dim = clip_embedding_dim
        self.source_model = source_model

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        clip_model = clip.load(clip_model_type, device="cpu")[0]
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()

        if self.prompt is None:
            self.embedder = clip_model.visual
        else:
            if self.multi_p_mode[2] == "AVG":
                self.embedder = AVGMultiVisualPromptTuningCLIP(clip_model, self.device)
            if self.multi_p_mode[2] == "CAT":
                self.embedder = CATMultiVisualPromptTuningCLIP(clip_model, self.device)
            self.embedder.prompt_init(self.prompt, multi_p_mode=self.multi_p_mode, meta_mode=self.meta_mode)
            
        del clip_model.transformer

        # self.rnn = nn.RNN(input_size=clip_embedding_dim, hidden_size=clip_embedding_dim, num_layers=1, batch_first=True)

        self.create_state_encoders(
            obs_embed_size=hidden_size,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=hidden_size,
            action_embed_size=action_embed_size,
        )

        if self.meta_mode:
            if self.source_model[0] is not None:
                source_dict = torch.load(self.source_model[0])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)
            if self.source_model[1] is not None:
                source_dict = torch.load(self.source_model[1])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)
        
        if self.meta_mode:
            for_meta = [
                "state_encoders.single_belief.rnn.bias_ih_l0", "state_encoders.single_belief.rnn.weight_ih_l0",
                "state_encoders.single_belief.rnn.weight_hh_l0", "state_encoders.single_belief.rnn.bias_hh_l0",
                "actor.linear.weight", "actor.linear.bias", "critic.fc.weight", "critic.fc.bias",
                ]
            for name, param in self.named_parameters():
                if name in for_meta:
                    param.requires_grad_(False)
                # if "source" in name:
                #     param.requires_grad_(False)

        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")
        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        
        self.train()

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        obs = observations[self.clip_rgb_preprocessor_uuid].to(self.device)
        x = obs[:, :, :3*224*224].detach().clone()
        B, env_n, _ = x.shape
        x = x.view(B*env_n, 3, 224, 224)
        goal = obs[:, :, 3*224*224:].detach().clone()
        # embedding
        x = self.embedder(x)
        x = x / x.norm(dim=-1, keepdim=True)
        if self.noise_std:
            noise = torch.clip(torch.normal(0, self.noise_std, size=x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
            x += noise
            x = x / x.norm(dim=-1, keepdim=True)
        if self.multi_p_mode[2] == "AVG":
            x = x.unsqueeze(1)
        # _, x = self.rnn(x)
        x = x.view(B, env_n, x.size(-1))
        x = torch.cat([x, goal], dim=-1)
        return x

    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        obs_embeds = self.forward_encoder(observations)
        
        # 1.2 use embedding model to get prev_action embeddings
        if self.prev_action_embedder.input_size == self.action_space.n + 1:
            # In this case we have a unique embedding for the start of an episode
            prev_actions_embeds = self.prev_action_embedder(
                torch.where(
                    condition=0 != masks.view(*prev_actions.shape),
                    input=prev_actions + 1,
                    other=torch.zeros_like(prev_actions),
                )
            )
        else:
            prev_actions_embeds = self.prev_action_embedder(prev_actions)
        
        joint_embeds = torch.cat((obs_embeds, prev_actions_embeds), dim=-1)  # (T, N, *)

        # 2. use RNNs to get single/multiple beliefs
        beliefs_dict = {}
        for key, model in self.state_encoders.items():
            beliefs_dict[key], rnn_hidden_states = model(
                joint_embeds, memory.tensor(key), masks
            )
            memory.set_tensor(key, rnn_hidden_states)  # update memory here

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(
            beliefs_dict, obs_embeds
        )  # fused beliefs

        # 4. prepare output
        extras = (
            {
                aux_uuid: {
                    "beliefs": (
                        beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs
                    ),
                    "obs_embeds": obs_embeds,
                    "aux_model": (
                        self.aux_models[aux_uuid]
                        if aux_uuid in self.aux_models
                        else None
                    ),
                }
                for aux_uuid in self.auxiliary_uuids
            }
            if self.auxiliary_uuids is not None
            else {}
        )

        if self.multiple_beliefs:
            extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights

        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(beliefs),
            values=self.critic(beliefs),
            extras=extras,
        )

        torch.cuda.empty_cache()

        return actor_critic_output, memory


class ENSObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str = "goal_object_type_ind",
        # RNN
        hidden_size=513,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        clip_rgb_preprocessor_uuid: str = 'rgb_clip_resnet',
        clip_embedding_dim: int = 512,
        clip_model_type: str = "ViT-B/32",
        prompt: str = None,
        multi_p_mode: str = None,
        meta_mode: bool = False,
        noise_std: float = 0.0,
        source_model: str = None,
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert clip_rgb_preprocessor_uuid is not None

        self.clip_rgb_preprocessor_uuid = clip_rgb_preprocessor_uuid
        self.goal_sensor_uuid = goal_sensor_uuid
        self.prompt = prompt
        self.multi_p_mode = multi_p_mode
        self.meta_mode = meta_mode
        self.noise_std = noise_std
        self.clip_embedding_dim = clip_embedding_dim
        self.source_model = source_model

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.clip_model = clip.load(clip_model_type, device="cpu")[0]
        for module in self.clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        self.clip_model.eval().float()
        if self.multi_p_mode[2] == "AVG":
            self.embedder = AVGMultiVisualPromptTuningCLIP(self.clip_model, self.device)
        elif self.multi_p_mode[2] == "CAT":
            self.embedder = CATMultiVisualPromptTuningCLIP(self.clip_model, self.device)
        self.embedder.prompt_init(self.prompt, multi_p_mode=self.multi_p_mode)
        
        del self.clip_model.transformer
        # self.rnn = nn.RNN(input_size=clip_embedding_dim, hidden_size=clip_embedding_dim, num_layers=1, batch_first=True)

        # ATTENTION MODULE
        if self.multi_p_mode[0] == "SESoM":
            # BASELINE REFERENCE: SESoM
            self.attn_W_down = nn.Linear(clip_embedding_dim, 128, bias=False)
            self.attn_W_up = nn.Linear(128, clip_embedding_dim, bias=False)
            self.attn_non_linear = nn.SiLU()
            self.attn_layer_norm = nn.LayerNorm(clip_embedding_dim)

            self.source_prompt_attn_weight_list = nn.ModuleList([])
            for i in range(self.embedder.visual_backbone.prompt_num):
                attn = nn.Sequential(
                        nn.Linear(clip_embedding_dim, 128, bias=False),
                        nn.SiLU(),
                        nn.Linear(128, clip_embedding_dim, bias=False),
                        nn.LayerNorm(clip_embedding_dim),
                    )
                self.source_prompt_attn_weight_list.append(attn)

        elif self.multi_p_mode[1] == "WEIGHTED":
            self.source_prompt_attn_weight_list = nn.ModuleList([])
            for i in range(self.embedder.visual_backbone.prompt_num):
                attn = nn.Sequential(
                        nn.Linear(clip_embedding_dim, 128, bias=False),
                        nn.SiLU(),
                        nn.Linear(128, clip_embedding_dim, bias=False),
                        nn.LayerNorm(clip_embedding_dim),
                    )
                self.source_prompt_attn_weight_list.append(attn)
        
        self.create_state_encoders(
            obs_embed_size=hidden_size,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=hidden_size,
            action_embed_size=action_embed_size,
        )
        if self.meta_mode:
            if self.source_model[0] is not None:
                source_dict = torch.load(self.source_model[0])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)
            if self.source_model[1] is not None:
                source_dict = torch.load(self.source_model[1])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)

        if self.meta_mode:
            for_meta = [
                "state_encoders.single_belief.rnn.bias_ih_l0", "state_encoders.single_belief.rnn.weight_ih_l0",
                "state_encoders.single_belief.rnn.weight_hh_l0", "state_encoders.single_belief.rnn.bias_hh_l0",
                "actor.linear.weight", "actor.linear.bias", "critic.fc.weight", "critic.fc.bias",
                ]
            for name, param in self.named_parameters():
                if name in for_meta:
                    param.requires_grad_(False)
                # if "source" in name:
                #     param.requires_grad_(False)

        for name, param in self.clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")
        
        self.train()

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        with torch.no_grad():
            obs = observations[self.clip_rgb_preprocessor_uuid].to(self.device)
            x = obs[:, :, :3*224*224].detach().clone()
            B, env_n, _ = x.shape
            image = x.view(B*env_n, 3, 224, 224)
            goal = obs[:, :, 3*224*224:].detach().clone()
            # embedding
            x = self.embedder(image)
            if self.multi_p_mode[2] == "CAT":
                x.view(B*env_n, -1, self.clip_embedding_dim)
                x = x / x.norm(dim=-1, keepdim=True)
            else:
                x = x / x.norm(dim=-1, keepdim=True)
            if self.noise_std:
                noise = torch.clip(torch.normal(0, self.noise_std, size=x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
                x += noise
                x = x / x.norm(dim=-1, keepdim=True)

            x = x.view(B*env_n, self.embedder.visual_backbone.prompt_num, -1) # torch.Size([B*env_n, prompt_num, clip_embedding_dim])
        ori_x = self.clip_model.visual(image)
        
        if self.multi_p_mode[1]=="WEIGHTED":
            # ATTENTION
            # if self.multi_p_mode[0] == "ENSEMBLE":
            #     # P_emb -> P_emb_hat
            #     query = ori_x.unsqueeze(1) # torch.Size([B, 1, clip_embedding_dim])
            #     P_emb_list = []
            #     P_emb_hat_list = []
            #     for i, attn in enumerate(self.source_prompt_attn_weight_list):
            #         P_emb = x[:,i,:] # torch.Size([B, hidden_size])
            #         P_emb_list.append(P_emb.unsqueeze(1))
            #         P_emb_hat = attn(P_emb) # torch.Size([1, hidden_size])
            #         P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
            #     key = torch.cat(P_emb_hat_list, dim=1)
            #     value = torch.cat(P_emb_list, dim=1)
            #     score = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # (batch, 1, s_len)
            #     attn = torch.softmax(score, -1) # (batch, 1, s_len)
            #     if self.multi_p_mode[2] == "AVG":
            #         context = torch.bmm(attn, value) # (batch, 1, dim)
            #         x = ori_x.unsqueeze(1) + context
            #     elif self.multi_p_mode[2] == "CAT":
            #         # weighted cat
            #         attn = attn.squeeze().unsqueeze(-1)
            #         attn = attn.expand(B*env_n , self.embedder.visual_backbone.prompt_num, value.size(-1)) # (batch, s_len, dim)
            #         context = attn * value # (batch, s_len, dim)
            #         x = torch.cat([ori_x.unsqueeze(1), context], dim=1)
            if self.multi_p_mode[0] == "SESoM":
                # BASELINE REFERENCE: SESoM
                H = self.attn_W_down(ori_x)
                H = self.attn_non_linear(H)
                H = self.attn_W_up(H)
                H = self.attn_layer_norm(H) # torch.Size([B, clip_embedding_dim])
                P_emb_list = []
                P_emb_hat_list = []
                for i, attn in enumerate(self.source_prompt_attn_weight_list):
                    P_emb = x[:,i,:] # torch.Size([B, hidden_size])
                    P_emb_list.append(P_emb.unsqueeze(1))
                    P_emb_hat = attn(P_emb) # torch.Size([1, hidden_size])
                    P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
                key = torch.cat(P_emb_hat_list, dim=1)
                value = torch.cat(P_emb_list, dim=1)
                query = H.unsqueeze(1) # torch.Size([B, 1, clip_embedding_dim])
                score = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # (batch, 1, s_len)
                attn = torch.softmax(score, -1) # (batch, 1, s_len)
                # weighted sum
                x = torch.bmm(attn, value) # (batch, 1, dim)

        else:
            if self.multi_p_mode[2] == "AVG":
                x = ori_x + torch.mean(x, dim=1)
                x = x.unsqueeze(1)
                
            elif self.multi_p_mode[2] == "CAT":
                ori_x = ori_x.unsqueeze(1)
                x = x / self.embedder.visual_backbone.prompt_num
                x = torch.cat([ori_x, x], dim=1)
        # _, x = self.rnn(x)
        x = x.view(B, env_n, -1)
        x = torch.cat([x, goal], dim=-1)
        return x

    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        obs_embeds = self.forward_encoder(observations)
        
        # 1.2 use embedding model to get prev_action embeddings
        if self.prev_action_embedder.input_size == self.action_space.n + 1:
            # In this case we have a unique embedding for the start of an episode
            prev_actions_embeds = self.prev_action_embedder(
                torch.where(
                    condition=0 != masks.view(*prev_actions.shape),
                    input=prev_actions + 1,
                    other=torch.zeros_like(prev_actions),
                )
            )
        else:
            prev_actions_embeds = self.prev_action_embedder(prev_actions)
        joint_embeds = torch.cat((obs_embeds, prev_actions_embeds), dim=-1)  # (T, N, *)

        # 2. use RNNs to get single/multiple beliefs
        beliefs_dict = {}
        for key, model in self.state_encoders.items():
            beliefs_dict[key], rnn_hidden_states = model(
                joint_embeds, memory.tensor(key), masks
            )
            memory.set_tensor(key, rnn_hidden_states)  # update memory here

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(
            beliefs_dict, obs_embeds
        )  # fused beliefs

        # 4. prepare output
        extras = (
            {
                aux_uuid: {
                    "beliefs": (
                        beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs
                    ),
                    "obs_embeds": obs_embeds,
                    "aux_model": (
                        self.aux_models[aux_uuid]
                        if aux_uuid in self.aux_models
                        else None
                    ),
                }
                for aux_uuid in self.auxiliary_uuids
            }
            if self.auxiliary_uuids is not None
            else {}
        )

        if self.multiple_beliefs:
            extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights

        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(beliefs),
            values=self.critic(beliefs),
            extras=extras,
        )

        torch.cuda.empty_cache()

        return actor_critic_output, memory


class CONPEObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str = "goal_object_type_ind",
        # RNN
        hidden_size=513,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        clip_rgb_preprocessor_uuid: str = 'rgb_clip_resnet',
        clip_embedding_dim: int = 512,
        clip_model_type: str = "ViT-B/32",
        prompt: str = None,
        multi_p_mode: str = None,
        meta_mode: bool = False,
        noise_std: float = 0.0,
        sm_noise: tuple = 0.0,
        source_model: str = None,
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert clip_rgb_preprocessor_uuid is not None

        self.clip_rgb_preprocessor_uuid = clip_rgb_preprocessor_uuid
        self.goal_sensor_uuid = goal_sensor_uuid
        self.prompt = prompt
        self.multi_p_mode = multi_p_mode
        self.meta_mode = meta_mode
        self.noise_std = noise_std
        self.sm_noise = sm_noise
        self.clip_embedding_dim = clip_embedding_dim
        self.source_model = source_model

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        self.clip_model = clip.load(clip_model_type, device="cpu")[0]
        for module in self.clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        self.clip_model.eval().float()
        self.embedder = CONPEMultiVisualPromptTuningCLIP(self.clip_model, self.device)
        self.embedder.prompt_init(self.prompt, multi_p_mode=self.multi_p_mode, meta_mode=self.meta_mode)
        
        del self.clip_model.transformer

        # ATTENTION MODULE
        if self.multi_p_mode[1] == "WEIGHTED":
            self.source_prompt_attn_weight_list = nn.ModuleList([])
            for i in range(self.embedder.visual_backbone.prompt_num):
                attn = nn.Sequential(
                        nn.Linear(clip_embedding_dim, 128, bias=False),
                        nn.SiLU(),
                        nn.Linear(128, clip_embedding_dim, bias=False),
                        nn.LayerNorm(clip_embedding_dim),
                    )
                self.source_prompt_attn_weight_list.append(attn)

        self.create_state_encoders(
            obs_embed_size=hidden_size,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=hidden_size,
            action_embed_size=action_embed_size,
        )
        if self.meta_mode:
            if self.source_model[0] is not None:
                source_dict = torch.load(self.source_model[0])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)
            if self.source_model[1] is not None:
                source_dict = torch.load(self.source_model[1])["model_state_dict"]
                model_dict = self.state_dict()
                # print(source_dict.keys())
                pretrained_dict = {k: v for k, v in source_dict.items() if k in model_dict}
                print(pretrained_dict.keys())
                model_dict.update(pretrained_dict) 
                self.load_state_dict(model_dict)
        # exit()
        print("Turning off gradients in both the image and the text encoder")
        # for name, param in self.named_parameters():
        for name, param in self.embedder.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)
            
        if self.meta_mode:
            for_meta = [
                "state_encoders.single_belief.rnn.bias_ih_l0", "state_encoders.single_belief.rnn.weight_ih_l0",
                "state_encoders.single_belief.rnn.weight_hh_l0", "state_encoders.single_belief.rnn.bias_hh_l0",
                "actor.linear.weight", "actor.linear.bias", "critic.fc.weight", "critic.fc.bias",
                ]
            for name, param in self.named_parameters():
                if name in for_meta:
                    param.requires_grad_(False)
                # if "source" in name:
                #     param.requires_grad_(False)
        
        for name, param in self.clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        if self.sm_noise[2]:
            self.text_features = torch.load(self.sm_noise[2])
            self.cnt = 0
        
        self.train()
        self.cos_sim_list = []
        self.attn_list = []

    @property
    def is_blind(self) -> bool:
        return False

    def denormalize(self, x,):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        # 3, H, W, B
        ten = x.clone().permute(1, 2, 3, 0)
        for t, m, s in zip(ten, mean, std):
            t.mul_(s).add_(m)
        # B, 3, H, W
        return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2)
    
    def semantic_noise(self, x):
        def make_logits(x):
            logits = 100. * x @ self.text_features.t()#.to(self.device)
            logits = torch.softmax(logits, dim=-1)
            return logits
        ori_x = x.detach()
        
        pre_logits = make_logits(ori_x)
        pre_labels = torch.argmax(pre_logits, dim=-1)
        pre_onehots = torch.nn.functional.one_hot(pre_labels, pre_logits.size(-1))
        
        noise = torch.clip(torch.normal(0, self.noise_std, size=x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
        x = x + noise
        
        post_logits = make_logits(x)

        f1 = MF.f1_score(task="binary", preds=post_logits.flatten(), target=pre_onehots.flatten())
        # print(f1)
        while f1 < self.sm_noise[3]:
            noise = torch.clip(torch.normal(0, self.noise_std, size=ori_x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
            x = ori_x + noise
            x = x / x.norm(dim=-1, keepdim=True)

            post_logits = make_logits(x)
            
            f1 = MF.f1_score(task="binary", preds=post_logits.flatten(), target=pre_onehots.flatten())
            
            self.cnt+=1
            if self.cnt==300:
                self.cnt = 0
                print("giveup")
                break
        # if self.cnt > 1:
        #     print(self.cnt)
        self.cnt = 0
        return noise

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        # observaion shaping
        with torch.no_grad():
            obs = observations[self.clip_rgb_preprocessor_uuid].to(self.device)
            x = obs[:, :, :3*224*224].detach().clone()
            B, env_n, _ = x.shape
            image = x.view(B*env_n, 3, 224, 224)
            # import torchvision
            # torchvision.utils.save_image(torchvision.utils.make_grid(image, nrow=5, normalize=True), "grid_image_val_.png")
            # torchvision.utils.save_image(self.denormalize(image[0].unsqueeze(0)).squeeze(0), "test1.png")
            # torchvision.utils.save_image(self.denormalize(image[1].unsqueeze(0)).squeeze(0), "test2.png")
            # torchvision.utils.save_image(self.denormalize(image[2].unsqueeze(0)).squeeze(0), "test3.png")
            # torchvision.utils.save_image(self.denormalize(image[3].unsqueeze(0)).squeeze(0), "test4.png")
            # torchvision.utils.save_image(self.denormalize(image[4].unsqueeze(0)).squeeze(0), "test5.png")
            # exit()
            goal = obs[:, :, 3*224*224:].detach().clone()
            # embedding
            x = self.embedder(image)
            x = x / x.norm(dim=-1, keepdim=True)
            # data augmentation
            if B > 1 and self.sm_noise[2]:
                x_c = x.clone().detach()
                for x_i in x:
                    x_i += self.semantic_noise(x_i.clone().detach())
                assert not (x_c == x).all()
                x = x / x.norm(dim=-1, keepdim=True)
            elif self.noise_std:
                noise = torch.clip(torch.normal(0, self.noise_std, size=x.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(self.device)
                x += noise
                x = x / x.norm(dim=-1, keepdim=True)
            
            x = x.view(B*env_n, self.embedder.visual_backbone.prompt_num, -1) # torch.Size([B*env_n, prompt_num, clip_embedding_dim])
        if self.meta_mode:
            clip_x = self.embedder(image, self.meta_mode)
        else:
            clip_x = self.clip_model.visual(image)
        # clip_x = clip_x / orclip_xi_x.norm(dim=-1, keepdim=True)
        
        if self.multi_p_mode[1]=="WEIGHTED":
            # ATTENTION
            if self.multi_p_mode[0] == "ENSEMBLE":
                query = clip_x.unsqueeze(1) # torch.Size([B, 1, clip_embedding_dim])
                # P_emb -> P_emb_hat
                P_emb_list = []
                P_emb_hat_list = []
                for i, attn in enumerate(self.source_prompt_attn_weight_list):
                    P_emb = x[:,i,:] # torch.Size([B, clip_embedding_dim])
                    P_emb_list.append(P_emb.unsqueeze(1))
                    P_emb_hat = attn(P_emb) # torch.Size([1, clip_embedding_dim])
                    P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
                key = torch.cat(P_emb_hat_list, dim=1) # torch.Size([B, prompt_num, clip_embedding_dim])
                value = torch.cat(P_emb_list, dim=1) # torch.Size([B, prompt_num, clip_embedding_dim])
                # cosim guidance
                Q_norm = torch.norm(query, dim=2, keepdim=True)
                V_norm = torch.norm(value, dim=2, keepdim=True)
                dot_prod = torch.bmm(query, value.permute(0, 2, 1))
                cos_sim = dot_prod / torch.bmm(Q_norm, V_norm.permute(0, 2, 1)) # torch.Size([B, 1, prompt_num])
                # cos_sim = torch.softmax(cos_sim, -1) # torch.Size([B, 1, prompt_num])
                # self.cos_sim_list.append(cos_sim)
                # print(torch.mean(torch.stack(self.cos_sim_list).squeeze(), 0))
                score = torch.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # torch.Size([B, 1, prompt_num])
                # gubel noise sampling
                if self.sm_noise[0] and not self.sm_noise[1]: # hard False
                    gumbel_noise = -torch.log(-torch.log(torch.rand_like(cos_sim)))
                    score = score + gumbel_noise / self.sm_noise[0]
                elif self.sm_noise[0] and self.sm_noise[1]: # hard True
                    score = score / self.sm_noise[0]
                attn = torch.softmax(cos_sim*score, -1) # torch.Size([B, 1, prompt_num])
                # self.attn_list.append(attn)
                # print(torch.mean(torch.stack(self.attn_list).squeeze(), 0))
                if self.multi_p_mode[2] == "AVG":
                    context = torch.bmm(attn, value) # torch.Size([B, 1, prompt_num])
                    x = clip_x.unsqueeze(1) + context

        x = x.view(B, env_n, -1)
        x = torch.cat([x, goal], dim=-1)
        return x

    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        obs_embeds = self.forward_encoder(observations)
        
        # 1.2 use embedding model to get prev_action embeddings
        if self.prev_action_embedder.input_size == self.action_space.n + 1:
            # In this case we have a unique embedding for the start of an episode
            prev_actions_embeds = self.prev_action_embedder(
                torch.where(
                    condition=0 != masks.view(*prev_actions.shape),
                    input=prev_actions + 1,
                    other=torch.zeros_like(prev_actions),
                )
            )
        else:
            prev_actions_embeds = self.prev_action_embedder(prev_actions)
        
        joint_embeds = torch.cat((obs_embeds, prev_actions_embeds), dim=-1)  # (T, N, *)

        # 2. use RNNs to get single/multiple beliefs
        beliefs_dict = {}
        for key, model in self.state_encoders.items():
            beliefs_dict[key], rnn_hidden_states = model(
                joint_embeds, memory.tensor(key), masks
            )
            memory.set_tensor(key, rnn_hidden_states)  # update memory here

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(
            beliefs_dict, obs_embeds
        )  # fused beliefs

        # 4. prepare output
        extras = (
            {
                aux_uuid: {
                    "beliefs": (
                        beliefs_dict[aux_uuid] if self.multiple_beliefs else beliefs
                    ),
                    "obs_embeds": obs_embeds,
                    "aux_model": (
                        self.aux_models[aux_uuid]
                        if aux_uuid in self.aux_models
                        else None
                    ),
                }
                for aux_uuid in self.auxiliary_uuids
            }
            if self.auxiliary_uuids is not None
            else {}
        )

        if self.multiple_beliefs:
            extras[MultiAuxTaskNegEntropyLoss.UUID] = task_weights

        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(beliefs),
            values=self.critic(beliefs),
            extras=extras,
        )

        torch.cuda.empty_cache()

        return actor_critic_output, memory


class CLIPNavActorCritic(VisualNavActorCritic):
    action_space: gym.spaces.Discrete

    def __init__(
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        hidden_size=1024,
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size,
            multiple_beliefs=False,
            beliefs_fusion=None,
            auxiliary_uuids=None,
        )

    def forward(  # type:ignore
        self,
        observations: ObservationType,
        memory: Memory,
        prev_actions: torch.Tensor,
        masks: torch.FloatTensor,
    ) -> Tuple[ActorCriticOutput[DistributionType], Optional[Memory]]:

        # 1.1 use perception model (i.e. encoder) to get observation embeddings
        vis_embeds = self.forward_encoder(observations)

        # 2. use RNNs to get single/multiple beliefs
        belief, rnn_hidden_states = self.state_encoders['single_belief'](
            vis_embeds,
            memory.tensor('single_belief'),
            masks
        )
        beliefs_dict = { 'single_belief': belief }
        memory.set_tensor('single_belief', rnn_hidden_states)

        # 3. fuse beliefs for multiple belief models
        beliefs, task_weights = self.fuse_beliefs(beliefs_dict, None)

        output = (vis_embeds + beliefs) * goal_embeds

        # 4. prepare output
        actor_critic_output = ActorCriticOutput(
            distributions=self.actor(output),
            values=self.critic(output),
            extras={},
        )

        return actor_critic_output, memory

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        raise NotImplementedError("Obs Encoder Not Implemented")


class CLIPObjectNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str,
        # RNN
        hidden_size=1024,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # custom params
        clip_rgb_preprocessor_uuid: str = 'rgb_clip_resnet',
        clip_embedding_dim: int = 1024
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size
        )

        assert clip_rgb_preprocessor_uuid is not None

        self.clip_rgb_preprocessor_uuid = clip_rgb_preprocessor_uuid

        self.create_state_encoders(
            obs_embed_size=clip_embedding_dim,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=clip_embedding_dim,
            action_embed_size=action_embed_size,
        )

        self.train()

    @property
    def is_blind(self) -> bool:
        return False

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:
        return observations[self.clip_rgb_preprocessor_uuid]