# -*- coding: utf-8 -*-
# !/usr/bin/python

import torch
from torch import nn
import numpy as np
device = torch.device("cuda")


def generate_plm_inputs(tokenizer, nlu1_tok, hds1, max_seq_length=500, hard_prompt=None):
    tokens = []
    segment_ids = []

    tokens.append("<s>")
    segment_ids.append(0)

    n_hds = []
    for i, hds11 in enumerate(nlu1_tok):
        i_st_hd = len(tokens)
        sub_tok = tokenizer.tokenize(hds11)
        if len(tokens + sub_tok) >= max_seq_length:
            break
        tokens += sub_tok
        i_ed_hd = len(tokens)
        n_hds.append((i_st_hd, i_ed_hd))
        segment_ids += [0] * len(sub_tok)

    tokens.append("</s>")
    segment_ids.append(0)

    i_hds = []
    for i, hds11 in enumerate(hds1):
        i_st_hd = len(tokens)
        sub_tok = tokenizer.tokenize(hds11)
        if len(tokens + sub_tok) >= max_seq_length:
            break
        tokens += sub_tok
        i_ed_hd = len(tokens)
        i_hds.append((i_st_hd, i_ed_hd))
        segment_ids += [1] * len(sub_tok)
        if i < len(hds1)-1:
            tokens.append("</s>")
            segment_ids.append(0)
        elif i == len(hds1)-1:
            tokens.append("</s>")
            segment_ids.append(1)
        else:
            raise EnvironmentError

    if hard_prompt is not None:
        print(hard_prompt)
        for q, cols in hard_prompt:
            sub_tok = tokenizer.tokenize(q)
            if len(tokens + sub_tok) >= max_seq_length:
                break
            tokens += sub_tok
            segment_ids += [0] * len(sub_tok)
            tokens.append("</s>")
            segment_ids.append(0)

            for i, col in enumerate(cols):
                sub_tok = tokenizer.tokenize(col)
                if len(tokens + sub_tok) >= max_seq_length:
                    break
                tokens += sub_tok
                segment_ids += [1] * len(sub_tok)
                if i < len(cols) - 1:
                    tokens.append("</s>")
                    segment_ids.append(0)
                elif i == len(cols) - 1:
                    tokens.append("</s>")
                    segment_ids.append(1)
        print(tokens)
        exit()

    return tokens, segment_ids, n_hds, i_hds

def plm_encode(hidden_size, plm_model, tokenizer, nlu_t, hds, max_seq_length, hard_prompt=None):

    # get contextual output of all tokens from bert
    last_hidden_state, pooled_output, tokens, i_nlu, i_hds,\
    l_n, n_hs, l_hpu, l_hs, \
    nlu_tt, t_to_tt_idx, tt_to_t_idx = get_plm_output(plm_model, tokenizer, nlu_t, hds, max_seq_length, hard_prompt=hard_prompt)
    # all_encoder_layer: BERT outputs from all layers.
    # pooled_output: output of [CLS] vec.
    # tokens: BERT intput tokens
    # i_nlu: start and end indices of question in tokens
    # i_hds: start and end indices of headers

    # get the wemb
    wemb_n = get_wemb_avg_list(i_nlu, n_hs, hidden_size, last_hidden_state)

    wemb_h = get_wemb_h(i_hds, l_hpu, l_hs, hidden_size, last_hidden_state)

    return wemb_n, wemb_h, l_n, n_hs, l_hpu, l_hs, \
           nlu_tt, t_to_tt_idx, tt_to_t_idx

def get_plm_output(plm_model, tokenizer, nlu_t, hds, max_seq_length, hard_prompt=None):

    n_hs = []
    l_hs = []  # The length of columns for each batch

    input_ids = []
    tokens = []
    segment_ids = []
    input_mask = []

    i_nlu = []  # index to retreive the position of contextual vector later.
    i_hds = []

    nlu_tt = []

    t_to_tt_idx = []
    tt_to_t_idx = []

    max_seq_length_in_batch = 0
    for b, nlu_t1 in enumerate(nlu_t):
        hds1 = hds[b]
        tokens1, _, _, _ = generate_plm_inputs(tokenizer, nlu_t1, hds1)
        max_seq_length_in_batch = max(max_seq_length_in_batch, len(tokens1))

    for b, nlu_t1 in enumerate(nlu_t):
        hds1 = hds[b]
        n_hs.append(len(nlu_t1))
        l_hs.append(len(hds1))

        tokens1, segment_ids1, n_hds, i_hds1 = generate_plm_inputs(tokenizer, nlu_t1, hds1,
                                                                   max_seq_length=max_seq_length,
                                                                   hard_prompt=hard_prompt[b] if hard_prompt is not None else hard_prompt)

        input_ids1 = tokenizer.convert_tokens_to_ids(tokens1)
        input_mask1 = [1] * len(input_ids1)

        while len(input_ids1) < max_seq_length_in_batch:
            input_ids1.append(0)
            input_mask1.append(0)
            segment_ids1.append(0)

        if len(input_ids1) > max_seq_length:
            input_ids1 = input_ids1[:max_seq_length]
            input_mask1 = input_mask1[:max_seq_length]
            segment_ids1 = segment_ids1[:max_seq_length]

        assert len(input_ids1) == min(max_seq_length_in_batch, max_seq_length)
        assert len(input_mask1) == min(max_seq_length_in_batch, max_seq_length)
        assert len(segment_ids1) == min(max_seq_length_in_batch, max_seq_length)

        input_ids.append(input_ids1)
        tokens.append(tokens1)
        segment_ids.append(segment_ids1)
        input_mask.append(input_mask1)

        i_nlu.append(n_hds)
        i_hds.append(i_hds1)


    all_input_ids = torch.tensor(input_ids, dtype=torch.long).to(device)
    all_input_mask = torch.tensor(input_mask, dtype=torch.long).to(device)
    # all_segment_ids = torch.tensor(segment_ids, dtype=torch.long).to(device)

    plm_output = plm_model(all_input_ids, all_input_mask, return_dict=True)

    l_hpu = gen_l_hpu(i_hds)
    l_n = gen_l_hpu(i_nlu)

    return plm_output.last_hidden_state, plm_output.pooler_output, tokens, i_nlu, i_hds, \
           l_n, n_hs, l_hpu, l_hs, \
           nlu_tt, t_to_tt_idx, tt_to_t_idx

