from typing import List, Callable, Optional, Any, cast, Dict

import gym
import numpy as np
import torch
import torch.nn as nn
from torchvision import models

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.misc_utils import prepare_locals_for_super


class ViTEmbedder(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.model = vit
        del self.model.fc
        print(self.model)
        exit()
        #self.pool = pool
        self.eval()

    def forward(self, x):
        with torch.no_grad():
            return self.model(x)


class ViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model."""

    def __init__(
        self,
        input_uuids: List[str],
        output_uuid: str,
        model_type: str,
        input_height: int,
        input_width: int,
        output_height: int,
        output_width: int,
        output_dims: int,
        #pool: bool,
        #torchvision_resnet_model: Callable[..., models.ResNet] = models.resnet18,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        def f(x, k):
            assert k in x, "{} must be set in ResNetPreprocessor".format(k)
            return x[k]

        def optf(x, k, default):
            return x[k] if k in x else default

        self.input_height = input_height
        self.input_width = input_width
        self.output_height = output_height
        self.output_width = output_width
        self.output_dims = output_dims
        #self.pool = pool
        #self.make_model = torchvision_resnet_model
        self.model_type = model_type
        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )

        self._vit: Optional[ViTEmbedder] = None
        
        low = -np.inf
        high = np.inf
        shape = (self.output_dims, self.output_height, self.output_width)

        assert (
            len(input_uuids) == 1
        ), "vit preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def vit(self) -> ViTEmbedder:
        if self._vit is None:
            if self.model_type == "ViT-B/32":
                self._vit = ViTEmbedder(
                #self.make_model(pretrained=True).to(self.device), pool=self.pool
                ViT('B_32',pretrained=True)
                )
            elif self.model_type == "ViT-B/16":
                self.vit = ViTEmbedder(
                ViT('B_16',pretrained=True)        
                )
        return self._vit

    def to(self, device: torch.device) -> "ViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
        # If the input is depth, repeat it across all 3 channels
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return self.vit(x.to(self.device))
