import numpy as np
from scipy.optimize import fmin_tnc
from scipy.stats import norm
from numpy.linalg import pinv
from scipy.linalg import eigh
from scipy.optimize import minimize, NonlinearConstraint
from tqdm.notebook import tqdm
import time

class Optimal_Policy(object):

    def __init__(self, env, K):
        self.K = K
        self.env = env

        if self.env.nAction == 4:
            self.optimal_action = 2
        else:
            self.optimal_action = 1

    def run(self):
        print("Optimal policy: Go to the right!")
        episode_return = []
        
        for k in range(1,self.K+1):
            self.env.reset()
            done = 0
            R = 0
                        
            while not done:
                a = self.optimal_action
                r, s_, done = self.env.advance(a)                
                R += r            
            episode_return.append(R)
            
        return episode_return

class RRL_MNL(object):
    """
    Our proposed algorithm, RRL-MNL in the paper
    """
    def __init__(self, env, K, lam=1, sig=0.1, M=5, kappa=0.1, c=0.003):
    
        ## Inputs
        self.env = env
        self.K = K # number of episodes
        self.lam = lam # regularization parameter
        self.M = M # sampling size
        self.kappa = kappa # problem-dependent constant
        self.c = c  # hyperparameter for confidence radius

        ### Setting feature map
        self.d1 = self.env.nState * self.env.nAction
        self.d2 = self.env.nState
        self.d = self.d1 * self.d2

        ## phi: feature for (s,a) in S * A
        self.phi = {(s,a): np.zeros(self.d1) for s in self.env.states.keys() for a in range(self.env.nAction)}
        i = 0
        for key in self.phi.keys():
            self.phi[key][i] = 1
            i += 1

        ## psi: feature for s in S
        self.psi = {(s): np.zeros(self.d2) for s in self.env.states.keys()}
        j = 0
        for key in self.psi.keys():
            self.psi[key][j] = 1
            j += 1

        ## reachable states S_{s,a}
        self.reachable_states = {(s,a):set(np.where(self.env.P[s,a])[0]) for s in self.env.states.keys() for a in range(self.env.nAction)}

        ## varphi: feature for (s,a,s')
        self.varphi = {(s,a,s_): np.zeros(self.d1*self.d2) for s in self.psi.keys() \
                       for a in range(self.env.nAction) for s_ in self.reachable_states[s,a]}
        for s in self.psi.keys():
            for a in range(self.env.nAction):
                for s_ in self.reachable_states[s,a]:
                    self.varphi[(s,a,s_)] = np.outer(self.phi[(s,a)], self.psi[s_]).flatten()                                                     

        ## Q-value matrix
        self.Q = {(h,s,a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys() \
                   for a in range(self.env.nAction)}
        
        ## Gram matrix
        self.A = self.lam * np.identity(self.d) # Gram matrix for transition core
        self.A_inv = (1/self.lam) * np.identity(self.d)

        ## transition core
        self.theta_0 = np.zeros(self.d) # optimization starting point
        self.theta = np.zeros(self.d) 
        self.theta_prev = np.zeros(self.d)

        # optimizer for transition core
        self.optimizer = Online_Newton_Step()

    def act(self, s, h):
        """
        a function that returns the argmax of Q given the state and timestep
        """        
        return self.env.argmax(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))

    def proj(self, x, lo, hi):
        """
        Projects the value of x into the [lo,hi] interval
        """
        return max(min(x, hi), lo)

    def generate_multivariate_gaussian_sample(self, mean, covariance_matrix):
        """
        multivariate Gaussian sampling using Cholesky decomposition
        """
        dimension = len(mean)
        L = np.linalg.cholesky(covariance_matrix)  
        z = np.random.randn(dimension)
        sample = mean + np.dot(L, z) 

        return sample
        
    def compute_Q(self, k):
        """
        a function that computes the stochastically optimsitic Q-values
        """
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys()
        for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}

        self.sig = self.c  * self.env.epLen * np.sqrt(self.d) * np.log(k+1) # exploration variance

        for h in range(self.env.epLen-1, -1, -1):
            for s in self.env.states.keys():
                for a in range(self.env.nAction):
                    
                    ## known reward
                    r = self.env.R[(s,a)][0]

                    prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                    prob = prob / np.sum(prob)

                    ## decide dominating feature \hat{\varphi}
                    if len(self.reachable_states[(s,a)]) == 1: # there exists only one reachable states
                        max_noise = 0
                    else:
                        varphi_norm = [np.dot(np.dot(self.varphi[(s,a,s_)], self.A_inv), self.varphi[(s,a,s_)]) \
                                       for s_ in self.reachable_states[(s,a)]]
                        idx = np.argmax(varphi_norm)
                        s_hat = list(self.reachable_states[(s,a)])[idx]
                        varphi_hat = self.varphi[(s,a,s_hat)]

                        ## multivariate Gaussian noise
                        xi = [self.generate_multivariate_gaussian_sample(np.zeros(self.d), (self.sig)**2 * self.A_inv) for _ in range(self.M)]
                        max_noise = np.max(np.array([ np.dot(varphi_hat, xi[m]) for m in range(self.M)]))

                    ## next_value
                    next_value = [V[h+1][s_] for s_ in self.reachable_states[(s,a)]] 
                    
                    ## stochastically optimistic Q 
                    qval = r + np.dot(prob, next_value) + max_noise
                    Q[h, s, a] = qval
                    
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()

    def update_gram_matrix(self, x):
        """
        Update Gram matrix 
        V : Gram matrix (d*d-dimenional nparray)
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix
        self.A += sum([np.outer(row, row) for row in x])

    def update_gram_matrix_inverse(self, x):
        """
        Update Gram matrix inverse using Sherman-Morrison Formula
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix inverse: Sherman-Morrison Formula
        for row in x:
            self.A_inv = self.A_inv - (np.outer(np.dot(self.A_inv, row), np.dot(row, self.A_inv)))/(1 + np.dot(np.dot(row, self.A_inv), row))

    def update_theta(self, x, y):
        self.theta_prev = self.theta
        self.optimizer.fit(theta=self.theta_0, theta_prev=self.theta_prev, x=x, y=y, A=self.A)
        self.theta = self.optimizer.w

    def run(self):
        print("RRL-MNL")
        episode_return = []

        for k in tqdm(range(1, self.K+1)):

            # reset environment
            self.env.reset()
            done = 0
            R = 0

            X_k = [] # varphi(s^k_h,a^k_h,s') for all h in [H]
            Y_k = [] # y^k_h for all h in [H]

            # starting episode
            while not done:
                X_h = []

                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)
                r, s_, done = self.env.advance(a)

                R += r

                # transition response variable
                y = np.zeros(len(self.reachable_states[s,a]))
                for i in range(len(self.reachable_states[s,a])):
                    if list(self.reachable_states[s,a])[i] == s_:
                        y[i] = 1                        
                    self.update_gram_matrix([((self.kappa/2)**(1/2))*(self.varphi[(s,a,list(self.reachable_states[s,a])[i])])])
                    self.update_gram_matrix_inverse( [((self.kappa/2)**(1/2))*(self.varphi[(s,a,list(self.reachable_states[s,a])[i])])])
                    X_h.append(self.varphi[(s,a,list(self.reachable_states[s,a])[i])])
                Y_k.append(y)
                X_k.append(np.array(X_h))
            print(f'episode: {k+1}, cum_rewards: {R}')
            episode_return.append(R)

            # transition core update
            self.update_theta(X_k, Y_k)

            # Update Q value
            self.compute_Q(k)

        return episode_return

