import glob
import os
import platform
from abc import ABC
from math import ceil
from typing import Dict, Any, List, Optional, Sequence

import ai2thor
import gym
import numpy as np
import torch
from packaging import version

from allenact.base_abstractions.experiment_config import MachineParams
from allenact.base_abstractions.preprocessor import SensorPreprocessorGraph
from allenact.base_abstractions.sensor import SensorSuite, ExpertActionSensor
from allenact.base_abstractions.task import TaskSampler
from allenact.utils.experiment_utils import evenly_distribute_count_into_bins
from allenact.utils.system import get_logger
from projects.plugins.ithor_plugin.ithor_util import (
    horizontal_to_vertical_fov,
    get_open_x_displays,
)
from projects.plugins.robothor_plugin.robothor_sensors import DepthSensorThor
from projects.plugins.robothor_plugin.robothor_task_samplers import (
    PointNavDatasetTaskSampler,
)
# from projects.plugins.robothor_plugin.robothor_tasks import PointNavTask
from projects.plugins.robothor_plugin.robothor_tasks import ObjectNavTask
from projects.point_navigation.baseline_configs.pointnav_base import (
    PointNavBaseConfig,
    PointNavAUTOTESTBaseConfig
)
if ai2thor.__version__ not in ["0.0.1", None] and version.parse(
    ai2thor.__version__
) < version.parse("2.7.2"):
    raise ImportError(
        "To run the PointNav baseline experiments you must use"
        " ai2thor version 2.7.1 or higher."
    )


