'''
Script to run simple tabular RL experiments.
'''
import os.path
import pickle
import random

import numpy as np
import math
import pandas as pd
from tqdm import tqdm
import multiprocessing
from shutil import copyfile


def Interact_with_nondetermin_policy(env, agent):
    '''
    Agents use nondeterministic policies to interact with the environment

    Args:
        env: the enviroment
        agent: the agent

    Returns:
        the episode value of the policy
    '''

    # Reset the environment
    oldState = env.reset()
    start_state = env.state
    start_time = env.timestep

    # Nondeterministic policy evaluation
    qVals, vVals = env.softmax_policy_evaluation(agent.policy_entropy)

    done = 0
    h = 0
    while done == 0:
        # Step through the episode
        action = agent.pick_action(oldState)
        reward, newState, done = env.step(action)
        agent.update_obs(oldState, action, reward, newState, done)
        oldState = newState
        h += 1
    epValue = vVals[start_time, start_state]
    return env, agent, epValue, agent.logp_gap


def Interact_with_determin_policy(env, agent):
    '''
    Agents use deterministic policies to interact with the environment

    Args:
        env: the enviroment
        agent: the agent

    Returns:
        the episode value of the policy
    '''
    # Reset the environment
    oldState = env.reset()
    start_state = env.state
    start_time = env.timestep

    policy = np.argmax(agent.qVals, axis=-1)
    qVals, vVals = env.policy_evaluation(policy)

    done = 0
    h = 0
    while done == 0:
        # Step through the episode
        action = agent.pick_action(oldState)
        reward, newState, done = env.step(action)

        agent.update_obs(oldState, action, reward, newState, done)
        oldState = newState
        h += 1

    epValue = vVals[start_time, start_state]

    return env, agent, epValue

def generate_state_buffer(agents,envs,nTrials,seed):
    np.random.seed(seed)
    for trial in range(nTrials):

        env = envs[trial]
        agent = agents[trial]
        #generate the random policy
        policy = np.random.random((agent.nState,agent.nAction))
        policy=policy/policy.sum(axis=1)[:,None]
        #generate the uniform policy
        #policy=1/agent.nAction*np.ones((agent.nState,agent.nAction))
        #Sample 10 rounds of data using the random policy
        state_buffer=[]
        for ep in range(500):
            # Reset the environment
            oldState = env.reset()
            done=0
            state_list=[]
            h=0
            while done == 0:
                # Step through the episode

                action = np.random.choice(agent.nAction,p=policy[oldState,:])
                reward, newState, done = env.step(action)
                state_list.append((oldState,action,h))
                oldState = newState
                h+=1
            state_buffer.append(state_list)

        agent.get_state_buffer(state_buffer)

