import numpy as np
import torch

from library.sde.base import DynamicalSystem

'''
This module contains some common neuroscience task stimuli, cues, targets, and so forth. 
They are all dynamical systems, even if constant.
'''

class Target(DynamicalSystem):

    def __init__(self, drift, dim=2, noise_dim=None, noise=0.0, device='cpu'):

        super(Target, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim
        self.drift = drift

    def set_pos(self, batch_size):

        self.pos = torch.rand((batch_size, self.dim), device=self.device)

    def get_initial_state(self, batch_size):

        return self.pos+torch.randn(batch_size,self.dim, device=self.device)*self.noise/np.sqrt(2*self.drift)

    def f(self, x, *args):

        return self.drift*(self.pos-x)


class TargetSphere(Target):

    def __init__(self, drift, dim=2, noise_dim=None, noise=0.0, device='cpu'):
        super().__init__(drift, dim, noise_dim, noise, device)

    def set_pos(self, batch_size):

        self.pos = torch.randn((batch_size, self.dim), device=self.device)
        self.pos = self.pos/torch.norm(self.pos, dim=1).unsqueeze(1)

        # Careful, angle is computed only over first two dims
        self.angles = torch.atan2(self.pos[:,1], self.pos[:,0])
        self.angles[self.angles<0] = 2*np.pi+self.angles[self.angles<0]

class TargetSphereDiscrete(Target):

    def __init__(self, drift, positions, dim=2, noise_dim=None, noise=0.0, device='cpu'):
        super().__init__(drift, dim, noise_dim, noise, device)

    def set_pos(self, batch_size):

        self.pos = torch.randn((batch_size, self.dim), device=self.device)
        self.pos = self.pos/torch.norm(self.pos, dim=1).unsqueeze(1)

        # Careful, angle is computed only over first two dims
        self.angles = torch.atan2(self.pos[:,1], self.pos[:,0])
        self.angles[self.angles<0] = 2*np.pi+self.angles[self.angles<0]



class TargetDiscrete(DynamicalSystem):

    def __init__(self, drift, number_targets=8, noise_dim=None, noise=0.0, device='cpu'):

        super(TargetDiscrete, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.number_targets = number_targets

        self.dim = number_targets
        self.drift = drift

        self.angles_pool = torch.tensor([np.pi*2*i/number_targets for i in range(number_targets)], device=device)
        self.targets_pool = torch.stack([torch.cos(self.angles_pool), torch.sin(self.angles_pool)], dim=-1)

    def set_pos(self, batch_size):

        self.target_id = torch.randint(0, 8, (batch_size,), device=self.device)

        self.angles = self.angles_pool[self.target_id]
        self.pos = self.targets_pool[self.target_id]

        self.one_hot_target = torch.stack([self.target_id==i for i in range(self.number_targets)], dim=-1).float()
    def get_initial_state(self, batch_size):

        return self.one_hot_target+torch.randn(batch_size,self.number_targets, device=self.device)*self.noise/np.sqrt(2*self.drift)

    def f(self, x, *args):

        return self.drift*(self.one_hot_target-x)


class Cue(DynamicalSystem):

    def __init__(self, dim=1, noise_dim=None, noise=0.0, device='cpu'):
        super(Cue, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim

    def get_initial_state(self, batch_size):

        return torch.ones([batch_size, self.dim], device=self.device)

    def f(self, x, *args):
        return torch.zeros_like(x, device=self.device)


class OU(DynamicalSystem):

    def __init__(self, dim, drift, noise, mus=None, noise_dim=None, device='cpu'):
        super(OU, self).__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.dim = dim
        self.drift = drift

        if mus is None: self.steady_state = 0
        else: self.mus = torch.tensor(mus, device=device)

    def set_steady_state(self, batch_size):

        self.steady_state = self.mus[torch.randint(len(self.mus), [batch_size, self.dim], device=self.device)]

    def get_initial_state(self, batch_size):

        return torch.randn(batch_size,self.dim, device=self.device)*self.noise/np.sqrt(2*self.drift)

    def f(self, x, *args):

        return (self.steady_state-x)*self.drift



class DiscreteTargetSequence(DynamicalSystem):
    def __init__(self, positions, sequence_length, drift, zero_start=True, zero_end=True, noise_dim=None, noise=0.0, device='cpu'):
        super().__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.number_positions = len(positions)

        self.dim = 2

        self.positions = positions
        self.sequence_length = sequence_length
        self.drift = drift
        self.zero_start = zero_start
        self.zero_end = zero_end

    def set_pos(self, batch_size):

        self.sequence_ids = torch.stack([torch.randperm(self.number_positions) for i in range(batch_size)])[:,:self.sequence_length]

        self.sequence = self.positions[self.sequence_ids.T]

        if self.zero_end:
            self.sequence = torch.cat([self.sequence, torch.zeros((1, batch_size, self.dim), device=self.device)])

        self.reset_sequence()

    def reset_sequence(self):
        self.id = 0

        self.pos = self.sequence[self.id]

    def next_element(self):

        self.id += 1

        self.pos = self.sequence[self.id]


    def get_initial_state(self, batch_size):

        temp = torch.randn(batch_size,self.dim, device=self.device)*self.noise/np.sqrt(2*self.drift)

        if self.zero_start: return temp
        else: return temp + self.pos


    def f(self, x, *args):

        return self.drift*(self.pos-x)
