# coding=utf-8
import threading
import numpy as np
import tensorflow as tf
from ...utils.memory import ReplayMemory
from .dqn_algo import DqnAlgo

GAMMA = 0.9     # reward discount
TAU = 0.01      # soft replacement
RENDER = False
BATCH_SIZE = 1024

REPLAY_BUFFER_SIZE = 1000000
UPDATE_FREQ = 100
FIRST_UPDATE_SAMPLE_NUM = 25600
MODEL_UPDATE_FREQ = 1000

# high frequency target soft update is better for DQN
TARGET_UPDATE_FREQ = 1


class DqnTrainer(object):
    def __init__(self, state_dim, action_num, algo_name="dqn"):

        self.state_dim = state_dim
        self.action_num = action_num
        self.algo_name = algo_name

        self.init_algo()

    def init_algo(self):
        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)
        self.algorithm = DqnAlgo(self.session, self.graph, self.state_dim, self.action_num,
                                 epsilon_decay="exponential", is_test=False,
                                 algo_name=self.algo_name, grad_clip=True)

        self.update_cnt = 0

        # the replay buffer
        self.memory = ReplayMemory(REPLAY_BUFFER_SIZE)

        # also create a tf file writer for writing other information
        self.writer_graph = tf.Graph()
        self.my_writer = self.algorithm.train_writer

    def action(self, state, test_model):
        a, action_info = self.algorithm.choose_action(state, test_model)

        return a, action_info

    def experience(self, s, a_n, r_n, s_n, terminal, **kwargs):
        # put it to memory
        self.memory.add((s, a_n, r_n, s_n, terminal))

    def update(self, t):

        if len(self.memory.store) > FIRST_UPDATE_SAMPLE_NUM:
            # update frequency
            if not t % UPDATE_FREQ == 0:
                return

            # print('update for', self.update_cnt)
            self.update_cnt += 1

            # get mini batch from replay buffer
            sample = self.memory.get_minibatch(BATCH_SIZE)
            s_batch, a_batch, r_batch, sp_batch, done_batch = [], [], [], [], []

            for i in range(len(sample)):
                s_batch.append(sample[i][0])
                a_batch.append(sample[i][1])
                r_batch.append(sample[i][2])
                sp_batch.append(sample[i][3])
                done_batch.append(sample[i][4])

            self.algorithm.learn(np.array(s_batch), np.array(a_batch),
                                 np.array(r_batch).reshape([-1, 1]),
                                 np.array(sp_batch), np.array(done_batch).reshape([-1, 1]))

            # update target network
            if self.update_cnt % TARGET_UPDATE_FREQ == 0:
                self.algorithm.update_target_soft(tau=0.01)

            # save param
            self.save_params()

    def save_params(self):
        if self.update_cnt % MODEL_UPDATE_FREQ == 0 and self.update_cnt > 0:
            print('model saved for update', self.update_cnt)
            save_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(self.update_cnt)
            self.algorithm.saver.save(self.algorithm.sess, save_path)


    def load_params(self, load_cnt):
        print("load model for update %s " % load_cnt)
        load_path = './data/' + self.algo_name + '/model/{}.ckpt'.format(load_cnt)
        self.algorithm.saver.restore(self.algorithm.sess, load_path)

    def episode_done(self, test_model):
        self.algorithm.episode_done(test_model)

    def write_summary_scalar(self, iteration, tag, value, train_info):
        if train_info:
            self.algorithm.write_summary_scalar(iteration, tag, value)
        else:
            self.my_writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]), iteration)
