from online_sc import MultiplicativeWeightsAlgo
from offline_sc import OfflineSCAlgo
from sc_input import SCInput
from smooth_combiner import SmoothOnlineCombiner
from standard_combiner import StandardOnlineCombiner
import numpy as np
import polars as pl

"""
This code will run the experiments 
"""


def get_noisy_prediction(sc_input,
                         false_pos_rate,
                         false_neg_rate,
                         seed,
                         rounding_multiplier=5):
    _, offline_solution = OfflineSCAlgo.solve(sc_input)
    rng = np.random.default_rng(seed=seed)
    integral_offline_solution = (rounding_multiplier * offline_solution) > rng.random(size=offline_solution.shape)
    noisy_prediction = integral_offline_solution & (rng.random(integral_offline_solution.shape) > false_neg_rate) | (
                rng.random(integral_offline_solution.shape) < false_pos_rate)

    # add singletons
    noisy_prediction[-sc_input.num_elems:] = True
    return noisy_prediction


def run_experiment(seed,
                   false_pos_rate,
                   false_neg_rate,
                   num_elems,
                   num_sets,
                   connection_probability):
    test_input = SCInput.get_random_input(num_elems=num_elems,
                                          num_sets=num_sets,
                                          uniform_conn_prob=connection_probability,
                                          sigma=1.6,
                                          seed=seed)

    prediction = get_noisy_prediction(test_input,
                                      false_pos_rate=false_pos_rate,
                                      false_neg_rate=false_neg_rate,
                                      seed=seed)

    offline_cost, offline_solution = OfflineSCAlgo.solve(test_input)

    gen_online_algo = MultiplicativeWeightsAlgo(test_input)
    pred_online_algo = MultiplicativeWeightsAlgo(test_input, prediction=prediction)
    smooth_combination_algo = SmoothOnlineCombiner(test_input, prediction=prediction)
    standard_combination_algo = StandardOnlineCombiner(test_input, prediction=prediction)

    rng = np.random.default_rng(seed=seed)
    online_order = rng.permutation(num_elems)
    for elem in online_order:
        gen_online_algo.request_element(elem)
        pred_online_algo.request_element(elem)
        standard_combination_algo.request_element(elem)
        smooth_combination_algo.request_element(elem)

    gen_algo_cost = gen_online_algo.get_cost()
    pred_algo_cost = pred_online_algo.get_cost()
    standard_combination_cost = standard_combination_algo.get_cost()
    smooth_combination_cost = smooth_combination_algo.get_cost()

    return pl.DataFrame(
        {
            'offline': [offline_cost],
            'general online': [gen_algo_cost],
            'prediction online': [pred_algo_cost],
            'standard combination': [standard_combination_cost],
            'smooth combination': [smooth_combination_cost]
        }
    )