class ORRL_MNL(RRL_MNL):
    """
    Our proposed algorithm, ORRL-MNL in the paper
    """
    def __init__(self, env, K, lam=1, sig=0.1, M=3, beta=0.1, eta=0.1, c=0.001):

        ## Inputs
        super().__init__(env, K, lam, sig, M, c)
        self.beta = beta
        self.eta = 2 + np.log(4)/2
        self.c = c

        ## Optimizer: Online Mirror Descent
        self.optimizer = Online_Mirror_Descent()

        ## Gram matrix
        self.lam = 1
        self.B = (self.lam)*np.identity(self.d)
        self.B_inv = (1/self.lam)*np.identity(self.d)
        self.B_tilde = (self.lam)*np.identity(self.d)
        self.grad_2_ell = np.zeros_like(self.B_tilde)

    def update_gram_matrix(self, x):
        """
        Update Gram matrix 
        V : Gram matrix (d*d-dimenional nparray)
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix
        self.B += sum([np.outer(row, row) for row in x])

    def update_gram_matrix_inverse(self, x):
        """
        Update Gram matrix inverse using Sherman-Morrison Formula
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix inverse: Sherman-Morrison Formula
        for row in x:
            self.B_inv = self.B_inv - (np.outer(np.dot(self.B_inv, row), np.dot(row, self.B_inv)))/(1 + np.dot(np.dot(row, self.B_inv), row))

    def expected_varphi(self, s, a, theta):
        """
        E_{s' ~ P_theta}[vaprhi(s,a,s')]
        """
        prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
        prob = prob / np.sum(prob)

        E_varphi = np.zeros(self.d)

        for i in range(len(prob)):
            sprime = list(self.reachable_states[(s,a)])[i]
            E_varphi += prob[i]*self.varphi[(s,a,sprime)]

        return E_varphi

    def compute_Q(self, k):
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys()
        for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}

        self.sig = self.c*self.env.epLen * np.sqrt(self.d) * np.log(k+1) # exploration variance
        self.beta = self.c*np.sqrt(self.d) * np.log(k+1) # confidence radius
           
        for h in range(self.env.epLen-1, -1, -1):
            for s in self.env.states.keys():
                for a in range(self.env.nAction):
                    
                    ## known reward
                    r = self.env.R[(s,a)][0]

                    ## prob : list of transition probabilities
                    prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                    prob = prob / np.sum(prob)

                    ## randomized bonus
                    if len(self.reachable_states[(s,a)]) == 1: # there exists only one reachable states
                        bonus1 = 0
                        s_hat = list(self.reachable_states[(s,a)])[0]
                        varphi_hat = self.varphi[(s,a,s_hat)]
                        bonus2 = 2*(self.env.epLen)*(self.beta**2)*(np.dot(np.dot(varphi_hat, self.B_inv), varphi_hat))
                        r_bonus = bonus1 + bonus2
                    else:
                        ## bonus term 1
                        ## centralization
                        expected_varphi = np.zeros(self.d)
                        for i in range(len(prob)):
                            sprime = list(self.reachable_states[(s,a)])[i]
                            expected_varphi += prob[i]*self.varphi[(s,a,sprime)]
                        varphi_bars = [self.varphi[(s,a,s_)] - expected_varphi for s_ in self.reachable_states[(s,a)]] # centralized feature

                        ## multivariate Gaussian noise
                        xi = [self.generate_multivariate_gaussian_sample(np.zeros(self.d), (self.sig)**2 * self.B_inv) for _ in range(self.M)]

                        max_noises = np.max(np.dot(np.array(varphi_bars), np.array(xi).T), axis=1)
                        bonus1 = np.dot(prob, max_noises)

                        ## bonus term 2
                        varphi_norm_sq = [np.dot(np.dot(self.varphi[(s,a,s_)], self.B_inv), self.varphi[(s,a,s_)]) \
                                       for s_ in self.reachable_states[(s,a)]]
                        bonus2 = 2*(self.env.epLen)*(self.beta**2)*(np.max(varphi_norm_sq))

                        r_bonus = bonus1 + bonus2

                    next_value = [V[h+1][s_] for s_ in self.reachable_states[(s,a)]] # next values

                    ## stochastically optimistic Q 
                    qval = r + np.dot(prob, next_value) + r_bonus
                    Q[h, s, a] = qval                     
                        
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()

    def update_theta(self, x, y):
        self.theta_prev = self.theta
        self.optimizer.fit(theta=self.theta_0, theta_prev=self.theta_prev, x=x, y=y, Btilde=self.B_tilde, eta=self.eta, ell_2=self.grad_2_ell)
        self.theta = self.optimizer.w
    
    def run(self):
        print("ORRL-MNL")
        episode_return = []

        for k in tqdm(range(1, self.K+1)):

            # reset environment
            self.env.reset()
            done = 0
            R = 0

            X_k = [] # [[varphi(s_h, a_h, s')]]
            Y_k = [] # [[y_h]]

            # starting episode
            while not done:
                X_h = [] # [vaprhi(s_h, a_h, s')]

                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)
                r, s_, done = self.env.advance(a)

                R += r

                prob_k = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                prob_k = prob_k / np.sum(prob_k)

                ## transition response variable
                y = np.zeros(len(self.reachable_states[s,a]))
                self.grad_2_ell = np.zeros_like(self.B_tilde)
                self.B_tilde = self.B
                for i in range(len(self.reachable_states[s,a])):
                    if list(self.reachable_states[s,a])[i] == s_:
                        y[i] = 1
                    ## B_tilde update
                    x = self.varphi[(s,a,list(self.reachable_states[s,a])[i])] 
                    self.B_tilde += self.eta*(prob_k[i])*(np.outer(x, x))
                    self.grad_2_ell += self.eta*(prob_k[i])*(np.outer(x, x))
            
                    X_h.append(self.varphi[(s,a,list(self.reachable_states[s,a])[i])])
                Y_k.append(y)
                X_k.append(np.array(X_h))
            print(f'episode: {k+1}, cum_rewards: {R}')
            episode_return.append(R)

            ## transition core update          
            self.update_theta(X_k, Y_k)
            
            ## Gram matrix B update
            for h in range(self.env.epLen):
                prob_k_plus_1 = [np.exp(np.dot(X_k[h][i], self.theta)) for i in range(len(X_k[h]))] 
                prob_k_plus_1 = prob_k_plus_1 / np.sum(prob_k_plus_1)
                
                for i in range(len(prob_k_plus_1)):
                    x = X_k[h][i] 
                    self.update_gram_matrix([np.sqrt(prob_k_plus_1[i])*x])
                    self.update_gram_matrix_inverse( [np.sqrt(prob_k_plus_1[i])*x])

            self.compute_Q(k)

        return episode_return

