import numpy as np
import itertools
from pomdp_env import sample


class NPG:
    def __init__(self, p, Z, alpha, tau):
        Q = []
        self.pi = []
        self.as_space = []
        self.s_space = []
        self.p = p
        self.Z = Z
        self.alpha = alpha
        self.tau = tau
        for i in range(p.H + 1):
            Q_dict = np.zeros(shape=(p.s_num, p.a_num))
            pi_dict = np.ones(shape=(p.s_num, p.a_num)) / p.a_num
            Q.append(Q_dict)
            self.pi.append(pi_dict)
        self.est = dict()
        for state, a in itertools.product(list(range(p.s_num)), list(range(p.a_num))):
            self.est[(state, a)] = np.zeros(shape=(p.s_num,))
        self.decode = dict()
        for state, a, o in itertools.product(list(range(p.s_num)), list(range(p.a_num)), list(range(p.o_num))):
            self.decode[(o,)] = 0
            self.decode[(state, a, o)] = np.random.randint(p.s_num)
        self.Q_list = [Q]

    def learn(self, K):
        traj_list = []
        r_sum_list = []
        for k in range(K):
            traj, r_sum = self.run_traj()
            r_sum_list.append(self.evaluate(100))
            print(r_sum_list[-1])
            traj_list.append(traj)
            self.update_Q(traj)
            self.update_pi()
        return r_sum_list

    def update_Q(self, traj):
        self.new_Q = self.Q_list[-1].copy()
        for h in range(self.p.H):
            self.est[(traj[4 * h], traj[4 * h + 2])][traj[4 * h + 4]] += 1
            if h == 0:
                self.decode[(traj[4 * h + 1])] = traj[4 * h]
            else:
                self.decode[(traj[4 * h - 4], traj[4 * h - 2], traj[4 * h + 1])] = traj[4 * h]
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            for s in range(self.p.s_num):
                for a in range(self.p.a_num):
                    state = s
                    bonus = 1 / max(np.sum(self.est[(state, a)]), 1)
                    expectation = 0
                    if np.sum(self.est[(state, a)]) < 0.5:
                        emp_tran = np.ones(shape=(self.p.s_num,)) / self.p.s_num
                    else:
                        emp_tran = self.est[(state, a)] / np.sum(self.est[(state, a)])
                    for state_prime in range(self.p.s_num):
                        if emp_tran[state_prime] < 0.01:
                            continue
                        for a_prime in range(self.p.a_num):
                            if self.new_Q[h + 1][state_prime, a_prime] < 0.01:
                                continue
                            expectation += emp_tran[state_prime] * self.pi[h + 1][state_prime, a_prime] * \
                                           self.new_Q[h + 1][state_prime, a_prime]
                    self.new_Q[h][s][a] = self.p.reward[h][state, a] + bonus + expectation
        self.Q_list.append(self.new_Q)

    def truncate(self, his):
        if len(his) > self.Z * 2 + 1:
            s = tuple(his[-(self.Z * 2 + 1):])
        else:
            s = tuple(his)
        return s

    def update_pi(self):
        Q = self.Q_list[-1]
        for i in range(self.p.H):
            h = self.p.H - 1 - i
            for s in range(self.p.s_num):
                for a in range(self.p.a_num):
                    self.pi[h][s, a] *= np.exp(self.tau * Q[h][s, a] / (h + 1))
                self.pi[h][s] /= self.pi[h][s].sum()

    def run_traj(self, pa=False):
        s, o = self.p.reset()
        his = [o]
        r_sum = 0
        extended_hist = [s, o]
        for h in range(self.p.H):
            if not pa:
                a = sample(self.pi[h][extended_hist[-2]])
            else:
                a = sample(self.pi[h][self.decode_s(extended_hist)])
            s, o, r, _ = self.p.step(a)
            his += [a, o]
            extended_hist += [a, r, s, o]
            r_sum += r
        return extended_hist, r_sum

    def evaluate(self, K):
        r_sum = 0
        for k in range(K):
            r_sum += self.run_traj(pa=True)[1]
        return r_sum / K

    def decode_s(self, extended_hist):
        s = self.decode[(extended_hist[1],)]
        for h in range(len(extended_hist) // 4):
            s = self.decode[(s, extended_hist[4 * h + 2], extended_hist[4 * h + 5])]
        return s
