import os
import glob
import numpy as np
import pandas as pd

import wandb

from tensorboardX import SummaryWriter

# fmt: off
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# fmt: on


class Recorder:
    def __init__(
        self, save_path, wandb_project=None, wandb_entity=None, config={}, group=None
    ) -> None:
        self.save_path = os.path.join(save_path, "records.csv")

        # If there is a SummryWriter in the same directory, remove it
        # and create a new one.
        event_file = glob.glob(os.path.join(save_path, "events.out.tfevents.*"))
        if event_file:
            event_file = event_file[0]
            os.remove(event_file)
        self.writer = SummaryWriter(save_path)

        self.data = pd.DataFrame()

        if wandb_project is not None:
            assert wandb_entity is not None, "wandb_entity must be provided"
        if wandb_entity is not None:
            assert wandb_project is not None, "wandb_project must be provided"

        self.wandb = False
        if wandb_project is not None and wandb_entity is not None:
            wandb.login()
            wandb.init(
                project=wandb_project,
                entity=wandb_entity,
                config=config,
                group=group,
            )
            self.wandb = True

    def record(self, timestep, key, value):
        """Record new value to dataframe and writer.

        Args:
            timestep (int): Timestep of the value.
            key (str): Key of the value.
            value (float): Value to record.
        """
        if key not in self.data.columns:
            self.data[key] = np.nan

        # append new row if timestep is not in dataframe
        if timestep not in self.data.index:
            self.data.loc[timestep] = np.nan

        self.data.loc[timestep, key] = value
        self.writer.add_scalar(key, value, timestep)

        if self.wandb:
            wandb.log({key: value}, step=timestep)

    def record_dict(self, timestep, record_dict):
        """Record new values to dataframe and writer.

        Args:
            timestep (int): Timestep of the value.
            record_dict (dict): Dictionary of key-value pairs to record.
        """
        for key, value in record_dict.items():
            self.record(timestep, key, value)

    def dump(self, *args, **kwargs):
        self.data.to_csv(self.save_path, *args, **kwargs)

    def __del__(self):
        self.writer.close()
        self.dump()
        if self.wandb:
            wandb.finish()