def run_finite_tabular_experiment(agent, env_sampler, nEps, nTrials, seed=1, pal=False, alg='none',
                                  recFreq=100, fileFreq=1000, targetPath='tmp.csv', folderName='temp'):
    '''
    A simple script to run a finite MDP experiment

    Args:
        agent - A constructor of finite tabular agents
        env_sampler - sample a FiniteHorizonTabularMDP
        nEps - number of episodes to run
        nTrial - number of trials to run
        seed - numpy random seed
        recFreq - how many episodes between logging
        fileFreq - how many episodes between writing file
        targetPath - where to write the csv

    Returns:
        NULL - data is output to targetPath as csv file
    '''


    envs = env_sampler()
    agents = agent()
    vOptVals = []

    #generate state data according to the random policy
    #generate_state_buffer(agents,envs,nTrials,seed)

    for i in range(nTrials):
        qVals, vVals = envs[i].value_iteration()
        vOptVals.append(vVals)
    data = []

    cumRegret = np.zeros(nTrials)
    cumReward = 0

    cores = multiprocessing.cpu_count()
    print(cores)

    Totalv_s_s = []
    Totalv_k_s = []
    Totalv_k = []
    for ep in tqdm(range(1, nEps + 1)):
        v_k = np.zeros(nTrials)
        epValue = np.zeros(nTrials)
        epOptVals = np.zeros(nTrials)
        model_iteration_v = np.zeros(nTrials)
        m_start_policy = np.zeros(nTrials)
        sc = np.zeros(nTrials)
        if pal == True:
            pool = multiprocessing.Pool(processes=cores)
            if alg == "FiniteBOO":
                with pool:
                    T = pool.starmap(Interact_with_nondetermin_policy, zip(envs, agents))
            else:
                with pool:
                    T = pool.starmap(Interact_with_determin_policy, zip(envs, agents))
            envs = [T[i][0] for i in range(nTrials)]
            agents = [T[i][1] for i in range(nTrials)]
            epValue = np.array([T[i][2] for i in range(nTrials)])
            epOptVals = np.array([vOptVals[i][0, 0] for i in range(nTrials)])
            # v_k = 0  # np.array([agents[i].logp_gap for i in range(nTrials)])
            # model_iteration_v = 0  # np.array([agents[i].model_value_iteration() for i in range(nTrials)])
            # # sc=np.array([agents[i].scaling for i in range(nTrials)])
        else:

            for trial in range(nTrials):
                env = envs[trial]
                agent = agents[trial]

                # Reset the environment
                oldState = env.reset()
                start_state = env.state
                start_time = env.timestep
                v_k[trial] = agent.logp_gap
                if alg=="FiniteBOO":
                    qVals, vVals = env.softmax_policy_evaluation(agent.policy_entropy)
                else:
                    policy = np.argmax(agent.qVals, axis=-1)
                    qVals, vVals = env.policy_evaluation(policy)
                done = 0
                h = 0

                while done == 0:
                    # Step through the episode
                    action = agent.pick_action(oldState)
                    reward, newState, done = env.step(action)

                    agent.update_obs(oldState, action, reward, newState, done)
                    oldState = newState
                    h += 1

                epOptVals[trial] = vOptVals[trial][start_time, start_state]
                epValue[trial] = vVals[start_time, start_state]
        #print(epValue)
        cumRegret += epOptVals - epValue
        # print(epValue)
        Totalv_s_s.append(epOptVals)
        Totalv_k_s.append(epValue)
        Totalv_k.append(v_k)
        diffcount = 0
        for i in range(nTrials):
            if epOptVals[i] - epValue[i] > 0.01:
                diffcount += 1
        epValue = np.mean(epValue)
        cumReward += epValue

        # Variable granularity
        recFreq = 10 ** math.floor(np.log10(ep)) if ep < 100 else 100
        fileFreq = 100
        # Logging to dataframe
        if ep % recFreq == 0:
            data.append([ep, np.mean(epOptVals), np.mean(sc), np.mean(v_k), np.mean(model_iteration_v), np.mean(m_start_policy), epValue, cumReward, np.mean(cumRegret), np.std(cumRegret) / np.sqrt(nTrials), diffcount])
            print(
                f'episode: {ep:6d} realValue: {np.mean(epOptVals)} scaling: {np.mean(sc):3.4f} kkValue: {np.mean(v_k):3.4f} model_value:{np.mean(model_iteration_v):3.4f} m_start_policy: {np.mean(m_start_policy):3.4f} epValue: {epValue:5.4f} cumRegret: {np.mean(cumRegret):6.2f} cumRegretSEM: {np.std(cumRegret) / np.sqrt(nTrials):5.5f} diffcount：{diffcount}')

        if ep % max(fileFreq, recFreq) == 0:
            dt = pd.DataFrame(data,
                              columns=['episode', 'realValue', 'scaling', 'kkValue', 'model_value', 'm_start_policy', 'epValue', 'cumReward',
                                       'cumRegret', 'cumRegretSEM', 'diffcount'])
            print('Writing to file ' + targetPath)
            dt.to_csv('tmp.csv', index=False, float_format='%.5f')
            copyfile('tmp.csv', targetPath)
            print('****************************')

    # pickle.dump(Totalv_s_s, open(folderName+"Totalv_s_s", "wb"))
    # pickle.dump(Totalv_k_s, open(folderName+"Totalv_k_s", "wb"))
    pickle.dump(Totalv_k, open(folderName + "Totalv_k", "wb"))
    print('**************************************************')
    print('Experiment complete')
    print('**************************************************')
