import math
import torch
import numpy as np
import imagehash
from collections import defaultdict
from PIL import Image
'''
Code borrowed from abhinav's implementation
'''


class ImageHash:
    def __init__(self, name):
        self.name = name
        self.image2hash = None
        if name == "average":
            self.image2hash = imagehash.average_hash
        elif name == "perceptual":
            self.image2hash = imagehash.phash
        elif name == "wavelet":
            self.image2hash = imagehash.whash
        elif name == "difference":
            self.image2hash = imagehash.dhash

    def __call__(self, image):
        if self.name == "none":
            return image.tobytes()
        else:
            return self.image2hash(image)


class AtariStateHash:
    def __init__(self, image_hash="none"):
        self.image_hash = image_hash
        self.image_hash_fn = ImageHash(name=image_hash)

    def __call__(self, input_obs):
        """
        frames_tensor: torch.tensor[(1, 1, 4, 1, 84, 84)]
        """
        frames = input_obs.squeeze()  # [4, 84, 84]
        if self.image_hash != "none":
            # slow
            images = [Image.fromarray(frames[i]) for i in range(len(frames))]
            hashes = [self.image_hash_fn(image) for image in images]
            hashes_array_list = [image_hash.hash for image_hash in hashes]
            hashes_array = np.array(hashes_array_list)  # [4, 8, 8]
            state_hash = hashes_array.tobytes()
        else:
            state_hash = frames.tobytes()
        return state_hash
