from dataclasses import dataclass, field

import numpy as np
import torch
from gpytorch.kernels import RBFKernel
from sklearn.datasets import load_diabetes

from src.explanation_algorithms.BayesGPSHAP import BayesGPSHAP
from src.gp_model.VariationalGPRegression import VariationalGPRegression


# num_coalitions = [2 ** 8, 2 ** 9, 2 ** 10]
# num_training_datas = [150, 300, 442]


@dataclass
class AblationExperiments():
    num_coalitions: list[int]
    num_training_datas: list[int]

    experiment_results: dict = field(init=False, default_factory=dict)
    feature_names: list[str] = field(init=False)

    def run(self):
        diabetes = load_diabetes()
        X, y = diabetes.data, diabetes.target
        self.feature_names = diabetes.feature_names

        X = torch.tensor(X).float()
        scale = np.std(y)
        y = torch.tensor(y).float()
        y = (y - y.mean()) / y.std()
        d = X.shape[1]

        for num_training_data in self.num_training_datas:
            X_train, y_train = X[:num_training_data], y[:num_training_data]
            gp_regression = VariationalGPRegression(
                X_train, y_train, kernel=RBFKernel, num_inducing_points=100, batch_size=128
            )
            gp_regression.fit(learning_rate=1e-2, training_iteration=500)

            for num_coalition in self.num_coalitions:
                bayesgpshap = BayesGPSHAP(
                    train_X=X_train, gp_model=gp_regression, kernel=RBFKernel(),
                    include_likelihood_noise_for_explanation=False,
                    scale=scale)

                bayesgpshap.run_bayesSHAP(X_train, num_coalitions=num_coalition, sampling_method="subsampling")

                self.experiment_results[num_training_data, num_coalition] = bayesgpshap
