'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  linear_transform.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.utils.o3 import get_slices, get_shapes


class LinearTransform(nn.Module):
    """Simple linear transformation for irreducible Cartesian tensors (Cartesian harmonics).
    It preserves their properties, i.e., the resulting tensors are symmetric and traceless.

    Args:
        in_l_max (int): Maximal rank of the input tensor.
        out_l_max (int): Maximal rank of the output tensor.
        in_features (int): Numbers of features in the input tensor.
        out_features (int): Number of features in the output tensor.
        in_paths (List[int], optional): List of paths used to generate irreducible Cartesian tensors 
                                        contained in the input tensor. Defaults to None (in this case 
                                        a list of length `in1_l_max+1` filled with ones is produced).
        bias (bool, optional): If True, apply bias to scalars. Defaults to False.
    """
    def __init__(self,
                 in_l_max: int,
                 out_l_max: int,
                 in_features: int, 
                 out_features: int,
                 in_paths: Optional[List[int]] = None,
                 bias: bool = False):
        super(LinearTransform, self).__init__()
        self.in_l_max = in_l_max
        self.out_l_max = out_l_max
        self.in_features = in_features
        self.out_features = out_features
        
        # define the number of paths used to compute Cartesian harmonics in the input tensor
        if in_paths is None:
            self.in_paths = [1 for _ in range(in_l_max + 1)]
        else:
            self.in_paths = in_paths
        assert len(self.in_paths) == in_l_max + 1
        
        # slices and shapes for tensors of rank l in the flattened input tensor
        self.in_slices = get_slices(in_l_max, in_features, self.in_paths)
        self.in_shapes = get_shapes(in_l_max, in_features, self.in_paths, use_prod=True)
        
        # dimensions of the input tensors for sanity checks
        self.in_dim = sum([(3 ** l) * in_features * self.in_paths[l] for l in range(in_l_max + 1)])
        
        # define normalization
        self.alpha = [(in_features * n_paths) ** (-0.5) for n_paths in self.in_paths]
        
        # define weight and bias
        self.weight = torch.nn.ParameterList([])
        for n_paths in self.in_paths[:self.out_l_max+1]:
            self.weight.append(nn.Parameter(torch.randn(out_features, in_features * n_paths)))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_buffer('bias', None)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies linear transformation to the input tensor `x`. This tensor must contain flattened 
        irreducible Cartesian tensors, i.e., Cartesian harmonics that are accessed using pre-computed 
        slices and shapes.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor containing irreducible Cartesian tensors after the linear 
                          transformation.
        """
        torch._assert(x.shape[-1] == self.in_dim, 'Incorrect last dimension for x.')
        
        # x shape: n_neighbors x (3 x ... x l-times x ... x 3) x in_features
        x_0 = x[:, self.in_slices[0]].view(x.shape[0], *self.in_shapes[0])
        if self.out_l_max > 0: x_1 = x[:, self.in_slices[1]].view(x.shape[0], *self.in_shapes[1])
        if self.out_l_max > 1: x_2 = x[:, self.in_slices[2]].view(x.shape[0], *self.in_shapes[2])
        if self.out_l_max > 2: x_3 = x[:, self.in_slices[3]].view(x.shape[0], *self.in_shapes[3])
        
        # apply linear transformation to l=0
        x_0 = F.linear(x_0, self.weight[0] * self.alpha[0], self.bias)
        if self.out_l_max == 0:
            return x_0
        
        # apply linear transformation to l=1
        x_1 = F.linear(x_1, self.weight[1] * self.alpha[1], None)
        if self.out_l_max == 1:
            return torch.cat([x_0, 
                              x_1.view(x.shape[0], 3 * self.out_features)], -1)
        
        # apply linear transformation to l=2
        x_2 = F.linear(x_2, self.weight[2] * self.alpha[2], None)
        if self.out_l_max == 2:
            return torch.cat([x_0, 
                              x_1.view(x.shape[0], 3 * self.out_features),
                              x_2.view(x.shape[0], (3 ** 2) * self.out_features)], -1)
        
        # apply linear transformation to l=3
        x_3 = F.linear(x_3, self.weight[3] * self.alpha[3], None)
        if self.out_l_max == 3:
            return torch.cat([x_0, 
                              x_1.view(x.shape[0], 3 * self.out_features),
                              x_2.view(x.shape[0], (3 ** 2) * self.out_features),
                              x_3.view(x.shape[0], (3 ** 3) * self.out_features)], -1)
            
    def __repr__(self) -> str:
        return (f"{self.__class__.__name__} ({self.in_l_max} -> {self.out_l_max} | {self.in_paths[:self.out_l_max+1]} -> {[1 for _ in range(self.out_l_max+1)]} paths | {sum([w.numel() for w in self.weight])} weights)")
