from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout


from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class Prompter(nn.Module):
    def __init__(self, device, pad_size=30, image_size=224):
        super().__init__()
        self.device = device
        self.base_size = image_size - pad_size*2
        self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
        self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
        self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
        self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
        self.conv2d = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x):
        base = torch.zeros(1, 3, self.base_size, self.base_size).to(self.device)
        prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3)
        prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2)
        prompt = torch.cat(x.size(0) * [prompt])
        x = self.conv2d(x)
        return x + prompt

class CustomCLIPVisual(nn.Module):
    def __init__(self, clip_model, device):
        super().__init__()
        self.dtype = clip_model.dtype
        self.prompter = Prompter(device)
        self.visual_encoder = clip_model.visual
        
    def forward(self, images):
        prompted_images = self.prompter(images)
        image_features = self.visual_encoder(prompted_images)
        image_features = image_features.type(self.dtype)
        return image_features

class VisualPromptCLIP(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device):
        super(VisualPromptCLIP, self).__init__()
        self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self._features_dim = feat_dim
        self.logit_scale = clip_model.logit_scale
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classnames]).to(device)
        text_features = clip_model.encode_text(text_inputs)
        self.text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.visual_backbone.named_parameters():
            if "prompter" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.visual_backbone.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = self.logit_scale * image_features @ self.text_features.t()
        
        if self.training:
            return logits, logits, image_features
        else:
            return logits, image_features
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.prompter.parameters(), "lr": 1.0 * base_lr},
        ]

        return params