import numpy as np
import ghalton as gh
from sswimlib.utils.enums import SCRAMBLING, SEQUENCE, QMC_KWARG
import torch.distributions as tdists
import torch


class Sequence(object):
    def __init__(self,
                 N,
                 D,
                 seed=42,
                 sequence_type=SEQUENCE.HALTON,
                 scramble_type=SCRAMBLING.OWEN17,
                 kwargs={QMC_KWARG.PERM: None},
                 ):
        self.N = N
        self.D = D
        self.seed = seed
        self.sequence_type = sequence_type
        self.scramble_type = scramble_type
        self.kwargs = kwargs
        self.sequencer = None
        self.points = None
        self.init_sequencer()
        self.init_points()

    def init_sequencer(self):
        # ---------------------------------------#
        #                Halton                  #
        # ---------------------------------------#
        if self.sequence_type == SEQUENCE.HALTON:
            if self.scramble_type == SCRAMBLING.OWEN17:
                pass
            elif self.scramble_type == SCRAMBLING.GENERALISED:
                if self.kwargs[QMC_KWARG.PERM] is None:
                    perm = gh.EA_PERMS[:self.D]  # Default permutation
                else:
                    perm = self.kwargs[QMC_KWARG.PERM]
                self.sequencer = gh.GeneralizedHalton(perm)
            else:
                self.sequencer = gh.Halton(int(self.D))

        # ---------------------------------------#
        #              Monte-Carlo               #
        # ---------------------------------------#
        elif self.sequence_type == SEQUENCE.MC:
            self.sequencer = tdists.Uniform(torch.tensor(0.0),
                                            torch.tensor(1.0))

    def init_points(self):
        if self.sequence_type == SEQUENCE.MC:
            self.points = self.sequencer.sample(sample_shape=(self.N, self.D))
        else:  # Ghalton
            self.points = torch.tensor(np.array(self.sequencer.get(int(self.N))))  # XXX TODO XXX NOT SURE IF requires_grad must be true

    def resample_points(self):
        if self.sequence_type == SEQUENCE.MC:
            self.points = self.sequencer.sample(sample_shape=(self.N, self.D))
