import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.models as cnn_models
import torch.distributions.normal as t_normal
import random
# from .discriminator import Discriminator
from collections import OrderedDict
from torchmeta.modules import MetaSequential, MetaLinear

from metamodules import FCBlock, BatchLinear, HyperNetwork, get_subdict
from encoding import get_encoder
from activation import trunc_exp
from .renderer import NeRFRenderer
from .utils import get_rays
from .clip_utils import CLIPLoss

# spherical linear interpolation (slerp)
def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        # L'Hopital's rule/LERP
        return (1.0-val) * low + val * high
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high

# uniform interpolation between two points in latent space
def interpolate_points(p1, p2):
    # interpolate ratios between the points
    ratios = np.linspace(0, 1, num=50)
    # linear interpolate vectors
    vectors = list()
    for ratio in ratios:
        v = slerp(ratio, p1, p2)
        vectors.append(v)
    return vectors


class NeRFNetwork(NeRFRenderer):
    def __init__(self,
                 encoding="hashgrid",
                 encoding_dir="sphere_harmonics",
                 encoding_bg="hashgrid",
                 num_layers=2,
                 hidden_dim=64,
                 geo_feat_dim=15,
                 num_layers_color=3,
                 hidden_dim_color=64,
                 num_layers_bg=2,
                 hidden_dim_bg=64,
                 bound=1,
                 encoder=None,
                 encoder_in_dim=None,
                 encoder_dir=None,
                 encoder_in_dim_dir=None,
                 log2_hashmap_size = 11,
                 **kwargs,
                 ):
        super().__init__(bound, **kwargs)

        # sigma network
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim
        self.encoder, self.in_dim = get_encoder(
            encoding, desired_resolution=2048 * bound,log2_hashmap_size=log2_hashmap_size)

        sigma_net = MetaSequential()
        for l in range(num_layers):
            if l == 0:
                in_dim = self.in_dim
            else:
                in_dim = hidden_dim

            if l == num_layers - 1:
                out_dim = 1 + self.geo_feat_dim  # 1 sigma + 15 SH features for color
            else:
                out_dim = hidden_dim

            sigma_net.add_module(f"layer_{l}", BatchLinear(
                in_dim, out_dim, bias=False))
            if l != num_layers - 1:
                sigma_net.add_module(f"act_{l}", nn.ReLU(inplace=True))

        self.sigma_net = MetaSequential(sigma_net)

        # color network
        self.num_layers_color = num_layers_color
        self.hidden_dim_color = hidden_dim_color
        self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir)

        color_net = MetaSequential()
        for l in range(num_layers_color):
            if l == 0:
                in_dim = self.in_dim_dir + self.geo_feat_dim
            else:
                in_dim = hidden_dim_color

            if l == num_layers_color - 1:
                out_dim = 3  # 3 rgb
            else:
                out_dim = hidden_dim_color

            color_net.add_module(f"layer_{l}", BatchLinear(
                in_dim, out_dim, bias=False))
            if l != num_layers_color - 1:
                color_net.add_module(f"act_{l}", nn.ReLU(inplace=True))

        self.color_net = MetaSequential(color_net)

        # background network
        if self.bg_radius > 0:
            self.num_layers_bg = num_layers_bg
            self.hidden_dim_bg = hidden_dim_bg
            self.encoder_bg, self.in_dim_bg = get_encoder(
                encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048)  # much smaller hashgrid

            bg_net = []
            for l in range(num_layers_bg):
                if l == 0:
                    in_dim = self.in_dim_bg + self.in_dim_dir
                else:
                    in_dim = hidden_dim_bg

                if l == num_layers_bg - 1:
                    out_dim = 3  # 3 rgb
                else:
                    out_dim = hidden_dim_bg

                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.bg_net = nn.ModuleList(bg_net)
        else:
            self.bg_net = None

    def forward(self, x, d):
        # sigma
        x = self.encoder(x, bound=self.bound)
        
        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)

        sigma = trunc_exp(h[..., 0])
        geo_feat = h[..., 1:]

        # color
        d = self.encoder_dir(d)
        h = torch.cat([d, geo_feat], dim=-1)
        for l in range(self.num_layers_color):
            h = self.color_net[l](h)
            if l != self.num_layers_color - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        color = torch.sigmoid(h)

        return sigma, color

    def density(self, x, params=None, idx=None):
        # x: [N, 3], in [-bound, bound]

        x = self.encoder(x, bound=self.bound, idx=idx,params=get_subdict(params,'encoder.embeddings'))
        h = x
        
        h = self.sigma_net(h, params=get_subdict(params,'sigma_net'))
        
        sigma = trunc_exp(h[..., 0])
        geo_feat = h[..., 1:]

        return {
            'sigma': sigma,
            'geo_feat': geo_feat,
        }

    def background(self, x, d):
        # x: [N, 2], in [-1, 1]

        h = self.encoder_bg(x)  # [N, C]
        d = self.encoder_dir(d)

        h = torch.cat([d, h], dim=-1)
        for l in range(self.num_layers_bg):
            h = self.bg_net[l](h)
            if l != self.num_layers_bg - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        rgbs = torch.sigmoid(h)

        return rgbs

    # allow masked inference
    def color(self, x, d, mask=None, geo_feat=None, params=None, **kwargs):
        # x: [N, 3] in [-bound, bound]
        # mask: [N,], bool, indicates where we actually needs to compute rgb.

        if mask is not None:
            rgbs = torch.zeros(
                mask.shape[0], 3, dtype=x.dtype, device=x.device)  # [N, 3]
            # in case of empty mask
            if not mask.any():
                return rgbs
            x = x[mask]
            d = d[mask]
            geo_feat = geo_feat[mask]

        d = self.encoder_dir(d)
        h = torch.cat([d, geo_feat], dim=-1)
        h = self.color_net(h, params=get_subdict(params,'color_net'))

        # sigmoid activation for rgb
        h = torch.sigmoid(h)

        if mask is not None:
            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32
        else:
            rgbs = h

        return rgbs

    # optimizer utils
    def get_params(self, lr):

        params = [
            {'params': self.encoder.parameters(), 'lr': lr},
            {'params': self.sigma_net.parameters(), 'lr': lr},
            {'params': self.encoder_dir.parameters(), 'lr': lr},
            {'params': self.color_net.parameters(), 'lr': lr},
        ]
        if self.bg_radius > 0:
            params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
            params.append({'params': self.bg_net.parameters(), 'lr': lr})

        return params

