import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.geometry import poses_to_rays, rays_to_plucker

from models.ginr_attention.modules.data_encoders.image_encoder import Unfold


class MultiviewEncoder(nn.Module):
    def __init__(self, config, use_plucker_coordinate=False):
        super().__init__()
        self.config = config
        self.type = config.type
        self.trainable = config.trainable
        self.input_dim = config.n_channel
        self.output_dim = None

        self.use_plucker_coordinate = use_plucker_coordinate

        spec = config.encoder_spec
        if self.type == "unfold":
            self.encoder = Unfold(spec.patch_size, spec.padding)
            self.output_dim = self.input_dim * np.product(self.encoder.patch_size)
            self.is_encoder_out_channels_last = False
        else:
            # If necessary, implement additional wrapper for extracting features of data
            raise NotImplementedError

        if not self.trainable:
            for p in self.parameters():
                p.requires_grad_(False)

    def forward(self, support_imgs, support_poses, support_focals, put_channels_last=False):
        '''
        the start point and normalized direction are concatenated into the color channel or images,
        and then the output images have nine-channels.
        '''
        batch_size = support_imgs.shape[0]
        height, width = support_imgs.shape[-2:]

        rays_o, rays_d = poses_to_rays(support_poses, support_focals, height, width)
        if self.use_plucker_coordinate:
            rays_coord = rays_to_plucker(rays_o, rays_d)  # (b n h w 6)
        else:
            rays_coord = torch.cat([rays_o, rays_d], dim=-1)  # (b n h w 6)

        rays_coord = einops.rearrange(rays_coord, "b n h w c -> b n c h w")
        xs = torch.cat([support_imgs, rays_coord], dim=2)  # channel-wise concatenation
        xs = einops.rearrange(xs, "b n d h w -> (b n) d h w")
        xs_embed = self.encoder(xs)
        xs_embed = einops.rearrange(xs_embed, "(b n) ppd l -> b (n l) ppd", b=batch_size)

        if put_channels_last and not self.is_encoder_out_channels_last:
            # here, we have used einops.rearrange to consider channel_last type by default
            return xs_embed
        else:
            permute_idx_range = [i for i in range(2, xs_embed.ndim)]
            return xs_embed.permute(0, *permute_idx_range, 1).contiguous()
