import re
from transformers import BertForMaskedLM,BertTokenizer
from transformers.models.bert.modeling_bert import BertConfig, BertEmbeddings
import numpy as np
import torch
from util.box_ops import generalized_box_iou
import logging
import re
import os
import sys
import copy
import tensorflow_hub as hub
from filter_words import filter_words

class Feature(object):
    def __init__(self, seq_a):
        # self.label = label
        self.seq = seq_a
        self.final_adverse = seq_a
        self.query = 0
        self.change = 0
        self.success = 0
        self.sim = 0.0
        self.changes = []
class BERTATT:
    def __init__(self,models,evaluator,postprocessors,cos_sim=0.95,text_budget=1):
        self.model=models
        self.postprocessors=postprocessors
        self.evaluator=evaluator
        self.cos_sim=cos_sim
        self.text_budget=text_budget
        self.USE_model = hub.load('universal-sentence-encoder-large_5')
        self.tokenizer_mlm = BertTokenizer.from_pretrained("bert-base-uncased",
                                                      do_lower_case="uncased" in "bert-base-uncased")
        config_atk = BertConfig.from_pretrained('bert-base-uncased')
        self.mlm_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config_atk).cuda()
    def pre_caption(self,caption, max_words=None):
        caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')

        caption = re.sub(
            r"\s{2,}",
            ' ',
            caption,
        )
        caption = caption.rstrip('\n')
        caption = caption.strip(' ')
        # truncate caption
        caption_words = caption.split(' ')
        if max_words is not None and len(caption_words) > max_words:
            caption = ' '.join(caption_words[:max_words])
        return caption
    def encode_text_2_ids(self,text,max_src_length=80):
        assert (
                self.bos is not None and self.eos is not None and self.dataset is not None
        ), "not load text encode tools right now"
        prompt = ' which region does the text " {} " describe?'
        src_caption = self.pre_caption(text, max_src_length)
        src_item = self.dataset.encode_text(prompt.format(src_caption))
        src_tokens = torch.cat([self.bos, src_item, self.eos])
        src_tokens=src_tokens.unsqueeze(0)
        return src_tokens

    def tokenize(self,seq, tokenizer):
        seq = seq.replace('\n', '').lower()
        words = seq.split(' ')

        sub_words = []
        keys = []
        index = 0
        for word in words:
            sub = tokenizer.tokenize(word)
            sub_words += sub
            keys.append([index, index + len(sub)])
            index += len(sub)

        return words, sub_words, keys

    def _get_masked(self,words):
        len_text = max(len(words), 2)
        masked_words = []
        for i in range(len_text - 1):
            masked_words.append(words[0:i] + ['<unk>'] + words[i + 1:])
        return masked_words
    def _calculate_ap_score(self,hyps, refs, thresh=0.5):
        interacts = torch.cat(
            [torch.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]),
             torch.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])],
            dim=1
        )
        area_predictions = (hyps[:, 2] - hyps[:, 0]) * (hyps[:, 3] - hyps[:, 1])
        area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
        interacts_w = interacts[:, 2] - interacts[:, 0]
        interacts_h = interacts[:, 3] - interacts[:, 1]
        area_interacts = interacts_w * interacts_h
        ious = area_interacts / (area_predictions + area_targets - area_interacts + 1e-6)

        return ((ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)).float(),ious
    def get_important_scores(self,words, batch, ori_box,ori_size,targets):
        masked_words = self._get_masked(words)
        texts = [' '.join(words) for words in masked_words]
        important_scores = []
        for mlm in texts:
            memory_cache = self.model(batch, [mlm], targets, encode_and_save=True)
            outputs = self.model(batch, [mlm], targets, encode_and_save=False, memory_cache=memory_cache)
            results = self.postprocessors["bbox"](outputs, ori_size)
            pred_box=results[0]['boxes'].cpu()
            score=generalized_box_iou(pred_box,ori_box)

            important_scores.append(score[0][0].cpu().numpy())
        return np.array(important_scores)
    def get_substitues(self, substitutes, tokenizer, mlm_model, substitutes_score=None,  use_bpe=True,threshold=0.3):
        # substitues L,k
        # from this matrix to recover a word
        words = []
        sub_len, k = substitutes.size()  # sub-len, k

        if sub_len == 0:
            return words

        elif sub_len == 1:
            for (i, j) in zip(substitutes[0], substitutes_score[0]):
                if threshold != 0 and j < threshold:
                    break
                words.append(tokenizer._convert_id_to_token(int(i)))
        else:
            if use_bpe == 1:
                words = self.get_bpe_substitues(substitutes, tokenizer, mlm_model)
            else:
                return words
        return words
    def get_bpe_substitues(self, substitutes, tokenizer, mlm_model):
        # substitutes L, k

        substitutes = substitutes[0:12, 0:4] # maximum BPE candidates

        # find all possible candidates
        all_substitutes = []
        for i in range(substitutes.size(0)):
            if len(all_substitutes) == 0:
                lev_i = substitutes[i]
                all_substitutes = [[int(c)] for c in lev_i]
            else:
                lev_i = []
                for all_sub in all_substitutes:
                    for j in substitutes[i]:
                        lev_i.append(all_sub + [int(j)])
                all_substitutes = lev_i

        # all substitutes  list of list of token-id (all candidates)
        c_loss = torch.nn.CrossEntropyLoss(reduction='none')
        word_list = []
        # all_substitutes = all_substitutes[:24]
        all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
        # print('all',all_substitutes.shape)
        all_substitutes = all_substitutes[:24].cuda()
        # print('all_sub',all_substitutes.shape)

        # print(substitutes.size(), all_substitutes.size())
        N, L = all_substitutes.size()
        word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
        # print('mlm',word_predictions.shape)
        # exit()
        ppl = c_loss(word_predictions.view(N*L, -1), all_substitutes.view(-1)) # [ N*L ]
        # print('ppl',word_predictions.view(N*L, -1).shape,all_substitutes.view(-1).shape,all_substitutes,ppl.view(N, L).shape)
        # exit()
        ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N
        # print('ppl1',ppl.shape)
        # print(ppl)
        # exit()
        _, word_list = torch.sort(ppl)
        # print('w1',word_list)
        # exit()
        word_list = [all_substitutes[i] for i in word_list]
        # print('w2',word_list)
        # exit()
        final_words = []
        for word in word_list:
            tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
            # print('token',tokens)
            text = tokenizer.convert_tokens_to_string(tokens)
            # print('text',text)
            final_words.append(text)
            # print(final_words)
            # exit()
        # print('final',final_words)
        return final_words
    def attack(self,samples,text, ori_box,ori_size,targets):
        ori_text = text[0]
        text = ori_text.lower()
        feature = Feature(text)
        tokenizer = self.tokenizer_mlm
        words, sub_words, keys = self.tokenize(feature.seq, tokenizer)
        max_length = 512
        inputs = tokenizer.encode_plus(feature.seq, None, add_special_tokens=True, max_length=max_length,
                                       truncation=True)
        input_ids, _ = torch.tensor(inputs["input_ids"]), torch.tensor(inputs["token_type_ids"])
        sub_words = ['[CLS]'] + sub_words[:2] + sub_words[2:max_length - 2] + ['[SEP]']
        # print('sub',sub_words)
        input_ids_ = torch.tensor([tokenizer.convert_tokens_to_ids(sub_words)])
        # print(input_ids_.to(self.device).shape)
        # exit()
        word_predictions = self.mlm_model(input_ids_.cuda())[0].squeeze()  # seq-len(sub) vocab
        # print('xxxx',self.mlm_model(input_ids_.to(self.device))[0].shape,input_ids_.to(self.device).shape)
        # exit()
        word_pred_scores_all, word_predictions = torch.topk(word_predictions, 10, -1)
        # print('xxxxqq', word_pred_scores_all.shape,word_predictions.shape)
        # exit()
        word_predictions = word_predictions[1:len(sub_words) + 1, :]
        word_pred_scores_all = word_pred_scores_all[1:len(sub_words) + 1, :]
        # print('important')
        # exit()
        important_scores = self.get_important_scores(words, samples, ori_box,ori_size,targets)
        # exit()
        feature.query += int(len(words))
        list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=False)
        final_words = copy.deepcopy(words)
        success = 0
        simout = 1
        text_bank = []
        sim_list = []
        # print(list_of_index)
        for ii, top_index in enumerate(list_of_index):
            if feature.change >= self.text_budget:
                feature.success = 1  # exceed
                break
            tgt_word = words[top_index[0]]
            if tgt_word in filter_words:
                continue
            if keys[top_index[0]][0] > max_length - 2:
                continue
            substitutes = word_predictions[keys[top_index[0]][0]:keys[top_index[0]][1]]
            word_pred_scores = word_pred_scores_all[keys[top_index[0]][0]:keys[top_index[0]][1]]
            substitutes = self.get_substitues(substitutes, tokenizer, self.mlm_model,
                                              substitutes_score=word_pred_scores)
            for substitute in substitutes:
                if substitute == tgt_word:
                    continue  # filter out original word
                if '##' in substitute:
                    continue  # filter out sub-word

                if substitute in filter_words:
                    continue
                temp_replace = copy.deepcopy(final_words)
                temp_replace[top_index[0]] = substitute
                temp_text = tokenizer.convert_tokens_to_string(temp_replace)
                embs = self.USE_model([ori_text, temp_text]).numpy()
                norm = np.linalg.norm(embs, axis=1)
                embs = embs / norm[:, None]
                sim = (embs[:1] * embs[1:]).sum(axis=1)[0]
                if sim > self.cos_sim:
                    sim_list.append(sim)
                    text_bank.append(temp_text)
                    memory_cache = self.model(samples, [temp_text], targets, encode_and_save=True)
                    outputs = self.model(samples, [temp_text], targets, encode_and_save=False,
                                               memory_cache=memory_cache)
                    results = self.postprocessors["bbox"](outputs, ori_size)
                    res = {target["image_id"].item(): output for target, output in zip(targets, results)}
                    self.evaluator.update(res)
                    refexp_res, scores = self.evaluator.summarize()
                    if scores == 1:
                        success = 1
                        return text_bank, success, sim_list
        text_cand = []
        if len(text_bank) != len(sim_list):
            print('wrong bank')
            raise ValueError
        if len(text_bank) != 0:
            sim_list_sort = copy.deepcopy(sim_list)
            for i in range(len(sim_list_sort)):
                si = sim_list_sort.index(max(sim_list_sort))
                text_cand.append(text_bank[si])
                sim_list_sort[si] = -1e8
        return text_cand, success, sim_list