def orthonormalize_basis(basis):
    """
    Returns orthonormal basis vectors
    basis - B, 3, 3

    out - B, 3, 3
    """
    u, s, v = torch.svd(basis)
    out = u @ v.transpose(-2, -1)    

    return out

def gram_schmidt(vv):
    def projection(u, v):
        return (v * u).sum() / (u * u).sum() * u

    nk = vv.size(0)
    uu = torch.zeros_like(vv, device=vv.device)
    uu[:, 0] = vv[:, 0].clone()
    for k in range(1, nk):
        vk = vv[k].clone()
        uk = 0
        for j in range(0, k):
            uj = uu[:, j].clone()
            uk = uk + projection(uj, vk)
        uu[:, k] = vk - uk
    for k in range(nk):
        uk = uu[:, k].clone()
        uu[:, k] = uk / uk.norm()
    return uu

class NeRFGen(nn.Module):
    '''A NeRF Generation Network.'''

    def __init__(self, opt, num_instances=1, mode='nerf', type='relu',
                 hn_hidden_features=512, hn_hidden_layers=1, hn_in=512, std=0.01, custom_hashhn=False, **kwargs):
        super().__init__()

        self.mode = mode
        self.num_instances = num_instances
        self.cuda_ray = opt.cuda_ray
        self.bg_radius = opt.bg_radius
        self.bound = opt.bound
        self.device = opt.device

        self.clipcondition = opt.clipcondition
        self.clip_mapping = opt.clip_mapping
        self.varprior = opt.varprior
        self.std = std

        if self.varprior:
            self.shape_code_mu = nn.Embedding(self.num_instances, hn_in, max_norm=1e6)
            self.shape_code_std = nn.Embedding(self.num_instances, hn_in, max_norm=1e6)

            self.color_code_mu = nn.Embedding(self.num_instances, hn_in, max_norm=1e6)
            self.color_code_std = nn.Embedding(self.num_instances, hn_in, max_norm=1e6)

            nn.init.normal_(self.shape_code_mu.weight, mean=0, std=std)
            nn.init.normal_(self.shape_code_std.weight, mean=0, std=std)
            
            nn.init.normal_(self.color_code_mu.weight, mean=0, std=std)
            nn.init.normal_(self.color_code_std.weight, mean=0, std=std)
        else:
            if self.clipcondition:
                self.clip_encoder = CLIPLoss()

                clip_dim = 512

                self.shape_code = nn.Embedding(self.num_instances, clip_dim)
                nn.init.normal_(self.shape_code.weight, mean=0, std=std)
                
                self.color_code = nn.Embedding(self.num_instances, clip_dim)
                nn.init.normal_(self.color_code.weight, mean=0, std=std)

                self.mergeclipinstance_shape = nn.Sequential(
                    nn.Linear(clip_dim*2, clip_dim),
                    nn.ReLU(True),
                    nn.Linear(clip_dim, hn_in))

                self.mergeclipinstance_color = nn.Sequential(
                    nn.Linear(clip_dim*2, clip_dim),
                    nn.ReLU(True),
                    nn.Linear(clip_dim, hn_in))
                
            else:
                self.shape_code = nn.Embedding(self.num_instances, hn_in)
                nn.init.normal_(self.shape_code.weight, mean=0, std=std)
                
                self.color_code = nn.Embedding(self.num_instances, hn_in)
                nn.init.normal_(self.color_code.weight, mean=0, std=std)

        if self.clip_mapping:
            self.clip_encoder = CLIPLoss()
            self.clip_fc_shape = nn.Sequential(
                nn.Linear(512 , 512),
                nn.ReLU(True),
                nn.Linear(512, 256),
                nn.ReLU(True),
                nn.Linear(256, 256),
                nn.ReLU(True),
                nn.Linear(256, hn_in),
            )
            self.clip_fc_color = nn.Sequential(
                nn.Linear(512 , 512),
                nn.ReLU(True),
                nn.Linear(512, 256),
                nn.ReLU(True),
                nn.Linear(256, 256),
                nn.ReLU(True),
                nn.Linear(256, hn_in),
            )

        self.net = NeRFNetwork(
            encoding="hashgrid",
            bound=opt.bound,
            cuda_ray=opt.cuda_ray,
            density_scale=1,
            min_near=opt.min_near,
            density_thresh=opt.density_thresh,
            bg_radius=opt.bg_radius,
        )

        self.hyper_net = HyperNetwork(hyper_in_features=hn_in,
            hyper_hidden_layers=hn_hidden_layers,
            hyper_hidden_features=hn_hidden_features,
            hypo_module=self.net,
            custom_hashhn=custom_hashhn) # keep custom_hashn false at all times, it did not work so well with custom_hashhn true.
        
        for name, param in self.net.named_parameters():
            if param.requires_grad:
                print(name)

    def init_from_learned_latents(self, shape_codes, color_codes):
        with torch.no_grad():
            for i in range(min(len(shape_codes), len(self.shape_code.weight))):
                self.shape_code.weight[i] = shape_codes[i]
                self.color_code.weight[i] = color_codes[i]

            print('Loaded latent codes from the previous checkpoint!')
            
    def kl_divergence(self, z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)*self.std)
        q = torch.distributions.Normal(mu, std)

        # 2. get the probabilities from the equation
        log_qzx = q.log_prob(z)
        log_pz = p.log_prob(z)

        # kl
        kl = (log_qzx - log_pz)
        
        # sum over last dim to go from single dim distribution to multi-dim
        kl = kl.mean(-1)
        return kl

    def sample(self, mu_block, std_block, idx): # will be called if varprior is set true
        mu = mu_block(idx)

        log_var = std_block(idx)
        std = torch.exp(log_var/2)

        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()

        kl = self.kl_divergence(z, mu, std)
        return z, kl

    def get_params(self, idx, input_dict):
        if self.varprior:
            z_shape, kld_shape = self.sample(self.shape_code_mu, self.shape_code_std, idx)
            z_color, kld_color = self.sample(self.color_code_mu, self.color_code_std, idx)
            
            return self.hyper_net(z_shape, z_color), (kld_shape+kld_color)
        else:
            if self.clipcondition:
                img_input = input_dict['img_original'].permute(0,-1,1,2)
                
                with torch.no_grad():
                    z_clip = self.clip_encoder(img_input, mode="image").float()

                z_shape_normal = self.shape_code(idx)
                z_color_normal = self.color_code(idx)

                z_shape_cat = torch.cat((z_clip, z_shape_normal), 1)
                z_color_cat = torch.cat((z_clip, z_color_normal), 1)

                z_shape = self.mergeclipinstance_shape(z_shape_cat)
                z_color = self.mergeclipinstance_color(z_color_cat)
            else:
                z_shape = self.shape_code(idx)
                z_color = self.color_code(idx)

            return self.hyper_net(z_shape, z_color)

    def run_clip_mapping(self, idx, input_dict):
        img_input = input_dict['img_original'].permute(0,-1,1,2)

        with torch.no_grad():
            clip_embedding = self.clip_encoder(img_input, mode="image").float()

        pred_shape = self.clip_fc_shape(clip_embedding)
        pred_color = self.clip_fc_color(clip_embedding)

        z_shape = self.shape_code(idx)
        z_color = self.color_code(idx)
        
        return {
            'pred_shape': pred_shape,
            'pred_color': pred_color,
            'shape_code': z_shape,
            'color_code': z_color
        }
    
    def forward(self, idx, input_dict, rays_o, rays_d, staged=False, bg_color=None, perturb=True, force_all_rays=False, test_finally=False, **kwargs):
        if self.clip_mapping:
            clip_output = self.run_clip_mapping(idx, input_dict)

            pred_shape, pred_color = clip_output['pred_shape'], clip_output['pred_color']
            gt_shape, gt_color = clip_output['shape_code'], clip_output['color_code']

            with torch.no_grad():
                # using the predicted mappings to generate the data
                pred_params = self.hyper_net(pred_shape, pred_color) 

                pred_rendered_output = self.net.render(rays_o, rays_d, staged=staged, bg_color=bg_color, perturb=perturb, force_all_rays=force_all_rays,params=pred_params,idx=idx, **kwargs)

                # using the hn mappings to generate the hn gt data
                gt_params = self.hyper_net(gt_shape, gt_color) 

                gt_rendered_output = self.net.render(rays_o, rays_d, staged=staged, bg_color=bg_color, perturb=perturb, force_all_rays=force_all_rays,params=gt_params,idx=idx, **kwargs)
            
            return [pred_rendered_output, gt_rendered_output, clip_output] 
        else:
            outputs = self.get_params(idx, input_dict)

            if self.varprior:
                outputs, kld = outputs

            nerf = self.net.render(rays_o, rays_d, staged=staged, bg_color=bg_color, perturb=perturb, force_all_rays=force_all_rays, params=outputs, idx=idx, **kwargs)
            
            return (nerf, 1e-6*kld) if self.varprior else nerf