class PointNavThorBaseConfig(PointNavBaseConfig, ABC):
    """The base config for all iTHOR PointNav experiments."""

    NUM_PROCESSES: Optional[int] = None
    TRAIN_GPU_IDS = list(range(torch.cuda.device_count()))
    VALID_GPU_IDS = [torch.cuda.device_count() - 1]
    TEST_GPU_IDS = [torch.cuda.device_count() - 1]

    TRAIN_DATASET_DIR: Optional[str] = None
    VAL_DATASET_DIR: Optional[str] = None
    TEST_DATASET_DIR: Optional[str] = None

    AGENT_MODE = "bot"

    TARGET_TYPES: Optional[Sequence[str]] = None

    THOR_COMMIT_ID: Optional[str] = None
    DEFAULT_THOR_IS_HEADLESS: bool = False
    # ACTION_SPACE = gym.spaces.Discrete(len(PointNavTask.class_action_names()))
    ACTION_SPACE = gym.spaces.Discrete(len(ObjectNavTask.class_action_names()))

    print(ObjectNavTask.class_action_names())
    # exit()

    def __init__(self):
        super().__init__()

    def env_args(self):
        # check before start!!
        lightingValue = self.LIGHTING_VALUE[self.process_ind-1]
        fov = self.HORIZONTAL_FIELD_OF_VIEW[self.process_ind-1]
        gridSize = self.STEP_SIZE[self.process_ind-1]
        rotateStepDegrees = self.ROTATION_DEGREES[self.process_ind-1]
        lookStepDegrees = self.LOOK_DEGREES[self.process_ind-1]
        self.process_ind -= 1

        return dict(
            width=self.CAMERA_WIDTH,
            height=self.CAMERA_HEIGHT,
            continuousMode=True,
            applyActionNoise=self.STOCHASTIC,
            rotateStepDegrees=rotateStepDegrees,
            gridSize=gridSize,
            snapToGrid=False,
            agentMode=self.AGENT_MODE,
            include_private_scenes=False,
            renderDepthImage=any(isinstance(s, DepthSensorThor) for s in self.SENSORS),
            fieldOfView=horizontal_to_vertical_fov(
                horizontal_fov_in_degrees=fov,
                width=self.CAMERA_WIDTH,
                height=self.CAMERA_HEIGHT,
            ),
            lightingValue=lightingValue,
            datagen=self.DATA_GEN,
            lookStepDegrees=lookStepDegrees,
        )

    def preprocessors(self):
        return tuple()

    def machine_params(self, mode="train", **kwargs):
        sampler_devices: Sequence[int] = []
        if mode == "train":
            workers_per_device = 1
            gpu_ids = (
                []
                if not torch.cuda.is_available()
                else self.TRAIN_GPU_IDS * workers_per_device
            )
            nprocesses = (
                1
                if not torch.cuda.is_available()
                else evenly_distribute_count_into_bins(self.NUM_PROCESSES, len(gpu_ids))
            )

            nprocesses = len(self.STEP_SIZE)

            self.process_ind = nprocesses
            sampler_devices = self.TRAIN_GPU_IDS
        elif mode == "valid":
            # nprocesses = 1
            nprocesses = 0 # no vaild during train
            self.process_ind = nprocesses

            gpu_ids = [] if not torch.cuda.is_available() else self.VALID_GPU_IDS
        elif mode == "test":
            nprocesses = 1
            self.process_ind = nprocesses
            gpu_ids = [] if not torch.cuda.is_available() else self.TEST_GPU_IDS
        else:
            raise NotImplementedError("mode must be 'train', 'valid', or 'test'.")

        sensor_preprocessor_graph = (
            SensorPreprocessorGraph(
                source_observation_spaces=SensorSuite(self.SENSORS).observation_spaces,
                preprocessors=self.preprocessors(),
            )
            if mode == "train"
            or (
                (isinstance(nprocesses, int) and nprocesses > 0)
                or (isinstance(nprocesses, Sequence) and sum(nprocesses) > 0)
            )
            else None
        )

        return MachineParams(
            nprocesses=nprocesses,
            devices=gpu_ids,
            sampler_devices=sampler_devices
            if mode == "train"
            else gpu_ids,  # ignored with > 1 gpu_ids
            sensor_preprocessor_graph=sensor_preprocessor_graph,
        )

    @classmethod
    def make_sampler_fn(cls, **kwargs) -> TaskSampler:
        return PointNavDatasetTaskSampler(**kwargs)

    @staticmethod
    def _partition_inds(n: int, num_parts: int):
        return np.round(np.linspace(0, n, num_parts + 1, endpoint=True)).astype(
            np.int32
        )

    def _get_sampler_args_for_scene_split(
        self,
        scenes_dir: str,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]],
        deterministic_cudnn: bool,
        include_expert_sensor: bool = True,
    ) -> Dict[str, Any]:
        path = os.path.join(scenes_dir, "*.json.gz")
        scenes = [scene.split("/")[-1].split(".")[0] for scene in glob.glob(path)]

        scenes = ['FloorPlan21'] * len(self.STEP_SIZE)
        # print(scenes)
        # exit()

        if len(scenes) == 0:
            raise RuntimeError(
                (
                    "Could find no scene dataset information in directory {}."
                    " Are you sure you've downloaded them? "
                    " If not, see https://allenact.org/installation/download-datasets/ information"
                    " on how this can be done."
                ).format(scenes_dir)
            )

        oversample_warning = (
            f"Warning: oversampling some of the scenes ({scenes}) to feed all processes ({total_processes})."
            " You can avoid this by setting a number of workers divisible by the number of scenes"
        )
        # if total_processes > len(scenes):  # oversample some scenes -> bias
        #     if total_processes % len(scenes) != 0:
        #         get_logger().warning(oversample_warning)
        #     scenes = scenes * int(ceil(total_processes / len(scenes)))
        #     scenes = scenes[: total_processes * (len(scenes) // total_processes)]
        # elif len(scenes) % total_processes != 0:
        #     get_logger().warning(oversample_warning)

        inds = self._partition_inds(len(scenes), total_processes)

        x_display: Optional[str] = None
        if platform.system() == "Linux":
            x_displays = get_open_x_displays(throw_error_if_empty=True)

            if len([d for d in devices if d != torch.device("cpu")]) > len(x_displays):
                get_logger().warning(
                    f"More GPU devices found than X-displays (devices: `{x_displays}`, x_displays: `{x_displays}`)."
                    f" This is not necessarily a bad thing but may mean that you're not using GPU memory as"
                    f" efficiently as possible. Consider following the instructions here:"
                    f" https://allenact.org/installation/installation-framework/#installation-of-ithor-ithor-plugin"
                    f" describing how to start an X-display on every GPU."
                )
            x_display = x_displays[process_ind % len(x_displays)]

        return {
            "scenes": scenes[inds[process_ind] : inds[process_ind + 1]],
            "object_types": self.TARGET_TYPES,
            "max_steps": self.MAX_STEPS,
            "sensors": [
                s
                for s in self.SENSORS
                if (include_expert_sensor or not isinstance(s, ExpertActionSensor))
            ],
            "action_space": self.ACTION_SPACE,
            "seed": seeds[process_ind] if seeds is not None else None,
            "deterministic_cudnn": deterministic_cudnn,
            "rewards_config": self.REWARD_CONFIG,
            "env_args": {**self.env_args(), "x_display": x_display,},
        }

    def train_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        res = self._get_sampler_args_for_scene_split(
            os.path.join(self.TRAIN_DATASET_DIR, "episodes"),
            process_ind,
            total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )
        res["scene_directory"] = self.TRAIN_DATASET_DIR
        res["loop_dataset"] = True
        res["allow_flipping"] = True
        return res

    def valid_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        res = self._get_sampler_args_for_scene_split(
            os.path.join(self.VAL_DATASET_DIR, "episodes"),
            process_ind,
            total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
            include_expert_sensor=False,
        )
        res["scene_directory"] = self.VAL_DATASET_DIR
        res["loop_dataset"] = False
        return res

    def test_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        return self.valid_task_sampler_args(
            process_ind=process_ind,
            total_processes=total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )

