import torch
import os
import torchvision

from tqdm import tqdm
from torch import Tensor


class AverageMeter(object):
    """Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class Logger:
    def __init__(self, path: str, config: str) -> None:
        self.path = path
        if not config is None:
            with open(os.path.join(self.path, "log.txt"), "a") as f:
                f.write(config + "\n")

    def info(self, msg: str, print_msg: bool = False) -> None:
        if print_msg:
            print(msg)
        with open(os.path.join(self.path, "log.txt"), "a") as f:
            f.write(msg + "\n")


def log_images_to_tensorboard(writer, images: Tensor, tag: str, step: int):
    grid = torchvision.utils.make_grid(images)
    writer.add_image(tag, grid, step)


def log_stats(logger, current_epoch, stats, writer) -> None:
    logger.info("Epoch {}:".format(current_epoch))
    for stat_name, stat_value in stats.items():
        log_msg = "{}: {:.6f}".format(stat_name, stat_value)
        tqdm.write(log_msg)
        writer.add_scalar(stat_name, stat_value, current_epoch)
        logger.info(log_msg)
