import random
import gym
from gym import spaces
import torchvision.transforms as T 
from PIL import Image

import numpy as np
from metaworld.policies.policy import Policy

# Dynamics sets
CAMS=[
    'cam00', 'cam12', 'cam24', 'cam36',
    'cam04', 'cam08', 'cam16', 'cam20', 'cam28', 'cam32',
]
# CAMS=[
#     'cam00'
# ]
XWIND=[
    0.0, -0.02, 0.10, -0.04, 
    0.08, 0.04, 0.02, 0.06, -0.06, -0.08
] # wind applied on x-axis
GRAVITY=[
    0.0, 0.08, 0.12, 0.18,
    0.02, 0.04, 0.06, 0.1, 0.14, 0.16
] # gravity a  z-axis
# Illumination sets
BRIGHT = [
    1.0, 0.8, 3.2, 1.8, 
    1.9, 1.4, 2.6, 1.6, 2.4, 2.2
] # brightness applied on image
CONTRAST = [
    1.0, 2.4, 3.3, 0.5,
    1.05, 0.95, 1.1, 1.15, 1.2, 1.5
] # contrast appliend on image
SATURATION = [
    1.1, 1.1, 1.1, 1.1,
    1.1, 1.2, 0.8, 1.6, 0.4, 0.2
] # brightness applied on image
HUE = [
    -0.2, -0.2, 0.5, -0.3, 
    0.1, 0.2, -0.4, -0.5, 0.4, 0.3
] # brightness applied on image


# XWIND=[
#     0.0, -0.02, 0.10, -0.04, 
#     0.08, 0.04, 0.02, 0.06,
# ] # wind applied on x-axis
# GRAVITY=[
#     0.0, 0.08, 0.12, 0.18,
#     0.02, 0.04, 0.06, 0.1,
# ] # gravity a  z-axis
# # Illumination sets
# BRIGHT = [
#     1.8, 1.8, 1.8, 1.8, 
#     1.9, 1.4, 1.4, 1.6, 
# ] # brightness applied on image
# CONTRAST = [
#     2.4, 2.4, 3.3, 0.5,
#     1.05, 0.95, 1.1, 1.15, 
# ] # contrast appliend on image
# SATURATION = [
#      0.5, 1.1, 1.1, 1.1,
#     1.1, 1.2, 1.6, 1.6,
# ] # brightness applied on image
# HUE = [
#      0.5, 0.5, 0.5, 0.5, 
#     0.5, 0.5, 0.5, 0.5,
# ] # brightness applied on image

# Dynamics sets
U_CAMS=[
    'cam06', 'cam10', 'cam12', 'cam14', 'cam18',
    'cam22', 'cam26', 'cam30', 'cam34', 'cam12'
]
U_XWIND=[
    0.01, 0.03, 0.05, 0.07, 0.09, \
    0.11, -0.03, -0.05, -0.07, -0.09
] # wind applied on x-axis
U_GRAVITY=[
    0.01, 0.03, 0.05, 0.07, 0.09, \
    0.11, 0.13, 0.15, 0.17, 0.19
] # gravity a  z-axis
# Illumination sets
U_BRIGHT = [
    1.0*1.14, 1.1*1.14, 1.2*1.14, 0.8*1.14, 1.4*1.14,
    0.6*1.54, 1.6*1.54, 0.4*1.54, 1.8*1.54, 0.2*1.54
] # brightness applied on image
U_CONTRAST = [
    1.0*1.14, 1.05*1.14, 0.95*1.14, 1.1*1.14, 0.9*1.14,
    1.15*1.54, 1.2*1.54, 1.3*1.54, 1.4*1.54, 1.5*1.54
] # contrast appliend on image
U_SATURATION = [
    1.0*1.14, 1.1*1.14, 1.2*1.14, 0.8*1.14, 1.4*1.14,
    0.6*1.54, 1.6*1.54, 0.4*1.54, 1.8*1.54, 0.2*1.54
] # brightness applied on image
U_HUE = [
    0.01, 0.124, -0.233, 0.2345, -0.445,
    -0.48, 0.49, 0.423, -0.311, 0.321
] # brightness applied on image