class BasicBlock(nn.Module):
    def __init__(self, in_size, dims=[1024, 512, 256]):
        super().__init__()
        self.relu = nn.ReLU()
        # self.conv1 = nn.Conv1d(1, 32, 1)
        # self.conv2 = nn.Conv1d(32, 64, 1)
        # self.conv3 = nn.Conv1d(64, 32, 1)
        # self.conv4 = nn.Conv1d(32, 1, 1)
        self.conv1 = nn.Conv1d(1, 1, 1)

    def forward(self, x):
        x = x.view(1, 1, -1)
        # x = self.relu(self.conv1(x))
        # x = self.relu(self.conv2(x))
        # x = self.relu(self.conv3(x))
        # x = self.relu(self.conv4(x))

        x = self.conv1(x)

        return x.view(-1)

# class BasicBlock(nn.Module):
#     def __init__(self, in_size, dims=[1024, 512, 256]):
#         super().__init__()
#         self.downsample = nn.ModuleList()
#         self.upsample = nn.ModuleList()
#         dims.insert(0, in_size)
#         self.num_hidden = len(dims)
#         self.relu = nn.ReLU()

#         for i in range(self.num_hidden-1):
#             self.downsample.append(nn.Linear(dims[i], dims[i+1]))

#         for i in reversed(range(self.num_hidden-1)):
#             self.upsample.append(nn.Linear(dims[i+1], dims[i]))

