import numpy as np
import random
import torch

from collections import defaultdict
import copy

from torch.utils.data import RandomSampler, BatchSampler

from .models import *

"""
Helper functions needed throughout the code
"""

def get_layers_grad(model,layer_list,flatten=True):
    acc = []
    acc2 = []
    acc3 = []
    for name,param in model.named_parameters():
        for layer in layer_list:
            if layer in name and param.grad is not None: # found match AND has grad
                if flatten:
                    grad = param.grad.view(-1).clone().detach()
                else:
                    grad = param.grad.clone().detach()
                acc.append( grad )
                acc2.append(param)
                acc3.append(name)
    return acc, acc2, acc3

def partition_list_by_lists(data,groups):
    partition = [[] for _ in range(len(groups)+1)]
    for el in data:
        found = False
        for g,group in enumerate(groups):
            for kw in group:
                if kw in el:
                        partition[g].append(el)
                        found = True
        if not found:
                partition[-1].append(el)
    return partition

def compute_model_grad_norm(model):
    grad_norm = 0
    for i,p in enumerate(model.parameters()):
        if p.grad is not None:
                grad_norm += p.grad.data.norm()
    return grad_norm / (i+1)

def compute_network_output_size(h,w,kernels_h,kernels_w,strides_h,strides_w):
    for (k_h,k_w,s_h,s_w) in zip(kernels_h,kernels_w,strides_h,strides_w):
        h = (h-k_h) / s_h + 1
        w = (w-k_w) / s_w + 1
    return int(h) * int(w)

def filter_dict_by_dict(source,query):
    if not len(source):
        return False
    for k_query,v_query in query.items():
        if k_query in source:
            try:
                src = float(source[k_query])
            except:
                src = str(source[k_query])
            try:
                query = float(v_query)
            except:
                query = str(v_query)
            if src != query:
                return False
        else:
            return False
    return True

def select_architecture(args,class_list):
    architecture = args.architecture
    loss_fn = args.nce_loss
    data_aug = args.data_aug
    if architecture == 'Mnih':
        if loss_fn == 'CURL':
            return class_list['CURL']
        if 'action' in loss_fn and 'no_action' not in loss_fn:
           if data_aug:
                return class_list['infoNCE_Mnih_84x84_action_data_aug']
            else:
                return class_list['infoNCE_Mnih_84x84_action']
        else:
            return class_list['infoNCE_Mnih_84x84']
    

def init(module, weight_init, bias_init):
    weight_init(module.weight.data)
    if module.bias is not None:
        bias_init(module.bias.data)
    return module

def make_one_hot(labels, C=2):

    one_hot = torch.FloatTensor(size=(labels.size(0),C)).zero_()
    if torch.cuda.is_available():
            one_hot = one_hot.cuda()
    target = one_hot.scatter_(1, labels.unsqueeze(-1).long(), 1).float()
    return target

def set_seed(seed,cuda):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
