from collections import defaultdict
import pandas as pd
import numpy as np

def kl(p, q):
    if p == 0:
        return (1-p)*np.log((1-p)/(1-q))
    return p*np.log(p/q) + (1-p)*np.log((1-p)/(1-q))

class PFOptimization(object):
    # mus must be a dict from arm to reward
    def __init__(self, all_groups, all_arms, mus, group_to_avail_arms, arms_to_groups):
        self.all_groups = all_groups
        self.all_arms = all_arms
        self.G = len(all_groups)
        self.mus = mus 
        self.total_num_arms = len(all_arms)
        self.group_to_avail_arms = group_to_avail_arms
        self.arms_to_groups = arms_to_groups
        
        # self.opts[g] is the optimal value for each group.
        self.opts = []
        for g in all_groups:
            avail_arms = group_to_avail_arms[g]
            group_mus = [self.mus[a] for a in avail_arms]
            self.opts.append(max(group_mus))
        
        # get J's
        self.Js = dict()
        # list of suboptimal arms that will be used in the optimization
        self.suboptimal_arms = []
        self.gmins = dict()
        self.q0 = []
        self.optimal_regret = 0
        for a in all_arms:
            smallest_opt = min(self.opts[g] for g in self.arms_to_groups[a])
            for g in self.arms_to_groups[a]:
                if self.opts[g] == smallest_opt:
                    self.gmins[a] = g
                    break
            if smallest_opt > self.mus[a]:
                self.Js[a] = 1/kl(self.mus[a], smallest_opt)
                self.suboptimal_arms.append(a)
                q0_a = np.zeros(self.G)
                if len(self.arms_to_groups[a]) > 1:
                    other_group = [g for g in self.arms_to_groups[a] if g != self.gmins[a]][0]

                    # add q0
                    Jg = 1/kl(self.mus[a], self.opts[other_group])
                    q_other_group = Jg / (2*self.Js[a])

                    q0_a[other_group] = q_other_group
                    q0_a[self.gmins[a]] = 1 - q_other_group
                else:
                    q0_a[self.gmins[a]] = 1
                self.q0.extend(q0_a)

                self.optimal_regret += self.Js[a]* (smallest_opt - self.mus[a])

                
        self.K = len(self.suboptimal_arms)
        self.a_to_index = dict()
        for i, a in enumerate(self.suboptimal_arms):
            self.a_to_index[a] = i

        self.num_vars = self.K * self.G
        assert all([self.utility_gain(g, self.q0) > 0 for g in self.all_groups])

            
    # q[get_var_index(g, a)] returns the variable corresponding to g and a
    # where a must be a suboptimal arm.
    # The first K variables are related to the first arm, the next K for the second arm, etc.
    def get_var_index(self, g, a):
        a_index_subopt = self.a_to_index[a]
        return a_index_subopt*self.G + g
            
    def total_disagreement(self):
        return sum([self.disagreement(g) for g in self.all_groups])

    def disagreement(self, g):
        disagreement = 0
        # optimal for group g
        opt_g = self.opts[g]
        # Loop through all arms for group g.
        for a in self.group_to_avail_arms[g]:
            mu = self.mus[a]
            delta = opt_g - mu
            if delta == 0:
                # This arm is optimal for g.
                continue
            J_g_a = 1/kl(mu, opt_g)
            
            disagreement += delta*J_g_a
        return disagreement

    def total_regret(self, q):
        return sum([self.regret(g, q) for g in range(self.G)])

    def regret(self, g, q):
        utility_gain = 0
        # optimal for group g
        opt_g = self.opts[g]
        # Loop through all arms for group g.
        for a in self.group_to_avail_arms[g]:
            mu = self.mus[a]
            delta = opt_g - mu
            if delta == 0:
                # This arm is optimal for g.
                continue
            
            if a in self.suboptimal_arms:
                J_a = self.Js[a]
                utility_gain += delta*q[self.get_var_index(g, a)]*J_a
        return utility_gain

    def utility_gain(self, g, q):
        utility_gain = 0
        # optimal for group g
        opt_g = self.opts[g]
        # Loop through all arms for group g.
        for a in self.group_to_avail_arms[g]:
            mu = self.mus[a]
            delta = opt_g - mu
            if delta == 0:
                # This arm is optimal for g.
                continue
            J_g_a = 1/kl(mu, opt_g)
            
            utility_gain += delta*J_g_a
            
            if a in self.suboptimal_arms:
                J_a = self.Js[a]
                utility_gain -= delta*q[self.get_var_index(g, a)]*J_a
        return utility_gain
        
