import argparse
import os

import torch

from data.hyperspectra import getHyper
from data.tech import getTech
from data.videos import getVideos
from evaluate import evaluate,evaluate_both,getbest,evaluate_dense
from pathlib import Path
import sys

def get_hostname():
    with open("/etc/hostname") as f:
        hostname=f.read()
    hostname=hostname.split('\n')[0]
    return hostname

def mysvd(init_A,k):
    if k>min(init_A.size(0),init_A.size(1)):
        k=min(init_A.size(0),init_A.size(1))
    d=init_A.size(1)
    x=[torch.Tensor(d).uniform_() for i in range(k)]
    for i in range(k):
        x[i]=x[i].cuda()
        x[i].requires_grad=False
    def perStep(x,A):
        x2=A.t().mv(A.mv(x))
        x3=x2.div(torch.norm(x2))
        return x3
    U=[]
    S=[]
    V=[]
    Alist=[init_A]
    for kstep in range(k): #pick top k eigenvalues
        cur_list=[x[kstep]]   #current history
        for j in range(300):  #steps
            cur_list.append(perStep(cur_list[-1],Alist[-1]))  #works on cur_list
        V.append((cur_list[-1]/torch.norm(cur_list[-1])).view(1,cur_list[-1].size(0)))
        S.append((torch.norm(Alist[-1].mv(V[-1].view(-1)))).view(1))
        U.append((Alist[-1].mv(V[-1].view(-1))/S[-1]).view(1,Alist[-1].size(0)))
        Alist.append(Alist[-1]-torch.ger(Alist[-1].mv(cur_list[-1]), cur_list[-1]))
    return torch.cat(U,0).t(),torch.cat(S,0),torch.cat(V,0).t()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    def aa(*args, **kwargs):
        parser.add_argument(*args, **kwargs)


    aa("--data", type=str, default="tech", help="tech|video|hyper")
    aa("--dataname", type=str, default="mit", help="transformer|mit|friends")
    aa("--m", type=int, default=10, help="m for S")
    aa("--k", type=int, default=10, help="target: rank k approximation")
    aa("--mp", type=int, default=10, help="mp for R")
    aa("--iter", type=int, default=5000, help="total iterations")
    aa("--size", type=int, default= -1, help="dataset size")

    aa("--single", dest='single',default=False, action='store_true',help="generate raw?")
    aa("--dense", type=int, default= -1, help="calculate dense?")
    aa("--raw", dest='raw', default=False,action='store_true',help="generate raw?")
    aa("--bestonly", dest='bestonly', default=False,action='store_true',help="only compute best?")

    args = parser.parse_args()
    rawdir="/git/big-lowrank/" if get_hostname()=="Dragon" else "/big-lowrank/"
    rltdir="/git/big-lowrank/" if get_hostname()=="Dragon" else "/big-lowrank/"

    print(args)
    m=args.m
    mp=args.mp
    k=args.k



    if args.data=='tech':
        save_dir=rltdir+'rlt/tech/'
    elif args.data=='hyper':
        save_dir=rltdir+'rlt/hyper/'
    elif args.data=='video':
        save_dir=rltdir+'rlt/video/'+args.dataname+'/'
    else:
        print("Wrong data option!")
        sys.exit()

    if (not args.bestonly) and os.path.isfile(save_dir+'m='+str(m)+'_k='+str(k)+'_iter='+str(args.iter)+'_N='+str(args.size)) and (not args.raw):
        print("This one is already done! Now exiting.")
        sys.exit()

    lr=1
    if args.data=='tech':
        A_train,A_test,n,d=getTech(args.raw,rawdir)
    elif args.data=='hyper':
        A_train,A_test,n,d=getHyper(args.raw,args.size,rawdir)
    else:
        A_train,A_test,n,d=getVideos(args.dataname,args.raw,args.size,rawdir)
        lr=10


    print("Working on data ", args.data)

    Path(save_dir).mkdir(parents=True, exist_ok=True)

    N_train=len(A_train)
    N_test=len(A_test)
    print("Dim= ", n,d)
    print("N train=", N_train, "N test=", N_test)

    for tmpk in [10,20,30]:
        best_file=save_dir+"N="+str(args.size)+"_k="+str(tmpk)+'_best'
        if (not os.path.isfile(best_file)) or args.raw:
            print("computing ",best_file)
            getbest(A_train,A_test, tmpk, args.data,best_file)

    if args.bestonly:
        sys.exit()

    best_file=save_dir+"N="+str(args.size)+"_k="+str(k)+'_best'
    best_train,best_test=torch.load(best_file)

    rlt_dic={}

    sparse=(args.data=='tech')
    if args.dense>=0:
        for take in range(5):
            f_name ='m='+str(m)+'_k='+str(k)+'_N='+str(args.size)+'_full_take='+str(take)
            sketch = torch.randn(m, n).cuda()
            rlt_dic[f_name] = (evaluate_dense(sparse,A_train,sketch,m,k),
                               evaluate_dense(sparse,A_test,sketch,m,k))
            torch.save([rlt_dic[f_name], N_train, N_test], save_dir+f_name)
            print(f_name, rlt_dic[f_name][0]/N_train-best_train, rlt_dic[f_name][1]/N_test-best_test)
        if args.dense==1:
            sys.exit()

    print_freq=50

    if args.single:
        cur_diff = []
        sketch_vector = torch.randint(m, [n]).int()  # m*n
        sketch_vector.requires_grad = False
        sketch_value = ((torch.randint(2, [n]).float() - 0.5) * 2).cuda()
        sketch_value.requires_grad = False
        for bigstep in range(args.iter+1):
            if ((bigstep+1)%1000==0) and lr>1:
                lr=lr*0.3
            if bigstep>200:
                print_freq=200
            A = A_train[int(torch.randint(N_train, [1]).item())]
            if sparse:
                AM=A['M'].cuda()
                Ad=A['d']
                An=A['n']
                AMap=A['Map']
            else:
                AM = A.cuda()
                Ad=d
                An=n

            if bigstep % print_freq == 0:
                print(bigstep, '.')
                f_name ='m='+str(m)+'_k='+str(k)+'_iter=' + str(bigstep)+'_N='+str(args.size)
                rlt_dic[f_name] = (evaluate(sparse,A_train,sketch_vector,sketch_value,m,k,An,Ad),
                                   evaluate(sparse,A_test,sketch_vector,sketch_value,m,k,An,Ad))
                torch.save([sketch_vector, sketch_value, rlt_dic[f_name], N_train, N_test], save_dir+f_name)
                print(f_name, rlt_dic[f_name][0]/N_train-best_train, rlt_dic[f_name][1]/N_test-best_test)


            SA = torch.Tensor(m, Ad).fill_(0).cuda()

            if sparse:
                for i in range(An):  # A has this many rows, not mapped yet
                    actR = AMap[i]  # Actual row in the matrix
                    mapR = sketch_vector[actR]  # row is mapped to this row in the sketch
                    SA[mapR] += AM[i] * sketch_value[actR]  # remember: times the weight
            else:
                for i in range(n):  # A has this many rows, not mapped yet
                    mapR = sketch_vector[i]  # row is mapped to this row in the sketch
                    SA[mapR] += AM[i] * sketch_value[i]  # remember: times the weight

            SH = SA.detach()
            SH.requires_grad = True
            U2, Sigma2, V2 = mysvd(SH, SH.size(1))
            AU = AM.mm(V2)
            U3, Sigma3, V3 = mysvd(AU, k)
            ans = U3[:, :k].mm(torch.diag(Sigma3[:k]).cuda()).mm(V3.t()[:k]).mm(V2.t())
            loss = torch.norm(ans - AM)
            loss.backward()
            if bigstep%10==0:
                print(loss.cpu().item(),loss.cpu().item()-best_train, end=",")

            if sparse:
                for i in range(An):
                    actR = AMap[i]  # Actual row in the matrix
                    sketch_value[actR] -=lr* torch.dot(SH.grad.data[int(sketch_vector[actR]), :], AM[i, :])
            else:
                for i in range(n):
                    sketch_value[i] -=lr* torch.dot(SH.grad.data[int(sketch_vector[i]), :], AM[i, :])

            del SA, SH, U2, Sigma2, V2, AU, U3, Sigma3, V3, ans, loss, AM
            torch.cuda.empty_cache()
