import torch
from loaders.compas_loader import load_data, mono_list
import torch.utils.data as Data
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np

from monotonenorm.monotonicnetworks import SigmaNet, direct_norm, GroupSort

from BLNN import PICNN, PICNN_multiclass

torch.set_default_dtype(torch.float64)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)
print(mono_list)
X_train, y_train, X_test, y_test = load_data(get_categorical_info=False)

X_train = torch.tensor(X_train, dtype=torch.float64).to(device)
X_test = torch.tensor(X_test,dtype=torch.float64).to(device)
y_train = torch.tensor(y_train,dtype=torch.float64).unsqueeze(1).to(device)
y_test = torch.tensor(y_test,dtype=torch.float64).unsqueeze(1).to(device)
idx = torch.arange(len(X_train))
mean = X_train.mean(0)
std = X_train.std(0)
X_train = (X_train - mean) / std
X_test = (X_test - mean) / std

monotone_constraints = [1 if i in mono_list else 0 for i in range(X_train.shape[1])]

per_layer_lip = 1.3


def run(seed):
    torch.manual_seed(seed)

    width = 100

    network  = PICNN_multiclass(len(X_train[0])-1,1,45,4, 0.3,0, 4)
    network = network.to(device)

    print("params:", sum(p.numel() for p in network.parameters()))

    optimizer = torch.optim.Adam(network.parameters(), lr=2e-4)


    data_train_loader = Data.DataLoader(
        dataset=Data.TensorDataset(X_train, y_train, idx), batch_size=256, shuffle=True,
    )
    bar = tqdm(range(200))
    acc = 0
    for i in bar:
        for X, y, id in data_train_loader:
            y_pred = network(X,[0,1,2,3], id)

            loss_train = torch.nn.functional.binary_cross_entropy(torch.nn.Sigmoid()(y_pred), y)
            optimizer.zero_grad()
            loss_train.backward()
            optimizer.step()
        with torch.no_grad():
            y_pred = network(X_test,[0,1,2,3])
            loss = torch.nn.functional.binary_cross_entropy(torch.nn.Sigmoid()(y_pred), y_test)
            acci = 0
            for i in torch.linspace(0, 1, 50):
                acci = max(
                    acci,
                    accuracy_score(
                        y_test.cpu().detach().numpy(),
                        (y_pred.cpu().detach().numpy() > i.item()).astype(int),
                    ),
                )

            acc = max(acc, acci)
            bar.set_description(
                f"train: {loss_train.item():.4f}, test: {loss.item():.4f}, current acc: {acci:.4f}, best acc: {acc:.4f}"
            )
    return acc


accs = [run(i) for i in range(3)]
print(f"mean: {np.mean(accs):.4f}, std: {np.std(accs):.4f}")
