import yaml
import numpy as np
from Baird import Baird
from ThetaTwoTheta import ThetaTwoTheta
from utils import get_agent
import argparse
# from mdp import MDP
# from mdp_config import get_mdp
# from graph_utils import generate_mixing_matrix,get_graph
# from utils import sovleBellman_equation,calc_bellman_error
# import copy
# from agent import get_agent


def make_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--eta", type=float, default=1)

    args = parser.parse_args()

    return args


def main(config):

    num_iter = config["num_iter"]
    alpha = config["alpha"]

    gamma = config["gamma"]

    normalize = config["normalize"]

    env = Baird(normalize,False,0)

    num_states = env.NUM_STATES
    num_actions = env.NUM_ACTIONS


    agent = get_agent(config["agent"])(env,config)
    
    state =env.reset()

    norm_hist = []
    for step in range(num_iter):
        norm_hist.append(np.linalg.norm(agent.primal_weight(),ord=np.inf))

        #action =0
        action =  np.random.choice([0,1],p=[5/6,1/6])
        next_state,reward,done = env.step(action)

        done_mask = 0 if done else 1

        agent.update(state, next_state, action, reward, done_mask)
        
    return norm_hist


    
if __name__ == '__main__':

    args = make_parser()

    with open('config.yaml') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    for key in vars(args).keys():
        config[key] = vars(args)[key]    
    

    results = []
    for i in range(5):
        results.append(main(config))
    np.save(f'results/env-{config["mdp_name"]}-{config["exp"]}_agent_{config["agent"]}_normalize-{config["normalize"]}.npy',results)


    # results = []
    # for i in range(config["num_runs"]):

    #     print('i-th run:',i)
        
    #     error = main(config)
    #     results.append(error)
    # np.save(f'result/{config["exp_num"]}_N-{config["N"]}_graph-{config["graph_type"]}-agent_{config["agent"]}.npy',results)