def get_wemb_avg_list(i_nlu, l_n, hS, last_hidden_state):
    """
    Get the representation of each tokens.
    """
    bS = len(l_n)
    l_n_max = max(l_n)
    wemb_n = torch.zeros([bS, l_n_max, hS]).to(device)
    for b, i_hds1 in enumerate(i_nlu):
        for b1, i_hds11 in enumerate(i_hds1):
            wemb_n[b, b1, :] = torch.mean(last_hidden_state[b, i_hds11[0]:i_hds11[1], :], dim=0)
    return wemb_n

def get_wemb_n(i_nlu, l_n, hS, last_hidden_state):
    """
    Get the representation of each tokens.
    """
    bS = len(l_n)
    l_n_max = max(l_n)
    wemb_n = torch.zeros([bS, l_n_max, hS]).to(device)
    for b in range(bS):
        # [B, max_len, dim]
        # Fill zero for non-exist part.
        l_n1 = l_n[b]
        i_nlu1 = i_nlu[b]

        wemb_n[b, 0:(i_nlu1[1] - i_nlu1[0]), :] = last_hidden_state[b, i_nlu1[0]:i_nlu1[1], :]
    return wemb_n
    
def get_wemb_h(i_hds, l_hpu, l_hs, hS, last_hidden_state):
    """
    As if
    [ [table-1-col-1-tok1, t1-c1-t2, ...],
       [t1-c2-t1, t1-c2-t2, ...].
       ...
       [t2-c1-t1, ...,]
    ]
    """
    bS = len(l_hs)
    l_hpu_max = max(l_hpu)
    num_of_all_hds = sum(l_hs)
    wemb_h = torch.zeros([num_of_all_hds, l_hpu_max, hS]).to(device)
    b_pu = -1
    for b, i_hds1 in enumerate(i_hds):
        for b1, i_hds11 in enumerate(i_hds1):
            b_pu += 1
            wemb_h[b_pu, 0:(i_hds11[1] - i_hds11[0]), :] \
                    = last_hidden_state[b, i_hds11[0]:i_hds11[1],:]
    return wemb_h

def gen_l_hpu(i_hds):
    """
    # Treat columns as if it is a batch of natural language utterance with batch-size = # of columns * # of batch_size
    i_hds = [(17, 18), (19, 21), (22, 23), (24, 25), (26, 29), (30, 34)])
    """
    l_hpu = []
    for i_hds1 in i_hds:
        for i_hds11 in i_hds1:
            l_hpu.append(i_hds11[1] - i_hds11[0])

    return l_hpu

def generate_perm_inv(perm):
    # Definitly correct.
    perm_inv = np.zeros(len(perm), dtype=np.int32)
    for i, p in enumerate(perm):
        perm_inv[int(p)] = i

    return perm_inv

def mask_seq(seq, seq_lens):
    """ users are resposible for shaping
    Return: tensor_type [B, T]
    """
    mask = torch.zeros_like(seq)
    for i, l in enumerate(seq_lens):
        mask[i, :l].fill_(1)
    return mask

def max_pooling_by_lens(seq, seq_lens):
    mask = mask_seq(seq, seq_lens)
    seq = seq.masked_fill(mask == 0, -1e18)
    return seq.max(dim=1)[0]

def encode_hpu(wemb_hpu, l_hpu, l_hs):

    wenc_hpu = max_pooling_by_lens(wemb_hpu, l_hpu)
    wenc_hpu = wenc_hpu.unsqueeze(1)

    wenc_hpu = wenc_hpu.squeeze(1)
    hS = wenc_hpu.size(-1)

    wenc_hs = wenc_hpu.new_zeros(len(l_hs), max(l_hs), hS)
    wenc_hs = wenc_hs.to(device)

    # Re-pack according to batch.
    # ret = [B_NLq, max_len_headers_all, dim_lstm]
    st = 0
    for i, l_hs1 in enumerate(l_hs):
        wenc_hs[i, :l_hs1] = wenc_hpu[st:(st + l_hs1)]
        st += l_hs1
    return wenc_hs