'''
Author: 
Email: 
Date: 2022-03-10 12:28:43
LastEditTime: 2022-03-20 16:05:25
Description: 
'''

from .roomgrid import RoomGrid
import numpy as np
import gym

class Unlock(RoomGrid):
    """
    Unlock a door
    """

    def __init__(self, test_mode='IID', stage='train', seed=None):
        assert test_mode in ['IID', 'OOD']
        self.test_mode = test_mode
        self.stage = stage
        room_size = 6
        max_steps = 15
        self.room_size = room_size
        self.observation_space = gym.spaces.Box(-0., 1., shape=(room_size*room_size*3, ), dtype=np.float32)
        self.action_space = gym.spaces.Box(-0., 1., shape=(8, ), dtype=np.float32)
        
        super().__init__(num_rows=1, num_cols=1, room_size=room_size, max_steps=max_steps, seed=seed)
        
    def _gen_grid(self, width, height, stage):
        super()._gen_grid(width, height)

        # place an agent in the room (0, 0)
        self.place_agent(0, 0)

        # extrpolation setting
        if self.test_mode == 'OOD':
            # in training stage, we only have one door
            if stage == 'train':
                door_idx = np.random.choice([0, 1])
                door, door_pos = self.add_door(0, 0, door_idx=door_idx, color=None, locked=True)
            # in testing stage, there are two doors
            else:
                door_idx_1 = 0
                door_idx_2 = 1
                door, door_pos = self.add_door(0, 0, door_idx=door_idx_1, color=None, locked=True)
                door_2, door_pos_2 = self.add_door(0, 0, door_idx=door_idx_2, color=door.color, locked=True)
                assert door.color == door_2.color
        else:
            door_idx = 0
            door, door_pos = self.add_door(0, 0, door_idx=door_idx, color=None, locked=True)
            fixed_row = None

        # add one key
        self.add_object(0, 0, 'key', door.color, fixed_row)
        self.mission = "unlock the door"

    def reset(self):
        obs = super().reset(self.stage)
        return obs

    def step(self, action):
        obs, reward, done, info = super().step(action)
        return obs, reward, done, info