#     def forward(self, x):
#         outputs = [x]
#         for i in range(self.num_hidden-1):
#             x = self.relu(self.downsample[i](x))
#             outputs.insert(0, x)

#         for i in range(self.num_hidden-1):
#             x = self.relu(self.upsample[i](x))
#             x += outputs[i+1]

#         return x

class NeRFSuperresolution(nn.Module):
    def __init__(self, opt, hn_model, hash_in, hash_out, datalen, hn_hdn_lyrs=1, custom_hashhn=False):
        super().__init__()

        self.cuda_ray = opt.cuda_ray
        self.bg_radius = opt.bg_radius
        self.bound = opt.bound
        self.opt = opt

        self.hashgrid_key = "encoder-embeddings-weight"

        self.L = nn.ModuleDict({
            self.hashgrid_key: BasicBlock(hash_in, [1024, 512, 256]),
            "sigma_net-0-layer_0-weight": BasicBlock(2048, [512, 256]),
            "sigma_net-0-layer_1-weight": BasicBlock(1024, [512, 256]),
            "color_net-0-layer_0-weight": BasicBlock(1984, [512, 256]),
            "color_net-0-layer_1-weight": BasicBlock(4096, [512, 256]),
            "color_net-0-layer_2-weight": BasicBlock(192, [128, 64])
            # self.hashgrid_key: nn.Sequential(
            #     nn.Linear(hash_in, 128),
            #     nn.ReLU(inplace=True),
            #     nn.Linear(128, 128),
            #     nn.ReLU(inplace=True),
            #     nn.Linear(128, hash_out)
            # ),
            # "sigma_net-0-layer_0-weight": nn.Sequential(
            #     nn.Linear(2048, 128),
            #     nn.Linear(128, 2048)
            # ),
            # "sigma_net-0-layer_1-weight": nn.Sequential(
            #     nn.Linear(1024, 128),
            #     nn.Linear(128, 1024)
            # ),
            # "color_net-0-layer_0-weight": nn.Sequential(
            #     nn.Linear(1984, 128),
            #     nn.Linear(128, 1984)
            # ),
            # "color_net-0-layer_1-weight": nn.Sequential(
            #     nn.Linear(4096, 128),
            #     nn.Linear(128, 4096)
            # ),
            # "color_net-0-layer_2-weight": nn.Sequential(
            #     nn.Linear(192, 64),
            #     nn.Linear(64, 192)
            # )
        })

        self.hn_model = hn_model

        self.net = NeRFNetwork(
            encoding="hashgrid",
            bound=opt.bound,
            cuda_ray=opt.cuda_ray,
            density_scale=1,
            min_near=opt.min_near,
            density_thresh=opt.density_thresh,
            bg_radius=opt.bg_radius,
            log2_hashmap_size=11
        )

    def forward(self, idx, input_dict, rays_o, rays_d, staged=False, bg_color=None, perturb=True, force_all_rays=False,test_finally=False, **kwargs):
        with torch.no_grad():
            params = self.hn_model.get_params(idx, input_dict)

        updated_params = OrderedDict()

        for v in self.L:
            v_ = v.replace("-", ".")
            inp = params[v_]
            inp_shape = inp.shape
            output = self.L[v](inp.view(-1))
            
            if self.hashgrid_key == v:
                updated_params[v_] = output.reshape(1, 1, -1)
            else:
                updated_params[v_] = output.reshape(inp_shape)

        return self.net.render(rays_o, rays_d, staged=staged, bg_color=bg_color, perturb=perturb, force_all_rays=force_all_rays, params=updated_params, idx=idx, **kwargs)