class UCRL_MNL(RRL_MNL):
    """
    baseline algorithm: UCRL-MNL described in [Model-Based Reinforcement Learning with Multinomial Logistic Function Approximation](https://arxiv.org/pdf/2212.13540.pdf)
    """
    def __init__(self, env, K, lam=1, beta=0.1, kappa=0.1, c=0.001):

        ## Inputs
        super().__init__(env, K, lam, c)

        ## Optimizer: Regularized MLE
        self.optimizer = RegularizedMNLRegression()
        self.X = []
        self.Y = []
        self.c = c
    
    def compute_Q(self, k):
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys()
        for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}

        self.beta = self.c/10*self.env.epLen * np.sqrt(self.d) * np.log(k+1)/self.kappa # confidence radius
           
        for h in range(self.env.epLen-1, -1, -1):
            for s in self.env.states.keys():
                for a in range(self.env.nAction):
                    
                    ## known reward
                    r = self.env.R[(s,a)][0] 

                    ## prob : list of transition probabilities
                    prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                    prob = prob / np.sum(prob)

                    ## Optimistic bonus
                    if len(self.reachable_states[(s,a)]) == 1:
                        bonus = 0
                    else:
                        varphi_norm_sq = [np.dot(np.dot(self.varphi[(s,a,s_)], self.A_inv), self.varphi[(s,a,s_)]) \
                                       for s_ in self.reachable_states[(s,a)]]
                        bonus = 2*(self.env.epLen)*(self.beta)*(np.sqrt(np.max(varphi_norm_sq)))

                    ## UCB Q value
                    next_value = [V[h+1][s_] for s_ in self.reachable_states[(s,a)]]
                    qval = r + np.dot(prob, next_value) + bonus                    
                    Q[h, s, a] = qval 
                    
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()

    def update_gram_matrix_inverse(self, x):
        """
        Update Gram matrix inverse using Sherman-Morrison Formula
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix inverse: Sherman-Morrison Formula
        for row in x:
            self.A_inv = self.A_inv - (np.outer(np.dot(self.A_inv, row), np.dot(row, self.A_inv)))/(1 + np.dot(np.dot(row, self.A_inv), row))

    def update_theta(self, x, y):
        self.theta_prev = self.theta
        self.optimizer.fit(x=x, y=y, theta=self.theta_0, lam=self.lam)
        self.theta = self.optimizer.w
    
    def run(self):
        print("UCRL-MNL")
        episode_return = []

        for k in tqdm(range(1, self.K+1)):

            # reset environment
            self.env.reset()
            done = 0
            R = 0

            # starting episode
            while not done:
                X = []
                Y = []

                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)
                r, s_, done = self.env.advance(a)

                R += r

                # transition response variable
                y = np.zeros(len(self.reachable_states[s,a]))
                for i in range(len(self.reachable_states[s,a])):
                    if list(self.reachable_states[s,a])[i] == s_:
                        y[i] = 1
                    X.append(self.varphi[(s,a,list(self.reachable_states[s,a])[i])])
                    self.update_gram_matrix_inverse([self.varphi[(s,a,list(self.reachable_states[s,a])[i])]])

                self.Y.append(y)
                self.X.append(np.array(X))
            print(f'episode: {k+1}, cum_rewards: {R}')
            episode_return.append(R)

            ## transition core update
            self.update_theta(x=self.X, y=self.Y)       

            # Update Q value
            self.compute_Q(k)

        return episode_return    

class UCRL_MNL_PLUS(RRL_MNL):
    """
    Our proposed algorithm, UCRL-MNL+ in the paper
    """    

    def __init__(self, env, K, lam=1, beta=0.1, eta=0.1, c=0.003):

        super().__init__(env, K, lam, c)
        self.beta = beta
        self.eta = 2 + np.log(4)/2
        self.c = c

        ## Optimizer: Online Mirror Descent
        self.optimizer = Online_Mirror_Descent()

        ## Gram matrix
        self.lam =1
        self.B = (self.lam)*np.identity(self.d)
        self.B_inv = (1/self.lam)*np.identity(self.d)
        self.B_tilde = (self.lam)*np.identity(self.d)
        self.grad_2_ell = np.zeros_like(self.B_tilde)

    def expected_varphi(self, s, a, theta):
        """
        E_{s' ~ P_theta}[vaprhi(s,a,s')]
        """
        prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
        prob = prob / np.sum(prob)

        E_varphi = np.zeros(self.d)

        for i in range(len(prob)):
            sprime = list(self.reachable_states[(s,a)])[i]
            E_varphi += prob[i]*self.varphi[(s,a,sprime)]

        return E_varphi

    def update_gram_matrix(self, x):
        """
        Update Gram matrix 
        V : Gram matrix (d*d-dimenional nparray)
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix
        self.B += sum([np.outer(row, row) for row in x])

    def update_gram_matrix_inverse(self, x):
        """
        Update Gram matrix inverse using Sherman-Morrison Formula
        x: features ([d-dimensional nparay])
        """
        ## Gram matrix inverse: Sherman-Morrison Formula
        for row in x:
            self.B_inv = self.B_inv - (np.outer(np.dot(self.B_inv, row), np.dot(row, self.B_inv)))/(1 + np.dot(np.dot(row, self.B_inv), row))

    def compute_Q(self, k):
        Q = {(h, s, a): 0.0 for h in range(self.env.epLen) for s in self.env.states.keys()
        for a in range(self.env.nAction)}
        V = {h: np.zeros(self.env.nState) for h in range(self.env.epLen + 1)}
        
        self.beta = self.c *self.env.epLen * np.sqrt(self.d) * np.log(k+1)  # confidence radius

        for h in range(self.env.epLen-1, -1, -1):
            for s in self.env.states.keys():
                for a in range(self.env.nAction):
                    
                    ## unknown reward vs known reward
                    r = self.env.R[(s,a)][0]
                    
                    ## prob : list of transition probabilities
                    prob = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                    prob = prob / np.sum(prob)

                    ## bonus
                    if len(self.reachable_states[(s,a)]) == 1: # there exists only one reachable states
                        bonus1 = 0
                        s_hat = list(self.reachable_states[(s,a)])[0]
                        varphi_hat = self.varphi[(s,a,s_hat)]
                        bonus2 = 2*(self.env.epLen)*(self.beta**2)*(np.dot(np.dot(varphi_hat, self.B_inv), varphi_hat))
                        r_bonus = bonus1 + bonus2
                    else:
                        ## bonus term 1                        
                        expected_varphi = np.zeros(self.d)
                        for i in range(len(prob)):
                            sprime = list(self.reachable_states[(s,a)])[i]
                            expected_varphi += prob[i]*self.varphi[(s,a,sprime)]
                        varphi_bars = [self.varphi[(s,a,s_)] - expected_varphi for s_ in self.reachable_states[(s,a)]] # centralized feature
                        varphi_bar_norms = [np.sqrt(np.dot(np.dot(varphi_bars[i].T, self.B_inv), varphi_bars[i])) for i in range(len(prob))]
                        bonus1 = (self.env.epLen)*(self.beta)*np.dot(varphi_bar_norms, prob)

                        ## bonus term 2
                        varphi_norm_sq = [np.dot(np.dot(self.varphi[(s,a,s_)], self.B_inv), self.varphi[(s,a,s_)]) \
                                        for s_ in self.reachable_states[(s,a)]]
                        bonus2 = 2*(self.env.epLen)*(self.beta**2)*(np.max(varphi_norm_sq))

                        r_bonus = bonus1 + bonus2

                    next_value = [V[h+1][s_] for s_ in self.reachable_states[(s,a)]] # next values

                    ## stochastically optimistic Q 
                    qval = r + np.dot(prob, next_value) + r_bonus
                    Q[h, s, a] = qval                    
                        
                V[h][s] = max(np.array([self.Q[(h,s,a)] for a in range(self.env.nAction)]))
        self.Q = Q.copy()

    def update_theta(self, x, y):
        self.theta_prev = self.theta
        self.optimizer.fit(theta=self.theta_0, theta_prev=self.theta_prev, x=x, y=y, Btilde=self.B_tilde, eta=self.eta, ell_2=self.grad_2_ell)
        self.theta = self.optimizer.w

    def run(self):
        print("UCRL-MNL+")
        episode_return = []

        for k in tqdm(range(1, self.K+1)):

            # reset environment
            self.env.reset()
            done = 0
            R = 0

            X_k = [] # [[varphi(s_h, a_h, s')]]
            Y_k = [] # [[y_h]]

            # starting episode
            while not done:
                X_h = [] # [vaprhi(s_h, a_h, s')]

                s = self.env.state
                h = self.env.timestep
                a = self.act(s,h)
                r, s_, done = self.env.advance(a)

                R += r

                prob_k = [np.exp(np.dot(self.varphi[(s,a,s_)], self.theta)) for s_ in self.reachable_states[(s,a)]] 
                prob_k = prob_k / np.sum(prob_k)

                ## transition response variable
                y = np.zeros(len(self.reachable_states[s,a]))
                self.grad_2_ell = np.zeros_like(self.B_tilde)
                self.B_tilde = self.B
                for i in range(len(self.reachable_states[s,a])):
                    if list(self.reachable_states[s,a])[i] == s_:
                        y[i] = 1
                    ## B_tilde update
                    x = self.varphi[(s,a,list(self.reachable_states[s,a])[i])] # varphi_bar
                    self.B_tilde += self.eta*(prob_k[i])*(np.outer(x, x))
                    self.grad_2_ell += self.eta*(prob_k[i])*(np.outer(x, x))
            
                    X_h.append(x)
                Y_k.append(y)
                X_k.append(np.array(X_h))
            print(f'episode: {k+1}, cum_rewards: {R}')

            episode_return.append(R)

            ## transition core update           
            self.update_theta(X_k, Y_k)
            
            ## Gram matrix B update
            for h in range(self.env.epLen):
                prob_k_plus_1 = [np.exp(np.dot(X_k[h][i], self.theta)) for i in range(len(X_k[h]))] 
                prob_k_plus_1 = prob_k_plus_1 / np.sum(prob_k_plus_1)

                for i in range(len(prob_k_plus_1)):
                    x = X_k[h][i]
                    self.update_gram_matrix( [np.sqrt(prob_k_plus_1[i])*x])
                    self.update_gram_matrix_inverse( [np.sqrt(prob_k_plus_1[i])*x])

            # Update Q value
            self.compute_Q(k)

        return episode_return        

