import numpy as np
import random
import math
from pynverse import inversefunc

def auto_tuning(logw, p, reward, index, gamma):
    Kexp = len(logw)
    # update exp3 components
    logw[index] += (gamma/ Kexp * reward / p[index])
    # run exp3 to determine next hyper-para
    max_logw = np.max(logw)
    w = np.exp(logw - max_logw)
    p = gamma/ Kexp + (1-gamma) * w/sum(w)
    nxt_index = np.random.choice(Kexp, p=p)
    return logw, p, nxt_index

def op_tuning(s, f, reward, index):
    Kexp = len(s)
    r = np.random.binomial(1, max(0, min(reward,1)))
    s[index] += r
    f[index] += (1-r)
    beta = np.array([np.random.beta(s[i], f[i]) for i in range(Kexp)])
    index = np.argmax(beta)
    return s, f, index

def log_barrier(p, l, eta):
        #lamda = np.min(l)
        # final_lamda = lamda
        M = len(p)
        # close = float('Inf')
        '''
        while lamda <= np.max(l):
            tmp = 0
            for base in range(M):
                tmp += 1/ (1/p[base] + eta[base] * (l[base]-lamda) )
            if np.abs(tmp-1) <= close:
                close = np.abs(tmp-1)
                final_lamda = lamda
            lamda += 0.1
        '''
        def f(lamda):
            tmp = 0
            for base in range(M):
                tmp += 1/ (1/p[base] + eta[base] * (l[base]-lamda) )
            return tmp
        final_lamda = inversefunc(f, y_values=1, domain=[np.min(l), np.max(l)])
        
        res = np.zeros(M)
        for base in range(M):
            res[base] = 1/ (1/p[base] + eta[base] * (l[base]-final_lamda) )
        return res     