import numpy as np
import scipy
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import itertools
import sklearn
from sklearn.datasets import load_svmlight_file
import torch
import gc
from torch_sparse import coalesce, transpose, spmm, spspmm

device='cuda'


def chol_sample(mean, cov):
    return mean + np.linalg.cholesky(cov) @ np.random.standard_normal(mean.size)

def getA(dist, cov, n, d):
    if dist=='T1':
        mean = np.zeros(d)
        return scipy.stats.multivariate_t(mean, cov, df=1).rvs(n)
    elif dist=='T2':
        mean = np.zeros(d)
        return scipy.stats.multivariate_t(mean, cov, df=2).rvs(n)
    else:
        mean = np.ones(d)
        return scipy.stats.multivariate_normal(mean, cov).rvs(n) #np.random.multivariate_normal(mean, cov, size=n)
    
def getb(A, n, d):
    return (A @ chol_sample(np.zeros(d), np.eye(d)/d)
            + chol_sample(np.zeros(n), np.eye(n)/d))

def getS(A, b, method, size, s, nnz=1):
    d = A.shape[1]
    if method=='LESS' or method == 'LessUniform' or method == 'LessNorm' or method=='LESSb':
        leverages = (A[:,None,:] 
              @ torch.linalg.inv(A.T @ A) 
              @ A[:,:,None]).flatten()/d #/np.linalg.matrix_rank(A)
        if method == 'LESS':
            probs = leverages/torch.sum(leverages)
        elif method == 'LessUniform':
            probs = torch.ones(len(leverages)).to(device).double()/len(leverages)
        elif method == 'LessNorm':
            probs = ((A[:,None,:] 
                  @ torch.linalg.pinv(A.T @ A) 
                  @ A[:,:,None]).flatten()
             /torch.linalg.matrix_rank(A))
            probs /= torch.sum(probs)
        else:
            levB = (leverages + (b/torch.linalg.norm(b))**2)
            probs = levB/torch.sum(levB)
        mults = torch.distributions.Multinomial(total_count=int(nnz*d), probs=probs.repeat(size, s, 1)).sample()
        etas = torch.sqrt(mults/(int(nnz*d)*probs))
        rads = (2*torch.bernoulli(torch.ones(size,s,A.shape[0]).to(device)/2)-1) 
        S = (rads*etas) / np.sqrt(s)
    elif method == 'Leverage Score Subsampling':
        probs = ((A[:,None,:] 
                  @ torch.linalg.pinv(A.T @ A) 
                  @ A[:,:,None]).flatten()
             /torch.linalg.matrix_rank(A))
        inds = torch.multinomial((probs/torch.sum(probs)).repeat(size,1), num_samples=s, replacement=True)
        return inds, torch.sqrt(probs[inds]*s)
    elif method == 'Sparse Rademacher 0.1':
        q = 0.1
        rads = (2*torch.bernoulli(torch.ones(size,s,A.shape[0]).to(device)/2)-1)
        S = ((rads * torch.bernoulli(q*torch.ones(size,s,A.shape[0]).to(device))) / np.sqrt(q*s)).double()
    else: 
        q = 0.01
        rads = (2*torch.bernoulli(torch.ones(size,s,A.shape[0]).to(device)/2)-1)
        S = ((rads * torch.bernoulli(q*torch.ones(size,s,A.shape[0]).to(device))) / np.sqrt(q*s)).double()
    return S#.to_sparse(sparse_dim=2) 


def getPreds(A, b, x, method, size, s, nnz=1, lamp=0):
    S = getS(A, b, method=method, size=size, s=s, nnz=nnz)#[:,:,None,:].coalesce()
    if method == 'Leverage Score Subsampling':
        inds, levs = S
        Atilde = A[inds]/levs[...,None]
        btilde = b[inds]/levs
    else:    
        Atilde = S @ A
        btilde = S @ b
    #xtilde = torch.linalg.lstsq(Atilde, btilde)[0]
    At = Atilde.transpose(-1,-2)
    xtilde = (torch.linalg.pinv((At @ Atilde + lamp * torch.eye(Atilde.shape[-1]).to(device))) @ At @ btilde[...,None]).squeeze()
    return torch.bmm(A.unsqueeze(0).repeat(size,1,1), 
                                  (xtilde-x)[:,:,None]).squeeze()


def getNorms(A, b, x, space, method, nnz=1, lamp = 0):
    return np.array([[np.linalg.norm(np.array([getPreds(A, b, x, method, nnz=nnz, lamp = lamp).cpu().numpy() 
                         for i in range(int(k))]).mean(0),
              ord=2,
              axis=-1)] for k in space]).squeeze()

def getCoefs(A, b, x, method, size, s, nnz=1, bigdata=False):
    if not bigdata:
        S = getS(A, b, method=method, size=size, s=s, nnz=nnz)
        if method == 'Leverage Score Subsampling':
            inds, levs = S
            Atilde = A[inds]/levs[...,None]
            btilde = b[inds]/levs
        else:    
            Atilde = S @ A
            btilde = S @ b
        try:
            xtilde = torch.linalg.lstsq(Atilde, btilde)[0]
        except:
            At = Atilde.transpose(-1,-2)
            xtilde = (torch.linalg.pinv((At @ Atilde)) @ At @ btilde[...,None]).squeeze()
            del At
        del S, Atilde, btilde
        gc.collect()
        torch.cuda.empty_cache()
        return xtilde
    else:
        xtildes = []
        for i in range(size):
            S = getS(A, b, method=method, size=1, s=s, nnz=nnz)
            if method == 'Leverage Score Subsampling':
                inds, levs = S
                Atilde = A[inds]/levs[...,None]
                btilde = b[inds]/levs
            else:    
                Atilde = S @ A
                btilde = S @ b
            xtilde = torch.linalg.lstsq(Atilde, btilde)[0]
            xtildes.append(xtilde)
        del S, Atilde, btilde
        gc.collect()
        torch.cuda.empty_cache()
        return xtildes
            