class Online_Newton_Step:

    def compute_prob(self, theta, x):
        """
        theta: d-dimensional vector
        x: (K, d)-dimensional matrix where K is number of assortment, d is dimension
        """
        probs = []
        for i in range(len(x)):
            means = np.dot(x[i], theta)
            u = np.exp(means)
            SumExp = u.sum()
            prob = u/SumExp
            probs.append(prob)

        return prob

    def gradient_of_ell(self, theta, x, y):
        """
        Inputs
        theta: d-dimensional vector
        x: (K, d)-dimensional matrix where K is number of assortment, d is dimension
        y: K-dimensional vector (response varaible for assortment)
        Output
        G(theta) = nabal ell (theta) : gradient of negative log-likelihood
        """
        m = len(x) 
        prob = self.compute_prob(theta, x)
        eps = [prob[i] - y[i] for i in range(len(prob))]

        grad = np.zeros(len(theta))
        for i in range(len(eps)):
            for j in range(len(eps[i])):
                grad += eps[i][j] * x[i][j]

        grad = (1/m)*grad

        return grad

    def cost_function(self, theta, theta_prev, x, y, A):
        """
        Objectvie function of online transition core estimation
        """
        # term 1
        v = theta - theta_prev
        term1 = np.dot(np.dot(v.T, A), v)

        # term 2
        grad = self.gradient_of_ell(theta_prev, x, y)
        term2 = np.dot(v.T, grad)

        res = (1/2)*term1 + term2
        
        return res
    
    def gradient(self, theta, theta_prev, x, y, A):
        v = theta - theta_prev
        term1 = np.dot(A, v)
        term2 = self.gradient_of_ell(theta_prev, x, y)
        res = term1 + term2

        return res

    def fit(self, theta, theta_prev, x, y, A):
        """
        theta: optimization starting point
        """
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=(theta_prev, x, y, A), messages=0)
        self.w = opt_weights[0]

        return self