# Environment Wrapper
class MetaWorldWrapper(gym.Env): 
    def __init__(self, mt, task, max_step, terminate, dict_obs=False, expert=None, 
            xwind_id=0, gravity_id=0, camera_id='cam0-0', bright_id=0, contrast_id=0, saturation_id=0, hue_id=0, set_source=False, seed=0):
        random.seed(seed)
        self.mt = mt
        self.task = task 
        self.max_step = max_step
        self.terminate = terminate
        self.expert = expert
        self.dict_obs = dict_obs

        # setting dynamics
        self.set_source = False # seen 
        self.set_seen_random = False # seen
        self.set_unseen_random = False # seen
        self.set_target = False # unseen
        self.set_dfs(camera_id, xwind_id, gravity_id, bright_id, contrast_id, saturation_id, hue_id)
        self.env_cls = self.mt.train_classes[task]

        _ = self.reset()
        if not dict_obs:
            self.observation_space = self.env.observation_space
        else:
            self.observation_space = gym.spaces.Dict({
                #'obs': self.env.observation_space,
                'image': gym.spaces.Box(shape=(3, 224, 224), low=0, high=1),
            })

        self.action_space = spaces.Box(
            self.env.action_space.low,
            self.env.action_space.high,
            dtype=np.float32,
        )
        #self.action_space = self.env.action_space
        # Let's set source tasks !
        self.set_source = set_source
        if self.set_source:
            self.set_source_tasks()


    def step(self, action):
        # use expert for onlineC
        expert_action = None
        if self.expert is not None:
            expert_action = self.expert.get_action(self.expert_obs)
        
        # apply dynamics to action
        action[1] = action[1] + self.xwind # x-aixs
        action[2] = action[2] + self.gravity # z-axis
        
        obs, rew, done, info = self.env.step(action)
        self.expert_obs = obs
        self.steps += 1
        if self.steps == self.max_step:
            done = True
        
        info['is_success'] = False
        if info['success']:
            info['is_success'] = True
            if self.terminate:
                done = True

        if self.dict_obs:
            _obs = dict()
            #_obs['obs'] = np.array(obs)
            _obs['image'] = np.array(self.render(camera_name=self.camera_id, resolution=(224, 224)))
            obs = _obs
        info['expert_action'] = expert_action
        info['metadata'] = {
                'cam': self.camera_id,
                'xwind': XWIND.index(self.xwind),
                'gravity': GRAVITY.index(self.gravity),
                'bright': BRIGHT.index(self.bright),
                'contrast': CONTRAST.index(self.contrast),
                'saturation': SATURATION.index(self.saturation),
                'hue': HUE.index(self.hue),
            }
        return obs, rew, done, info

    def reset(self):
        self.env = self.env_cls()
        # single task 
        #task = random.choice([task for task in self.mt.train_tasks if task.env_name == self.task])
        task = [task for task in self.mt.train_tasks if task.env_name == self.task][0]
        self.steps = 0
        self.env.set_task(task)
        
        # source task sampling
        if self.set_source:
            self.set_source_tasks()
        
        # random task sampling
        if self.set_seen_random:
            self.set_seen_random_tasks()
        elif self.set_unseen_random:
            self.set_unseen_random_tasks()

        obs = self.env.reset()
        self.expert_obs = obs
        if self.dict_obs:
            _obs = dict()
            #_obs['obs'] = np.array(obs)
            _obs['image'] = np.array(self.render(camera_name=self.camera_id, resolution=(224, 224))) 
            obs = _obs
        return obs

    def render(self, offscreen=True, camera_name="corner2", resolution=(640, 480), original=False):
        image = self.env.render(offscreen, camera_name, resolution)

        if original:
            return image
        
        transform = T.Compose([
            T.ColorJitter(
                brightness=(self.bright, self.bright),
                contrast=(self.contrast, self.contrast),
                saturation=(self.saturation, self.saturation),
                hue=(self.hue, self.hue),
            ),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.299, 0.224, 0.225)),
        ])
        
        image = Image.fromarray(image) 
        image = transform(image)
        image = np.array(image)
        return image

    def set_source_tasks(self):
        # fov, grav, xwind
        self.source_tasks =[
            ['cam00', 0, 0, 0, 0, 0, 0],  # cam, wind, grav, bright, contrast, saturation, hue base
            ['cam00', 1, 1, 1, 1, 1, 1],
            ['cam00', 2, 2, 2, 2, 2, 2],
            ['cam00', 3, 3, 3, 3, 3, 3],
        ]

        idx = np.random.randint(low=0, high=len(self.source_tasks))
        dfs = self.source_tasks[idx]
        # print(dfs)
        # change domain factors ...
        self.set_dfs(dfs[0], dfs[1], dfs[2], dfs[3], dfs[4], dfs[5], dfs[6], verbose=True)

    def set_seen_random_tasks(self):
        idx = np.random.randint(low=0, high=8, size=(7,))
        self.set_dfs(CAMS[0], idx[1], idx[2], idx[3], idx[4], idx[5], idx[6], verbose=True)
        #self.set_dfs(CAMS[0], idx[0], idx[1], idx[2], 0, 0, 0, verbose=True)
    
    def set_unseen_random_tasks(self):
        idx = np.random.randint(low=0, high=10, size=(7,))
        self.set_dfs(CAMS[0], idx[1], idx[2], idx[3], idx[4], idx[5], idx[6], verbose=True)
        # self.set_unseen_dfs(U_CAMS[idx[0]], 0, 0, 0, 0, 0, 0, verbose=True)

    def set_dfs(self, fov, xwind_id, gravity_id, bright_id, contrast_id, saturation_id, hue_id, verbose: bool = True):
        self.camera_id = fov
        self.xwind = XWIND[xwind_id]
        self.gravity = GRAVITY[gravity_id]
        self.bright = BRIGHT[bright_id]
        self.contrast = CONTRAST[contrast_id]
        self.saturation = SATURATION[saturation_id]
        self.hue = HUE[hue_id]
        if False:
            print(f"Camera  : {self.camera_id}")
            print(f"XWIND   : {self.xwind}")
            print(f"Gravity : {self.gravity}")
            print(f"BRIGHT  : {self.bright}")
            print(f"CONTRAST: {self.contrast}")
            print(f"SATURATION: {self.saturation}")
            print(f"HUE: {self.hue}")

    def set_unseen_dfs(self, fov, xwind_id, gravity_id, bright_id, contrast_id, saturation_id, hue_id, verbose: bool = True):
        self.camera_id = fov
        self.xwind = U_XWIND[xwind_id]
        self.gravity = U_GRAVITY[gravity_id]
        self.bright = U_BRIGHT[bright_id]
        self.contrast = U_CONTRAST[contrast_id]
        self.saturation = U_SATURATION[saturation_id]
        self.hue = U_HUE[hue_id]
        if True:
            pass
            # print(f"U_Camera  : {self.camera_id}")
            # print(f"U_XWIND   : {self.xwind}")
            # print(f"U_Gravity : {self.gravity}")
            # print(f"U_BRIGHT  : {self.bright}")
            # print(f"U_CONTRAST: {self.contrast}")
            # print(f"U_SATURATION: {self.saturation}")
            # print(f"U_HUE: {self.hue}")
        # exit()
    
    ############# Not Used #############

    # changing dynamics
    def set_gravity(self, gravity):
        self.env.sim.model.opt.gravity[:] = np.array([0., 0., gravity])
        self.env.sim.set_constants()
    
    def set_wind(self, wind):
        self.env.sim.model.opt.wind[:] = np.array([0., wind, 0.]) # default 0.
        self.env.sim.set_constants()

    def set_friction(self, friction):
        self.env.sim.model.opt.impratio = friction # default 1.
        self.env.sim.set_constants()
