import csv
import logging
import os
import sys

import torch


def create_folders_if_necessary(path):
    dirname = os.path.dirname(path)
    if not os.path.isdir(dirname):
        os.makedirs(dirname)


def get_storage_dir():
    if "RL_STORAGE" not in os.environ:
        parent_dir = os.getcwd()
        if "scripts" in parent_dir:
            parent_dir = os.path.dirname(parent_dir)
        os.environ["RL_STORAGE"] = os.path.join(parent_dir, "storage")
    return os.environ["RL_STORAGE"]


def get_log_dir(run_name):
    return os.path.join(get_storage_dir(), run_name)


def get_model_dir(log_dir):
    return os.path.join(log_dir, "models")


def save_model_status(status, log_dir, num_frames):
    model_dir = get_model_dir(log_dir)
    agent_path = os.path.join(model_dir, "agent.pt")
    num_agent_path = os.path.join(model_dir, f"{num_frames}.pt")
    create_folders_if_necessary(agent_path)
    torch.save(status, agent_path)
    torch.save(status, num_agent_path)
    return model_dir


def get_txt_logger(log_dir):
    log_path = os.path.join(log_dir, "log.txt")
    create_folders_if_necessary(log_path)
    logging.basicConfig(
        level=logging.INFO,
        format="%(message)s",
        handlers=[logging.FileHandler(filename=log_path), logging.StreamHandler(sys.stdout)],
    )

    return logging.getLogger()


def get_csv_logger(model_dir):
    csv_path = os.path.join(model_dir, "log.csv")
    create_folders_if_necessary(csv_path)
    csv_file = open(csv_path, "a")
    return csv_file, csv.writer(csv_file)
