import os
import pathlib
import torch

from .hierarchical_lstm import HierarchicalLSTM
from .hierarchical_transformer_lstm import HierarchicalTransformerLSTM
from .transformer_lstm import TransformerLSTM


from nle.env.base import DUNGEON_SHAPE
from omegaconf import OmegaConf



base_path = str(pathlib.Path().resolve())
hihack_path = os.path.join(base_path[:base_path.find('hihack')], 'hihack')
sys.path.insert(0, os.path.join(hihack_path, 'dungeonsdata-neurips2022/experiment_code/hackrl'))
from tasks import ENVS

MODELS = [
	HierarchicalLSTM,
	HierarchicalTransformerLSTM,
	TransformerLSTM
]

MODELS_LOOKUP = {c.__name__: c for c in MODELS}


def initialize_weights(flags, model):
    def _initialize_weights(layer):
        if hasattr(layer, "bias") and isinstance(
            layer.bias, torch.nn.parameter.Parameter
        ):
            layer.bias.data.fill_(0)

        if flags.initialisation == "orthogonal":
            if type(layer) in [torch.nn.Conv2d, torch.nn.Linear]:
                torch.nn.init.orthogonal_(layer.weight.data, gain=1.0)
        elif flags.initialisation == "xavier_uniform":
            if type(layer) in [torch.nn.Conv2d, torch.nn.Linear]:
                torch.nn.init.xavier_uniform_(layer.weight.data, gain=1.0)
            else:
                pass
        else:
            pass

    model.apply(_initialize_weights)



def create_model(flags, device):
    try:
        model_cls = MODELS_LOOKUP[flags.model]
    except KeyError:
        raise NotImplementedError("model=%s" % flags.model) from None

    action_space = ENVS[flags.env.name](savedir=None).actions

    
    model = model_cls(DUNGEON_SHAPE, action_space, flags, device)
    model.to(device=device)

    initialize_weights(flags, model)
    return model


def load_model(load_dir, device):
    flags = OmegaConf.load(load_dir + "/config.yaml")
    flags.checkpoint = load_dir + "/checkpoint.tar"
    model = create_model(flags, device)
    checkpoint_states = torch.load(flags.checkpoint, map_location=device)
    model.load_state_dict(checkpoint_states["model_state_dict"])
    return model