'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  cartesian_harmonics.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import torch
import torch.nn as nn

import torch.fx
import opt_einsum_fx


L_MAX = 3
BATCH_SIZE = 10


class CartesianHarmonics(nn.Module):
    """Computes irreducible Cartesian tensors, sometimes referred to as Cartesian harmonics.

    Args:
        l_max (int): Maximal rank of Cartesian harmonics.

    Note:
        1. In the current implementation, all Cartesian harmonics are built using unit vectors.
        2. We normalize Cartesian harmonics of rank `l` such that their products with the respective 
           unit vector reduce the rank of these tensors  by one, i.e, to `l-1`. Applying `l` unit 
           vectors to a Cartesian harmonics of rank `l` yields unity.
    """
    def __init__(self, l_max: int):
        super(CartesianHarmonics, self).__init__()
        self.l_max = l_max
        assert self.l_max <= L_MAX, 'Cartesian harmonics are implemented for l <= 3.'
        
        # define identity matrix
        self.register_buffer('eye', torch.eye(3))
        
        # trace and optimize contractions
        self.contractions = torch.nn.ModuleList()
        
        if self.l_max > 1:
            contraction_eq = 'ai,aj->aij'
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum(contraction_eq, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr,
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3),
                                                                                 torch.randn(BATCH_SIZE, 3)))
            self.contractions.append(contraction_op)
            
        if self.l_max > 2:
            contraction_eq = 'ai,jk->aijk'
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum(contraction_eq, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3),
                                                                                 torch.eye(3)))
            self.contractions.append(contraction_op)
            
            contraction_eq = 'aij,ak->aijk'
            contraction_tr = torch.fx.symbolic_trace(lambda x, y: torch.einsum(contraction_eq, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(torch.randn(BATCH_SIZE, 3, 3),
                                                                                 torch.randn(BATCH_SIZE, 3)))
            self.contractions.append(contraction_op)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Computes Cartesian harmonics/irreducible Cartesian tensors using the input unit vector.

        Args:
            x (torch.Tensor): Input unit vector.

        Returns:
            torch.Tensor: Cartesian harmonics.
        """
        x = nn.functional.normalize(x, dim=-1)
        
        # l=0, shape: n_neighbors x 1
        ch_0 = torch.ones(*x.shape[:1], device=x.device).unsqueeze(-1)
        if self.l_max == 0:
            return ch_0
        
        # l=1, shape: n_neighbors x 3
        ch_1 = x
        if self.l_max == 1:
            return torch.cat([ch_0, 
                              ch_1], -1)
        
        # l=2, shape: n_neighbors x 3 x 3
        x_x = self.contractions[0](x, x)
        ch_2 = 3. / 2. * (x_x - 1. / 3. * self.eye.unsqueeze(0))
        if self.l_max == 2:
            return torch.cat([ch_0, 
                              ch_1, 
                              ch_2.view(-1, 3 ** 2)], -1)
        
        # l=3, shape: n_neighbors x 3 x 3 x 3
        x_e = self.contractions[1](x, self.eye)
        x_e = x_e + x_e.permute(0, 2, 3, 1) + x_e.permute(0, 3, 1, 2)
        x_x_x = self.contractions[2](x_x, x)
        ch_3 = 5. / 2. * (x_x_x - 1. / 5. * x_e)
        if self.l_max == 3:
            return torch.cat([ch_0, 
                              ch_1, 
                              ch_2.view(-1, 3 ** 2), 
                              ch_3.view(-1, 3 ** 3)], -1)
            
    def __repr__(self):
        return f'{self.__class__.__name__}(l_max={self.l_max})'
