from library.sde.base import DynamicalSystem

import torch
import numpy as np

'''
This module contains the dynamical systems necessary for the RNN to perform the reach task.
'''

class Hand(DynamicalSystem):

    def __init__(self, dim, initial_std, noise, noise_dim=None, device='cpu'):
        super(Hand, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim
        self.initial_std = initial_std

    def get_initial_state(self, batch_size):
        return torch.randn(batch_size, 2, device=self.device) * self.initial_std

    def f(self, *args):
        return sum(args)


class PerturbedHand(DynamicalSystem):

    def __init__(self, dim, initial_std, noise, rotation=-np.pi/2, noise_dim=None, device='cpu'):
        super(PerturbedHand, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim
        self.initial_std = initial_std
        self.rotation = rotation

        self.set_rotation_matrix(rotation)

    def set_rotation_matrix(self, rotation):
        self.rotation_matrix = torch.tensor([[np.cos(rotation),-np.sin(rotation)],
                                             [np.sin(rotation), np.cos(rotation)]],
                                            dtype=torch.float, device=self.device)

    def get_initial_state(self, batch_size):
        return torch.randn(batch_size, 2, device=self.device) * self.initial_std

    def f(self, *args):
        velocity = sum(args)

        return velocity

class HandVelocity(DynamicalSystem):

    def __init__(self, dim, initial_std, noise, coef_perturbation, total_coef=0.3, noise_dim=None, device='cpu'):
        super(HandVelocity, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim
        self.initial_std = initial_std
        self.coef_perturbation = coef_perturbation
        self.total_coef = total_coef

        self.set_rotation_matrix(np.pi/2)

    def set_rotation_matrix(self, rotation):
        self.rotation_matrix = torch.tensor([[np.cos(rotation),-np.sin(rotation)],
                                             [np.sin(rotation), np.cos(rotation)]],
                                            dtype=torch.float, device=self.device)

    def get_initial_state(self, batch_size):
        return torch.zeros(batch_size, 2, device=self.device)

    def f(self, x, y, z):
        acceleration = x+z*0.1

        return (acceleration + (y @ self.rotation_matrix.T)*self.coef_perturbation)*self.total_coef


