'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  misc.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import json
import pickle
from pathlib import Path
from typing import List, Union, Any, Dict

import torch
import torch.nn as nn

import numpy as np

import yaml
from yaml import Dumper, Loader


def save_object(filename: Union[str, Path],
                obj: Any,
                use_json: bool = False,
                use_yaml: bool = False):
    file = open(filename, 'w' if (use_json or use_yaml) else 'wb')
    if use_json:
        json.dump(obj, file)
    elif use_yaml:
        yaml.dump(obj, file, Dumper=Dumper)
    else:
        pickle.dump(obj, file, protocol=3)
    file.close()


def load_object(filename: Union[str, Path],
                use_json: bool = False,
                use_yaml: bool = False) -> Any:
    file = open(filename, 'r' if (use_json or use_yaml) else 'rb')
    if use_json:
        result = json.load(file)
    elif use_yaml:
        result = yaml.load(file, Loader=Loader)
    else:
        result = pickle.load(file)
    file.close()
    return result


def padded_str(strs: List[str],
               lens: List[int]) -> str:
    # strs should be a list of strings and lens should be a list of integers of the same length.
    # This function concatenates the strings in strs but fills them up with whitespaces at the end such that they have
    # (at least) the corresponding lengths
    result = ''
    for s, l in zip(strs, lens):
        result = result + s + (' ' * max(0, l-len(s)))
    return result


def get_default_device() -> str:
    if torch.cuda.is_available():
        return 'cuda:0'
    return 'cpu'


def get_available_devices():
    if torch.cuda.is_available():
        return [f'cuda:{i}' for i in range(torch.cuda.device_count())]
    else:
        return ['cpu']
    
    
def set_default_dtype(default_dtype: str):
    dtype_dict = {'float32': torch.float32, 'float64': torch.float64}
    torch.set_default_dtype(dtype_dict[default_dtype])


def recursive_detach(inputs: Any) -> Any:
    if isinstance(inputs, list):
        return [recursive_detach(input) for input in inputs]
    elif isinstance(inputs, dict):
        return {key: recursive_detach(value) for key, value in inputs.items()}
    return inputs.detach()


def count_parameters(module: nn.Module) -> int:
    return int(sum(np.prod(p.shape) for p in module.parameters()))
