import torch as th
from functools import partial
from tqdm import tqdm
from torch.utils.data import DataLoader
from ray.util.multiprocessing import Pool
from typing import *

from model.counterfactual.jctf_estimator import JointCounterfacutalEstimator


def _data_test(batch, model: JointCounterfacutalEstimator):
    return model.eval_step(batch)


def parallel_test(
    model: JointCounterfacutalEstimator,
    test_dataloader: DataLoader,
    num_workers: int = 12,
) -> Any:
    # Test parallely on cpu
    model.eval()
    model.freeze()
    model.share_memory()
    pool = Pool(num_workers)
    results = list(tqdm(
        pool.imap(partial(_data_test, model=model), test_dataloader), total=len(test_dataloader)
    ))
    pool.close()
    pool.join()
    return results
