from typing import Union, Mapping, Any
from omegaconf import DictConfig
import torch


def perturb_model(model: torch.nn.Module, dist=1e-3, w=None, V=None, perturb_strat='max_chaos', model_id=None):
    with torch.no_grad():
        if V is not None:
            perturb_dict = {
                'max_chaos': V[:, -1],
                'randn_chaos': V[:, w > 1]@torch.randn(int((w > 1).sum()), device=V.device),
                'max_convergence': V[:, 0],
                'randn_convergence': V[:, w < 1]@torch.randn(int((w < 1).sum()), device=V.device),
                'randn': torch.randn_like(V[:, 0])
            }
            if perturb_strat != 'all':
                direction = perturb_dict[perturb_strat]
                direction /= direction.norm()
            else:
                strat = list(perturb_dict.keys())[model_id % 5]
                direction = perturb_dict[strat]
                direction /= direction.norm()
                perturb_strat = strat
        param_id = 0
        for p in model.parameters():
            if V is None:
                eps_ = torch.randn(p.shape, device=p.device)
                eps = dist * eps_ / eps.norm()
            else:
                size = p.shape.numel()
                eps = dist * direction[param_id: param_id + size].reshape(p.shape)
                eps = eps.to(p.device)
                param_id += size
            p.data = p.data + eps
    
    return dist, perturb_strat


def _rec_log_dictconf(run, conf, key='', sep='/'):
    if isinstance(conf, dict) or isinstance(conf, DictConfig):
        for k in conf.keys():
            _rec_log_dictconf(run, conf[k], key=f'{key}{sep}{k}', sep=sep)
    else:
        run[key] = conf
        

def log_dictconf(run, conf: Union[DictConfig, Mapping[str, Any]], path='parameters'):
     _rec_log_dictconf(run, conf, key=path)
          

def lstrip_multiline(x):
    lines = x.split('\n')
    lines_stripped = [l.lstrip() for l in lines]
    return '\n'.join(lines_stripped)
