import torch
from typing import Optional, Tuple, List, Union
import warnings
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
import transformers
from transformers.models.llama.modeling_llama import *

def _prepare_decoder_attention_mask_inference(
    self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
    # [bsz, seq_len]
    if past_key_values_length > 0 and attention_mask is not None:
        attention_mask = torch.cat(
            (
                torch.full(
                    (input_shape[0], past_key_values_length),
                    True,
                    dtype=attention_mask.dtype,
                    device=attention_mask.device,
                ),
                attention_mask,
            ),
            dim=-1,
        )
 
    if attention_mask is not None and torch.all(attention_mask):
        return None  # This uses the faster call when training with full samples
 
    return attention_mask


def forward_flashattn_inference_spliced(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.Tensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    if output_attentions:
        warnings.warn(
            "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
        )
    # print("forward_flashattn_inference")
    bsz, q_len, hidden_dim = hidden_states.size()

    if "ne_inf" not in self.__dict__:
        self.ne_inf = -100000
        self.topk = self.config.to_dict().get("topk", -1)
        self.topk_from_layer = self.config.to_dict().get("topk_from_layer", 0)
 
    cos, sin = self.rotary_emb(hidden_states, seq_len=q_len)
 
    attn_out = torch.zeros_like(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
    for head in range(self.num_heads):
        act_num_heads = self.num_heads // self.num_key_value_groups
        part_q = F.linear(hidden_states, self.q_proj.weight.view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
        part_k = F.linear(hidden_states, self.k_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2)
        part_v = F.linear(hidden_states, self.v_proj.weight.view(act_num_heads, self.head_dim, hidden_dim)[head // self.num_key_value_groups]).unsqueeze(2)
        part_q, part_k = apply_rotary_pos_emb(part_q.transpose(1, 2), part_k.transpose(1, 2), cos, sin, position_ids)
        # print(part_q.shape, attn_out.shape)
        # (1,1,4997,128), (1,4997,32,128)

        # if self.topk != -1 and self.layer_idx >= self.topk_from_layer:
        #     q_idxs = part_q.abs().sum(-1)[:,0]
        #     # q_idxs = q_idxs.random_(0, 100)
        #     q_idxs[:, :100] = -self.ne_inf
        #     q_idxs[:, -500:] = -self.ne_inf
        #     top_k_values, _ = torch.topk(q_idxs, self.topk, dim=1, sorted=True)
        #     # print(top_k_values)
        #     threshold = top_k_values[:, -1]
        #     selected_mask = q_idxs > threshold.unsqueeze(-1)
        #     # print(selected_mask)
        #     part_q = part_q[:,:, selected_mask[0]]
        #     # part_k = part_k[:,:, selected_mask[0]]
        #     # part_v = part_v[:,selected_mask[0]]

        part_q = part_q.transpose(1, 2)
        part_k = part_k.transpose(1, 2)
        part_o = flash_attn_func(part_q, part_k, part_v, 0.0, softmax_scale=None, causal=True).view(bsz, part_q.shape[1], self.head_dim)
        # if self.topk != -1 and self.layer_idx >= self.topk_from_layer:
        #     attn_out[:, selected_mask[0], head, :] = part_o
        #     # print((attn_out[:, :, head, :].sum(-1) == 0).sum(-1))
        # else:
        #     attn_out[:, :, head, :] = part_o
        attn_out[:, :, head, :] = part_o
    # (B, T, H, C)

    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    query_states, _ = apply_rotary_pos_emb(query_states, query_states, cos, sin, position_ids)
    if self.layer_idx >= 3:
        q_idxs = query_states.abs().sum(-1)
        # q_idxs = q_idxs.random_(0, 100)
        q_idxs[:,:,:100] = -self.ne_inf
        q_idxs[:,:,-500:] = -self.ne_inf
        q_idx = (-q_idxs).topk(self.topk, -1).indices.unsqueeze(-1)
        dim=2
        attn_out = attn_out.transpose(1, 2)
        attn_out = attn_out.scatter(2, q_idx.expand(*attn_out.shape[:dim], q_idx.shape[dim], *attn_out.shape[dim + 1 :]), 0)
        attn_out = attn_out.transpose(1, 2)
    
    # print((attn_out.sum(-1) ==0).sum(-2))
 
    torch.matmul(attn_out.reshape(bsz, q_len, hidden_dim), self.o_proj.weight.T, out=hidden_states)
    return (hidden_states, None, None)


def forward_llama_for_causal_lm(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
    # assert labels is not None
 
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    torch.cuda.empty_cache()
 
    hidden_states = outputs[0]
    if labels is not None:
        loss_fct = CrossEntropyLoss(reduction='sum')
        valid_seq_len = input_ids.shape[-1] - 1
        valid_seq_len_slide_win = torch.sum(labels[:, 1:] >= 0).item()
        # print("valid_seq_len_slide_win", valid_seq_len)
        loss = 0.0
    
        for start_idx in range(0, valid_seq_len, 16384):
            end_idx = min(start_idx + 16384, valid_seq_len)
            shift_logits = self.lm_head(hidden_states[..., start_idx:end_idx, :]).float()
            shift_labels = labels[..., start_idx + 1:end_idx + 1].contiguous()
            # Flatten the tokens
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss += loss_fct(shift_logits, shift_labels)
        
        loss /= valid_seq_len_slide_win
        logits = None
    else:
        logits = self.lm_head(hidden_states[:,-1:]).float()
        loss = None
 
    return CausalLMOutputWithPast(loss=loss, logits=logits)
 
 
def forward_llama_model(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    assert not output_attentions
    assert not output_hidden_states
    # assert not use_cache
 
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    elif input_ids is not None:
        batch_size, seq_length = input_ids.shape
    elif inputs_embeds is not None:
        batch_size, seq_length, _ = inputs_embeds.shape
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")
 
    seq_length_with_past = seq_length
    past_key_values_length = 0
 
    if past_key_values is not None:
        past_key_values_length = past_key_values[0][0].shape[2]
        seq_length_with_past = seq_length_with_past + past_key_values_length
 
    if position_ids is None:
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        position_ids = torch.arange(
            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
        )
        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
    else:
        position_ids = position_ids.view(-1, seq_length).long()
 
    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)
    # embed positions
    if attention_mask is None:
        attention_mask = torch.ones(
            (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
        )
        padding_mask = None
    else:
        if 0 in attention_mask:
            padding_mask = attention_mask
        else:
            padding_mask = None
 
    attention_mask = self._prepare_decoder_attention_mask(
        attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
    )
 
    hidden_states = inputs_embeds
 
    assert not(self.gradient_checkpointing and self.training)
 
    all_self_attns = None
    all_hidden_states = None
 
    for idx, decoder_layer in enumerate(self.layers):
 
        past_key_value = past_key_values[idx] if past_key_values is not None else None
 
        layer_outputs = decoder_layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            padding_mask=padding_mask,
        )
 
        hidden_states = layer_outputs[0]
 
    batch, seq_len, embed_dim = hidden_states.shape
    for start_idx in range(0, seq_len, 16384):
        end_idx = min(seq_len, start_idx + 16384)
        # print(start_idx, end_idx)
        hidden_states[:, start_idx:end_idx, :] = self.norm(hidden_states[:, start_idx:end_idx, :])
 
    next_cache = None
    if not return_dict:
        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )
 
 
def forward_llama_decoder_layer(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """
 
    residual = hidden_states.clone()
    batch, seq_len, embed_dim = hidden_states.shape
 
    for start_idx in range(0, seq_len, 16384):
        end_idx = min(seq_len, start_idx + 16384)
        hidden_states[:, start_idx:end_idx, :] = self.input_layernorm(hidden_states[:, start_idx:end_idx, :])

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        padding_mask=padding_mask,
    )
    hidden_states = residual + hidden_states
 
    # Fully Connected
    for start_idx in range(0, seq_len, 16384):
        end_idx = min(seq_len, start_idx + 16384)
        part_hidden_states = hidden_states[:, start_idx:end_idx, :].clone()
        part_hidden_states = self.post_attention_layernorm(part_hidden_states)
        part_hidden_states = self.mlp(part_hidden_states)
        hidden_states[:, start_idx:end_idx, :] += part_hidden_states
 
    outputs = (hidden_states,)
 
    if output_attentions:
        outputs += (self_attn_weights,)
 
    if use_cache:
        outputs += (present_key_value,)
 
    return outputs


transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inference
transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = forward_flashattn_inference_spliced
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = forward_llama_for_causal_lm
transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = forward_llama_decoder_layer
transformers.models.llama.modeling_llama.LlamaModel.forward = forward_llama_model