import numpy as np
from math import sqrt
from math import log
import torch
from fc_nets import FCNet2Layers, FCNet, compute_grad_matrix


def data_vec_linear_network(a,b,p,n,k):
    cov = np.random.randn(p,p)
    U,S,V = np.linalg.svd(cov)
    cov = U @ np.diag(1/np.arange(1,p+1)**(.5/(1-a))) @ V.T
    # cov = U @ np.eye(p) @ V.T
    H = cov.T @ cov
    X = np.random.randn(n,p)@cov
    w0 = np.random.randn(p,k)
    w0 = w0 /  np.linalg.norm(w0, axis = 0)
    Y = X@w0 
    pred = w0
    return X,Y,H,pred


def data_orthogonal(p,n):
    M = np.random.randn(n,p)
    U,_,V  = np.linalg.svd(M)    
    #careful when n < p
    X =  U[ :, :p ]
    H = X.T @ X 
    w0 =  (np.ones((1,p))@np.diag(1/np.arange(1,p+1)**(.5))@V.T).T 
    w0 = sqrt(2)* w0/sqrt(w0.T@w0)# Average
    Y = X@w0 
    pred = w0
    return X,Y,H,pred


def indices( iterations, test_points ):
    indices = -1*np.ones(iterations)
    indices[test_points] = list(range(0, len(test_points)))
    indices = indices.astype(int)
    return indices


def gradient(a, W, X, Y):
    beta = W.T@a
    err = X@beta - Y
    n = X.shape[0]
    grada  = W@(X.T@err)
    gradW = a@err.T@X
    return grada/n, gradW/n


def vec_gradient(W_1, W_2, X, Y):
    beta = W_1@W_2
    err = X@beta - Y
    n = X.shape[0]
    err = X.T@(X@beta - Y)
    gradW_2  = W_1.T@err
    gradW_1 = err@W_2.T
    return gradW_1/n, gradW_2/n


def stochastic_gradient(a,W,X,Y,i):
    p = X.shape[1]
    s_grad = X[i].T*(np.dot(X[i],w) - Y[i])
    return s_grad.reshape(p,1)


def generate_test_samples(iterations , freq = .02):
    test_points = np.unique(np.round(10**np.arange(0, log(iterations, 10), freq))).astype(int)
    test_points = np.insert( test_points, 0, 0)
    return(test_points)


def init_vec_linear_network_wtopw_I_basic(l,p,k,scale):
    W_2 = np.zeros((l,k))
    t = np.ones(l)
    t[p:] = 0
    W_1 = np.diag(t)
    W_1 = W_1[:, :p]
    W_1 = scale*W_1
    R = W_1@W_1.T
    U,_,_ = np.linalg.svd(R)
    U_0 = U[:,:p]
    U_perp = U[:,p:]
    return W_1.T,W_2,U_0,U_perp


def init_linear_network_wtopw_I_a_zero(l,p,scale):
    a = np.zeros((l,1))
    M = np.random.randn(l,p) 
    W,_,_  = np.linalg.svd(M)    
    W =  W[ :, :p ]
    W = scale*W
    R = W@W.T
    U,_,_ = np.linalg.svd(R)
    U_0 = U[:,:p]
    U_perp = U[:,p:]
    return a,W,U_0,U_perp

def init_linear_network_wtopw_I_basic(l,p,scale):
    a = np.zeros((l,1))
    t = np.ones(l)
    t[p:] = 0
    W = np.diag(t)
    W = W[:, :p]
    W = scale*W
    R = W@W.T
    U,_,_ = np.linalg.svd(R)
    U_0 = U[:,:p]
    U_perp = U[:,p:]
    return a,W,U_0,U_perp


def init_vec_gaussian(l,p,k,scale):
    W_2 = np.zeros((l,k))
    W_1 = scale*np.random.randn(p,l)
    R = W_1@W_1.T
    U,_,_ = np.linalg.svd(R)
    U_0 = U[:,:p]
    U_perp = U[:,p:]
    return W_1,W_2,U_0,U_perp


def init_vec_gaussian_a_non_zero(l,p,k,scale_1,scale_2):
    W_2 = scale_2*np.random.randn(l,k)
    W_1 = scale_1*np.random.randn(p,l)
    R = W_1@W_1.T
    U,_,_ = np.linalg.svd(R)
    U_0 = U[:,:p]
    U_perp = U[:,p:]
    return W_1,W_2,U_0,U_perp


def rankW(W, threshold):
    u,s,v = np.linalg.svd(W)
    rank = np.sum( (s**2 > threshold) == True)
    return s, rank, u,v



def commutative_lie_vectors(u,v):
    return np.linalg.norm(u@v.T - v@u.T)


def get_data_two_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed):
    np.random.seed(seed + 1) 
    torch.manual_seed(seed + 1) 

    n_test = 1000
    H = np.eye(d)
    X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float()
    X = X / torch.sum(X**2, 1, keepdim=True)**0.5
    X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float()
    X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5

    # generate ground truth labels
    with torch.no_grad():
        net_teacher = FCNet2Layers(n_feature=d, n_hidden=m_teacher)
        net_teacher.init_gaussian(init_scales_teacher)
        net_teacher.layer1.weight.data = net_teacher.layer1.weight.data / torch.sum((net_teacher.layer1.weight.data)**2, 1, keepdim=True)**0.5
        net_teacher.layer2.weight.data = torch.sign(net_teacher.layer2.weight.data)

        y, y_test = net_teacher(X), net_teacher(X_test)

        # print('y', y[:20, 0])
    
    return X, y, X_test, y_test, net_teacher

def get_iters_eval(n_iter_power, x_log_scale, n_iters_first=101, n_iters_next=151):
    num_iter = int(10**n_iter_power) + 1

    iters_loss_first = np.array(range(100))
    if x_log_scale:
        iters_loss_next = np.unique(np.round(np.logspace(0, n_iter_power, n_iters_first)))
    else:
        iters_loss_next = np.unique(np.round(np.linspace(0, num_iter, n_iters_next)))[:-1]
    iters_loss = np.unique(np.concatenate((iters_loss_first, iters_loss_next)))
    
    return num_iter, iters_loss