import os
import numpy
import json
import copy
import numpy
import numpy as np
import argparse
import yaml
import ast
import re

from enum import Enum

from texture_change import RGB_SETTING


class RoadOption(Enum):
    """
    RoadOption represents the possible topological configurations when moving from a segment of lane to other.
    """
    VOID = -1
    LEFT = 1
    RIGHT = 2
    STRAIGHT = 3
    LANEFOLLOW = 4
    CHANGELANELEFT = 5
    CHANGELANERIGHT = 6

value_dict = {
    1: 'LEFT',
    2: 'RIGHT',
    3: 'STRAIGHT',
    4: 'LANEFOLLOW',
    5: 'CHANGELANELEFT',
    6: 'CHANGELANERIGHT',
}


def read_yaml_to_dict(file_path):
    with open(file_path, 'r') as file:
        data = yaml.safe_load(file)
    return data

def convert_string_to_list(input_str):
    # 괄호 안의 각 요소를 따옴표로 묶기
    # 공백을 포함한 요소들은 그대로 유지
    formatted_str = input_str.replace("\n", "")
    formatted_str = formatted_str.replace(" (", "(")
    formatted_str = formatted_str.replace(") ", ")")
    formatted_str = re.sub(r'([A-Za-z\s]+)', r'"\1"', formatted_str)

    # 괄호를 리스트 형식으로 변환
    formatted_str = formatted_str.replace("),(", "],[")
    formatted_str = formatted_str.replace("(", "[")
    formatted_str = formatted_str.replace(")", "]")
    formatted_str = formatted_str.replace(",\" ", ",\"")
    # 문자열을 파이썬 리스트로 변환
    result_list = list(eval(formatted_str))
    return result_list

def convert_string_to_list2(input_str):
    # 괄호 안의 각 요소를 따옴표로 묶기
    # 공백을 포함한 요소들은 그대로 유지
    formatted_str = input_str.replace("\n", "")
    formatted_str = formatted_str.replace(" [", "[")
    formatted_str = formatted_str.replace("] ", "]")
    formatted_str = re.sub(r'\[([^\[\]]+)\]', lambda m: '["' + m.group(1).replace(", ", '", "') + '"]', input_str)


    # 괄호를 리스트 형식으로 변환
    formatted_str = formatted_str.replace(",\" ", ",\"")
    # 문자열을 파이썬 리스트로 변환
    result_list = list(eval(formatted_str))
    return result_list

episode = {
    'episode_start_node': None,
    'episode_target_node': None,
    'rgb_setting': None,
    'config_setting': read_yaml_to_dict('carla_configs/dataset_config.yaml'),
    'action_seq': None,
    'trajectory': [
        # {
        #     'start_node': None,
        #     'target_node': None,
        #     'action_len': [],
        #     'order': None,
        #     'description': [],
        #     'image_name': [],
        #  },
    ],
}

if __name__=='__main__':
    parser = argparse.ArgumentParser(description='experiment setting')
    parser.add_argument('--color', type=int, default=0)
    args = parser.parse_args()



    dir_name = f'carla_dataset/train_text_{args.color}'
    for data_name in os.listdir(dir_name):
        json_dict = copy.deepcopy(episode)

        episode_start_node = int(data_name.split('_')[3])
        episode_target_node = int(data_name.split('_')[4][:-4])

        json_dict['rgb_setting'] = RGB_SETTING[args.color]

        text_list = os.listdir()
        with open(os.path.join(dir_name, data_name), 'r') as f:
            datas = f.readlines()
        start_node = int(datas[0].split('node_')[1].split('.png')[0])
        action_seq = np.load(os.path.join(f'./carla_action_traj/train_traj_{args.color}', f'route_{episode_start_node}_{episode_target_node}.npy')).tolist()
        json_dict['action_seq'] = action_seq

        action_lens = np.load(os.path.join(f'./carla_action_len/train_len_{args.color}', f'route_{episode_start_node}_{episode_target_node}.npy')).tolist()

        first_action_node = int(datas[0].split('node_')[1].split('.png')[0])
        end_action_node = int(datas[-1].split('node_')[1].split('.png')[0])

        if episode_start_node != first_action_node:
            print("Start changed", str(episode_start_node), str(first_action_node))
            print("Target changed", str(episode_target_node), str(end_action_node))
        for idx, data in enumerate(datas[:-1]):
            # if f'{episode_start_node}_{episode_target_node}' in data:
            order = int(data.split('order_')[1].split('_node')[0])
            target_node = int(datas[idx + 1].split('node_')[1].split('.png')[0])
            image_name, description = datas[idx].split(' | ')
            print(order, idx, image_name, data.split(' | ')[0], )
            # if episode_start_node != first_action_node:
            #     before_image_name = image_name
            #     image_name = image_name.replace(str(episode_start_node), str(first_action_node))
            # if episode_target_node != end_action_node:
            #     before_image_name = image_name
            #     image_name = image_name.replace(str(episode_target_node), str(end_action_node))

            action_data = {}
            action_data['timesteps'] = idx
            action_data['start_node'] = start_node
            action_data['target_node'] = target_node
            # action_data['description'] = convert_string_to_list(description)
            action_data['description'] = convert_string_to_list2(description)
            action_data['image_name'] = image_name
            # print(action_lens[idx])
            action_data['action_len'] = action_lens[idx][0]
            action_data['action_type'] = value_dict[action_lens[idx][1]]

            # print(image_name, start_node, target_node)

            json_dict['trajectory'].append(action_data)

            start_node = target_node
        episode_start_node = first_action_node
        episode_target_node = target_node

        json_dict['episode_start_node'] = episode_start_node
        json_dict['episode_target_node'] = episode_target_node
        print("Save to", f'./train_dataset/train_color_ver_{args.color}_route_{episode_start_node}_{episode_target_node}.json', len(action_seq), np.array(action_lens)[..., 0].sum())
        print()
        with open(f'./train_dataset/train_color_ver_{args.color}_route_{episode_start_node}_{episode_target_node}.json', "w") as json_file:
            json.dump(json_dict, json_file)