class Online_Mirror_Descent:

    def compute_prob(self, theta, x):
        """
        theta: d-dimensional vector
        x: (K, d)-dimensional matrix where K is number of assortment, d is dimension
        """
        probs = []
        for i in range(len(x)):
            means = np.dot(x[i], theta)
            u = np.exp(means)
            SumExp = u.sum()
            prob = u/SumExp
            probs.append(prob)

        return prob

    def gradient_of_ell(self, theta, x, y):
        """
        Inputs
        theta: d-dimensional vector
        x: (K, d)-dimensional matrix where K is number of assortment, d is dimension
        y: K-dimensional vector (response varaible for assortment)
        Output
        G(theta) = nabal ell (theta) : gradient of negative log-likelihood
        """
        m = len(x) 
        prob = self.compute_prob(theta, x)
        eps = [prob[i] - y[i] for i in range(len(prob))]

        grad = np.zeros(len(theta))
        for i in range(len(eps)):
            for j in range(len(eps[i])):
                grad += eps[i][j] * x[i][j]

        grad = (1/m)*grad

        return grad

    def cost_function(self, theta, theta_prev, x, y, Btilde, eta, ell_2):
        """
        Objectvie function of online transition core estimation
        """
        # term 1
        grad = self.gradient_of_ell(theta_prev, x, y)
        term1 = np.dot(theta, grad)

        # term 2
        v = theta - theta_prev
        if np.sqrt(np.dot(np.dot(v.T, ell_2), v)) >0:
            term2 = (1/2)*np.sqrt(np.dot(np.dot(v.T, ell_2), v))
        else: 
            term2 = 0

        # term 3 
        term3 = np.dot(np.dot(v.T, Btilde), v)

        res = term1 + term2 + 1/(2*eta)*term3
        return res

    def gradient(self, theta, theta_prev, x, y, Btilde, eta, ell_2):
        
        # term1
        term1 = self.gradient_of_ell(theta_prev, x, y)

        # term 2
        v = theta - theta_prev
        if np.sqrt(np.dot(np.dot(v.T, ell_2), v)) >0:
            term2 = (1/2)* np.dot(ell_2, v) / np.sqrt(np.dot(np.dot(v.T, ell_2), v))
        else: 
            term2 = 0

        # term 3
        term3 = np.dot(Btilde, v)
        res = term1 + term2 + (1/eta)*term3

        return res

    def fit(self, theta, theta_prev, x, y, Btilde, eta, ell_2):
        """
        theta: optimization starting point
        """
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=(theta_prev, x, y, Btilde, eta, ell_2), messages=0)
        self.w = opt_weights[0]

        return self

class RegularizedMNLRegression:

    def compute_prob(self, theta, x):
        probs = []
        for i in range(len(x)):
            means = np.dot(x[i], theta)
            u = np.exp(means)
            SumExp = u.sum()
            prob = u/SumExp
            probs.append(prob)
            
        return probs

    def cost_function(self, theta, x, y, lam):
        probs = self.compute_prob(theta, x)
        res = 0
        for i in range(len(x)):
            res += np.sum(np.multiply(y[i], np.log(probs[i])))
            
        m = len(x)
        res *= -(1/m)
        res += (1/m)*lam*np.linalg.norm(theta)
        
        return res

    def gradient(self, theta, x, y, lam):
        m = len(x)
        prob = self.compute_prob(theta, x)
        eps = [prob[i] - y[i] for i in range(len(prob))]
        
        grad = np.zeros(len(theta))
        for i in range(len(eps)):
            for j in range(len(eps[i])):
                grad += eps[i][j] * x[i][j]
                
        grad = (1/m)*grad
        grad += (1/m)*lam*theta        

        return grad

    def fit(self, x, y, theta, lam):
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, messages=0, args=(x, y, lam))
        self.w = opt_weights[0]
        return self