import numpy
import random
import time
import os
import matplotlib.pyplot as plt
import datasets
import models
import sys
import pickle

MOVIELENS_STEP_SIZE = 5e-3
MOVIELENS_MU = 1e-1
A9A_STEP_SIZE = 5e-3
A9A_MU = 1e-3
RCV_STEP_SIZE = 5e-1
RCV_MU = 1e-5
RCV_SSVRG_M = 1000
A9A_SSVRG_M = 600
MOVIELENS_SSVRG_M = 500
A9A_SSVRG_STEP_SIZE = 0.01
MOVIELENS_SSVRG_STEP_SIZE = 5e-3

def process(dataset, arrival_type, train_set, test_set, b, arrivals, rho, init_model):
    opt = models.Opt.SAGA
    
    if dataset == 'rcv':
        step_size = RCV_STEP_SIZE
        mu = RCV_MU
        ssvrg_M = RCV_SSVRG_M
        init_w = init_model
        ssvrg_step_size = RCV_STEP_SIZE
    elif dataset == 'a9a':
        step_size = A9A_STEP_SIZE
        mu = A9A_MU
        ssvrg_M = A9A_SSVRG_M
        init_w = init_model
        ssvrg_step_size = A9A_SSVRG_STEP_SIZE
    else:
        step_size = MOVIELENS_STEP_SIZE
        mu = MOVIELENS_MU
        ssvrg_M = MOVIELENS_SSVRG_M
        init_L, init_R = init_model
        ssvrg_step_size = MOVIELENS_SSVRG_STEP_SIZE

        
    M = arrivals * (1 << 3)
    
    loss = {
                'Inc': {'test': [0]*b, 'subopt': [0]*b},
                'S': {'test': [0]*b, 'subopt': [0]*b},
                'B': {'test': [0]*b, 'subopt': [0]*b},
                'A': {'test': [0]*b},
                'sample': [0]*b,
                'SGD_pass': {'test': [0]*b, 'subopt': [0]*b},
                'SGD_unif': {'test': [0]*b, 'subopt': [0]*b}
           }
    
    S = 0    # S_i = train_data[0:S]
    S_prev = 0
    
    TI = 0   # for STRSAGA, the sample set is train_set[0:TI]
    TB = 0   # Algo B samples from train_set[0:TB]
    TA = 0   # Algo A samples from train_set[0:TA]
    TS = 0   # Algo SSVRG samples at TS
    
    T_SGD = 0   # SGD_pass has already processed all points from 0:T
    
    trainI = 0
    trainS = 0
    trainB = 0
    trainA = 0
    trainPass = 0
    trainUnif = 0
    
    if dataset == 'rcv' or dataset == 'a9a':
        wI = models.LogisticRegression(init_w, opt)
        wS = models.LogisticRegression(init_w, models.Opt.SSVRG, ssvrg_M)
        wB = models.LogisticRegression(init_w, opt)
        wA = models.LogisticRegression(init_w, opt)
        wPass = models.LogisticRegression(init_w, models.Opt.SGD)
        wUnif = models.LogisticRegression(init_w, models.Opt.SGD)
    else:
        wI = models.MatrixFactorization(init_L, init_R, opt)
        wS = models.MatrixFactorization(init_L, init_R, models.Opt.SSVRG, ssvrg_M)
        wB = models.MatrixFactorization(init_L, init_R, opt)
        wA = models.MatrixFactorization(init_L, init_R, opt)
        wPass = models.MatrixFactorization(init_L, init_R, models.Opt.SGD)
        wUnif = models.MatrixFactorization(init_L, init_R, models.Opt.SGD)
    
    loss['Inc']['subopt'][0] = 0
    loss['B']['subopt'][0] = 0
    loss['S']['subopt'][0] = 0
    loss['Inc']['test'][0] = wI.loss(test_set)
    loss['S']['test'][0] = wS.loss(test_set)
    loss['B']['test'][0] = wB.loss(test_set)
    loss['A']['test'][0] = wA.loss(test_set)
    loss['sample'][0] = 1
    
    loss['SGD_pass']['subopt'][0] = 0
    loss['SGD_unif']['subopt'][0] = 0
    loss['SGD_pass']['test'][0] = wPass.loss(test_set)
    loss['SGD_unif']['test'][0] = wUnif.loss(test_set)
    
    for time in xrange(1, b):
        S_prev = S
        
        if arrival_type == 'skewed':
            S += (M if (random.random() < 1.0*arrivals/M) else 0)
        elif arrival_type == 'poisson':
            S += numpy.random.poisson(arrivals)
            
        if S > len(train_set):
            S = len(train_set)

        if S != 0:
            # Algo I
            for s in xrange(rho):
                if (s % 2 == 0 and TI < S):
                    j = TI
                    TI += 1
                else:
                    j = random.randrange(TI)
                wI.update_step(train_set[j], step_size, mu)
                
            # Algo Pass
            for s in xrange(rho):
                if (T_SGD < S):
                    j = T_SGD
                    T_SGD += 1
                else:
                    j = random.randrange(S)
                wPass.update_step(train_set[j], step_size, mu)
            
            # Algo Unif
            for s in xrange(rho):
                j = random.randrange(S)
                wUnif.update_step(train_set[j], step_size, mu)
            
            # Algo S 
            if rho == arrivals:
                steps = 0
                while (steps < rho):
                    if TS == S:
                        break
                    steps += wS.update_step(train_set[TS], ssvrg_step_size, mu)
                    TS += 1
                trainS = wS.reg_loss(train_set[:S], mu)
        
            # Algo B
            if S != S_prev:
                if dataset == 'rcv' or dataset == 'a9a':
                    wB = models.LogisticRegression(init_w, opt)
                else:
                    wB = models.MatrixFactorization(init_L, init_R, opt)
                TB = 0
                for s in xrange(rho*time):
                    if (s % 2 == 0 and TB < S):
                        j = TB
                        TB += 1
                    else:
                        j = random.randrange(TB)
                    wB.update_step(train_set[j], step_size, mu)
            else:
                for s in xrange(rho):
                    if (s % 2 == 0 and TB < S):
                        j = TB
                        TB += 1
                    else:
                        j = random.randrange(TB)
                    wB.update_step(train_set[j], step_size, mu)
                
            # Algo A
            if S != S_prev:
                if dataset == 'rcv' or dataset == 'a9a':
                    wA = models.LogisticRegression(init_w, opt)
                else:
                    wA = models.MatrixFactorization(init_L, init_R, opt)
                TA = 0
                for s in xrange(30*S):
                    if (s % 2 == 0 and TA < S):
                        j = TA
                        TA += 1
                    else:
                        j = random.randrange(TA)
                    wA.update_step(train_set[j], step_size, mu)
                trainA = wA.reg_loss(train_set[:S], mu)
                
        trainI = wI.reg_loss(train_set[:S], mu)
        trainB = wB.reg_loss(train_set[:S], mu)
        trainPass = wPass.reg_loss(train_set[:S], mu)
        trainUnif = wUnif.reg_loss(train_set[:S], mu)
        
        loss['Inc']['subopt'][time] = trainI - trainA
        loss['B']['subopt'][time] = trainB - trainA
        loss['Inc']['test'][time] = wI.loss(test_set)
        loss['B']['test'][time] = wB.loss(test_set)
        loss['A']['test'][time] = wA.loss(test_set)
        loss['sample'][time] = 1.0 * TI / TB if TB != 0 else 1
        
        loss['SGD_pass']['subopt'][time] = trainPass - trainA
        loss['SGD_unif']['subopt'][time] = trainUnif - trainA
        loss['SGD_pass']['test'][time] = wPass.loss(test_set)
        loss['SGD_unif']['test'][time] = wUnif.loss(test_set)
        
        if rho==arrivals:
            loss['S']['subopt'][time] = trainS - trainA
            loss['S']['test'][time] = wS.loss(test_set)
                
    return loss