class PointNavThorAUTOTESTBaseConfig(PointNavAUTOTESTBaseConfig, ABC):
    """The base config for all iTHOR PointNav experiments."""

    NUM_PROCESSES: Optional[int] = None
    TRAIN_GPU_IDS = list(range(torch.cuda.device_count()))
    VALID_GPU_IDS = [torch.cuda.device_count() - 1]
    TEST_GPU_IDS = [torch.cuda.device_count() - 1]

    TRAIN_DATASET_DIR: Optional[str] = None
    VAL_DATASET_DIR: Optional[str] = None
    TEST_DATASET_DIR: Optional[str] = None

    AGENT_MODE = "bot"

    TARGET_TYPES: Optional[Sequence[str]] = None

    THOR_COMMIT_ID: Optional[str] = None
    DEFAULT_THOR_IS_HEADLESS: bool = False
    # ACTION_SPACE = gym.spaces.Discrete(len(PointNavTask.class_action_names()))
    ACTION_SPACE = gym.spaces.Discrete(len(ObjectNavTask.class_action_names()))

    print(ObjectNavTask.class_action_names())
    # exit()

    def __init__(self):
        super().__init__()

    def env_args(self):
        # check before start!!
        lightingValue = self.LIGHTING_VALUE
        fov = self.HORIZONTAL_FIELD_OF_VIEW
        gridSize = self.STEP_SIZE
        rotateStepDegrees = self.ROTATION_DEGREE
        lookStepDegrees = self.LOOK_DEGREE

        return dict(
            width=self.CAMERA_WIDTH,
            height=self.CAMERA_HEIGHT,
            continuousMode=True,
            applyActionNoise=self.STOCHASTIC,
            rotateStepDegrees=rotateStepDegrees,
            gridSize=gridSize,
            snapToGrid=False,
            agentMode=self.AGENT_MODE,
            include_private_scenes=False,
            renderDepthImage=any(isinstance(s, DepthSensorThor) for s in self.SENSORS),
            fieldOfView=horizontal_to_vertical_fov(
                horizontal_fov_in_degrees=fov,
                width=self.CAMERA_WIDTH,
                height=self.CAMERA_HEIGHT,
            ),
            lightingValue=lightingValue,
            datagen=self.DATA_GEN,
            lookStepDegrees=lookStepDegrees,
        )

    def preprocessors(self):
        return tuple()

    def machine_params(self, mode="train", **kwargs):
        sampler_devices: Sequence[int] = []
        if mode == "train":
            workers_per_device = 1
            gpu_ids = (
                []
                if not torch.cuda.is_available()
                else self.TRAIN_GPU_IDS * workers_per_device
            )
            nprocesses = (
                1
                if not torch.cuda.is_available()
                else evenly_distribute_count_into_bins(self.NUM_PROCESSES, len(gpu_ids))
            )

            self.process_ind = nprocesses
            sampler_devices = self.TRAIN_GPU_IDS
        elif mode == "valid":
            # nprocesses = 1
            nprocesses = 0 # no vaild during train
            self.process_ind = nprocesses

            gpu_ids = [] if not torch.cuda.is_available() else self.VALID_GPU_IDS
        elif mode == "test":
            nprocesses = 1
            self.process_ind = nprocesses
            gpu_ids = [] if not torch.cuda.is_available() else self.TEST_GPU_IDS
        else:
            raise NotImplementedError("mode must be 'train', 'valid', or 'test'.")

        sensor_preprocessor_graph = (
            SensorPreprocessorGraph(
                source_observation_spaces=SensorSuite(self.SENSORS).observation_spaces,
                preprocessors=self.preprocessors(),
            )
            if mode == "train"
            or (
                (isinstance(nprocesses, int) and nprocesses > 0)
                or (isinstance(nprocesses, Sequence) and sum(nprocesses) > 0)
            )
            else None
        )

        return MachineParams(
            nprocesses=nprocesses,
            devices=gpu_ids,
            sampler_devices=sampler_devices
            if mode == "train"
            else gpu_ids,  # ignored with > 1 gpu_ids
            sensor_preprocessor_graph=sensor_preprocessor_graph,
        )

    @classmethod
    def make_sampler_fn(cls, **kwargs) -> TaskSampler:
        return PointNavDatasetTaskSampler(**kwargs)

    @staticmethod
    def _partition_inds(n: int, num_parts: int):
        return np.round(np.linspace(0, n, num_parts + 1, endpoint=True)).astype(
            np.int32
        )

    def _get_sampler_args_for_scene_split(
        self,
        scenes_dir: str,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]],
        seeds: Optional[List[int]],
        deterministic_cudnn: bool,
        include_expert_sensor: bool = True,
    ) -> Dict[str, Any]:
        path = os.path.join(scenes_dir, "*.json.gz")
        scenes = [scene.split("/")[-1].split(".")[0] for scene in glob.glob(path)]

        scenes = ['FloorPlan21']

        if len(scenes) == 0:
            raise RuntimeError(
                (
                    "Could find no scene dataset information in directory {}."
                    " Are you sure you've downloaded them? "
                    " If not, see https://allenact.org/installation/download-datasets/ information"
                    " on how this can be done."
                ).format(scenes_dir)
            )

        oversample_warning = (
            f"Warning: oversampling some of the scenes ({scenes}) to feed all processes ({total_processes})."
            " You can avoid this by setting a number of workers divisible by the number of scenes"
        )
        # if total_processes > len(scenes):  # oversample some scenes -> bias
        #     if total_processes % len(scenes) != 0:
        #         get_logger().warning(oversample_warning)
        #     scenes = scenes * int(ceil(total_processes / len(scenes)))
        #     scenes = scenes[: total_processes * (len(scenes) // total_processes)]
        # elif len(scenes) % total_processes != 0:
        #     get_logger().warning(oversample_warning)

        inds = self._partition_inds(len(scenes), total_processes)

        x_display: Optional[str] = None
        if platform.system() == "Linux":
            x_displays = get_open_x_displays(throw_error_if_empty=True)

            if len([d for d in devices if d != torch.device("cpu")]) > len(x_displays):
                get_logger().warning(
                    f"More GPU devices found than X-displays (devices: `{x_displays}`, x_displays: `{x_displays}`)."
                    f" This is not necessarily a bad thing but may mean that you're not using GPU memory as"
                    f" efficiently as possible. Consider following the instructions here:"
                    f" https://allenact.org/installation/installation-framework/#installation-of-ithor-ithor-plugin"
                    f" describing how to start an X-display on every GPU."
                )
            x_display = x_displays[process_ind % len(x_displays)]

        return {
            "scenes": scenes[inds[process_ind] : inds[process_ind + 1]],
            "object_types": self.TARGET_TYPES,
            "max_steps": self.MAX_STEPS,
            "sensors": [
                s
                for s in self.SENSORS
                if (include_expert_sensor or not isinstance(s, ExpertActionSensor))
            ],
            "action_space": self.ACTION_SPACE,
            "seed": seeds[process_ind] if seeds is not None else None,
            "deterministic_cudnn": deterministic_cudnn,
            "rewards_config": self.REWARD_CONFIG,
            "env_args": {**self.env_args(), "x_display": x_display,},
        }

    def train_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        res = self._get_sampler_args_for_scene_split(
            os.path.join(self.TRAIN_DATASET_DIR, "episodes"),
            process_ind,
            total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )
        res["scene_directory"] = self.TRAIN_DATASET_DIR
        res["loop_dataset"] = True
        res["allow_flipping"] = True
        return res

    def valid_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        res = self._get_sampler_args_for_scene_split(
            os.path.join(self.VAL_DATASET_DIR, "episodes"),
            process_ind,
            total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
            include_expert_sensor=False,
        )
        res["scene_directory"] = self.VAL_DATASET_DIR
        res["loop_dataset"] = False
        return res

    def test_task_sampler_args(
        self,
        process_ind: int,
        total_processes: int,
        devices: Optional[List[int]] = None,
        seeds: Optional[List[int]] = None,
        deterministic_cudnn: bool = False,
    ) -> Dict[str, Any]:
        return self.valid_task_sampler_args(
            process_ind=process_ind,
            total_processes=total_processes,
            devices=devices,
            seeds=seeds,
            deterministic_cudnn=deterministic_cudnn,
        )