
import numpy as np
import torch

from .base import ACselector
from tqdm import tqdm


class CoreSetSelector(ACselector):
    def __init__(self, dataset, seed, ac_type='CoreSet'):
        super().__init__(ac_type=ac_type)
        self.dataset = dataset
        self.seed = seed
        self.ac_type = ac_type
        np.random.seed(seed)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        self.features_ul = torch.Tensor(self.dataset.X)
        self.min_distance = torch.ones(self.dataset.size) * 19260817
        
    def select_batch_(self, already_selected, N, weight=None):
        
        num = self.dataset.size
        # print(self.min_distance.shape)
        self.min_distance[already_selected] = 0
        print('select ... ')

        # need GPU!!!

        self.min_distance = self.min_distance.to(self.device)
        self.features_ul = self.features_ul.to(self.device)

        if weight is None:
            weight = np.ones(num) / num
        else:
            weight_ = np.zeros(num)
            rest_index = list(set(range(num)) - set(already_selected))
            weight_[rest_index] = weight
            weight = weight_
            weight /= weight.sum()

        sample = []
        for _ in tqdm(range(N)):
            if len(already_selected) == 0 and len(sample) == 0:
                # Initialize centers with a randomly selected datapoint
                
                idx = np.random.choice(np.arange(self.dataset.size), p=weight)
                self.min_distance = torch.cdist(self.features_ul[idx].reshape(1,-1), self.features_ul).detach()
                
            else:
                idx = torch.argmax(self.min_distance * torch.Tensor(weight).to(self.device)).item()
                upd_dist = torch.cdist(self.features_ul[idx].reshape(1,-1), self.features_ul).detach()
                self.min_distance = torch.minimum(self.min_distance, upd_dist)

            sample.append(idx)
        
        self.min_distance = self.min_distance.cpu().reshape(-1)
        self.features_ul = self.features_ul.cpu()

        return sample, self.min_distance.detach().cpu().numpy()

        return sample[0:N], score