if __name__ == "__main__":
    if len(sys.argv) < 4:
        print "needs 3 arguments: dataset (a9a, rcv, movielens100k, movielens1m), arrival_type (skewed, poisson), rho/lambda"
        exit()
    dataset = sys.argv[1]
    arrival_type = sys.argv[2]
    rate = float(sys.argv[3])
   
    if dataset == 'a9a': 
        train_data, test_data, d = datasets.a9a()
        init_model = numpy.random.rand(d)
    elif dataset == 'rcv':
        train_data, test_data, d = datasets.rcv()
        init_model = numpy.random.rand(d)
    elif dataset == 'movielens100k':
        train_data, test_data, m, n = datasets.movielens100k()
        r = 10
        init_model = (numpy.random.rand(m, r), numpy.random.rand(r, n))
    elif dataset == 'movielens1m':
        train_data, test_data, m, n = datasets.movielens1m()
        r = 10
        init_model = (numpy.random.rand(m, r), numpy.random.rand(r, n))
    
    name = dataset + '-' + arrival_type
    b = 100
    arrivals = len(train_data)/(2*b)
    if int(arrivals*rate) % 2 != 0:
        arrivals -= 1
    rho = int(arrivals * rate)
 
    N = 5
    outputs = []
    for i in xrange(N):
        output = process(dataset, arrival_type, train_data, test_data, b, arrivals, rho, init_model)
        outputs.append(output)
    with open('output/output_data/{0}_r{1}.pkl'.format(name, rate), 'w') as f:
        pickle.dump(outputs, f)
