import numpy as np

from .base import ACselector
class UniformSamplingSelector(ACselector):
    def __init__(self, dataset, seed, ac_type='random'):
        super().__init__(ac_type=ac_type)
        self.dataset = dataset
        self.seed = seed
        self.ac_type = ac_type
        np.random.seed(seed)
        
    def select_batch_(self, already_selected, N, weight=None):
        # This is uniform given the remaining pool but biased wrt the entire pool.
        num = self.dataset.size
        if weight is not None:
            weight /= weight.sum()
            
            rest_index = list(set(range(num)) - set(already_selected))
            sample = np.random.choice(rest_index ,size=N, replace=False, p=weight)
            
            return sample.tolist(), None
            
        score = np.random.rand(num)
        ranking = np.argsort(-score)
        sample = [ranking[i] for i in range(num) if ranking[i] not in already_selected]
        
            

        return sample[0:N], score