from matplotlib.pyplot import axis
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem.Scaffolds.MurckoScaffold import (
    MurckoScaffoldSmiles as get_scaffold_from_smile,
)
import pandas as pd
from sklearn.model_selection import train_test_split

from dgllife.utils import smiles_to_bigraph, mol_to_bigraph
import dgl
from dgllife.utils import (
    AttentiveFPAtomFeaturizer,
    AttentiveFPBondFeaturizer,
)
import pickle as pkl
import os
import hashlib

from tdc.single_pred import Tox, HTS, ADME
from tdc.generation import MolGen


dataset_path = "data/datapoints"
data_splits_path = "data/splits"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_data_features(mol_smile):
    if not os.path.exists(dataset_path):
        os.mkdir(dataset_path)
    smile_hash = hashlib.new("md4", mol_smile.encode("utf-8")).hexdigest()
    filename = str(smile_hash) + ".pkl"
    if os.path.exists(os.path.join(dataset_path, filename)):
        with open(os.path.join(dataset_path, filename), "rb") as f:
            feature = pkl.load(f)
            return feature

    mol = Chem.MolFromSmiles(mol_smile)
    node_featurizer = AttentiveFPAtomFeaturizer()
    edge_featurizer = AttentiveFPBondFeaturizer()

    feature = mol_to_bigraph(
        mol, node_featurizer=node_featurizer, edge_featurizer=edge_featurizer
    )
    with open(os.path.join(dataset_path, filename), "wb") as f:
        pkl.dump(feature, f)

    return feature


class IDDataset(Dataset):
    def __init__(self, data_df):
        self.data = data_df

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        mol_smile = self.data.iloc[idx]["Drug"]
        label = self.data.iloc[idx]["Y"]
        graph = get_data_features(mol_smile)
        return graph, label, mol_smile


class OODDataset(Dataset):
    def __init__(self, data_df):
        self.data = data_df

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        mol_smile = self.data.iloc[idx]["smiles"]
        graph = get_data_features(mol_smile)
        return graph, mol_smile


def id_collate(batch_list):
    batched_graph = dgl.batch([graph for graph, label, smile in batch_list])
    batched_labels = [label for graph, label, smile in batch_list]
    batched_smiles = [smile for graph, label, smile in batch_list]
    return (
        batched_graph,
        torch.tensor(batched_labels),
        batched_smiles,
    )


def ood_collate(batch_list):
    batched_graph = dgl.batch([graph for graph, smile in batch_list])
    batched_smiles = [smile for graph, smile in batch_list]
    return batched_graph, batched_smiles


def get_dataloaders(data_type, split_type, batch_size,shuffle=True):

    id_split, ood_split = get_splits(data_type, split_type)
    if shuffle == True:
        merged_df = pd.concat([id_split["train"],id_split["valid"]])
        id_split["train"] = merged_df[:len(id_split["train"])]
        id_split["valid"] = merged_df[len(id_split["train"]):]

    id_dataset_train = IDDataset(id_split["train"])
    id_dataset_val = IDDataset(id_split["valid"])
    id_dataset_test = IDDataset(id_split["test"])
    ood_dataset_train = OODDataset(ood_split["train"])
    ood_dataset_val = OODDataset(ood_split["valid"])
    ood_dataset_test = OODDataset(ood_split["test"])

    dataloaders = {}
    dataloaders["train_id"] = DataLoader(
        id_dataset_train, batch_size=batch_size, collate_fn=id_collate, shuffle=True
    )
    dataloaders["test_id"] = DataLoader(
        id_dataset_test, batch_size=batch_size, collate_fn=id_collate, shuffle=True
    )
    dataloaders["valid_id"] = DataLoader(
        id_dataset_val, batch_size=batch_size, collate_fn=id_collate, shuffle=True
    )
    dataloaders["train_ood"] = DataLoader(
        ood_dataset_train, batch_size=batch_size, collate_fn=ood_collate, shuffle=True
    )
    dataloaders["test_ood"] = DataLoader(
        ood_dataset_test, batch_size=batch_size, collate_fn=ood_collate, shuffle=True
    )
    dataloaders["valid_ood"] = DataLoader(
        ood_dataset_val, batch_size=batch_size, collate_fn=ood_collate, shuffle=True
    )
    return dataloaders



def get_splits(data_type, split_type):
    if not os.path.exists(data_splits_path):
        os.mkdir(data_splits_path)

    id_filename = "id_" + str(data_type) + str(split_type) + ".pkl"
    id_path = os.path.join(data_splits_path, id_filename)
    ood_filename = "ood_" + str(data_type) + str(split_type) + ".pkl"
    ood_path = os.path.join(data_splits_path, ood_filename)

    if os.path.exists(id_path) and os.path.exists(ood_path):
        with open(id_path, "rb") as f:
            id_splits = pkl.load(f)
        with open(ood_path, "rb") as f:
            ood_splits = pkl.load(f)
        return id_splits, ood_splits

    zinc_data = MolGen(name="Zinc")
    ood_splits = zinc_data.get_split()

    ood_splits_merged = pd.concat(
        [ood_splits["train"], ood_splits["valid"], ood_splits["test"]]
    )
    
    if data_type == "AMES":
        data = Tox(name="AMES")
        data_type = "classification"

    if data_type == "Tox21":
        from tdc.utils import retrieve_label_name_list
        label_list = retrieve_label_name_list('Tox21')
        data = Tox(name = 'Tox21', label_name = label_list[0])
        data_type = "classification"

    if data_type == "ClinTox":
        data = Tox(name = 'ClinTox')
        data_type = "classification"


    if data_type == "HIV":
        data = HTS(name="HIV")
        data_type = "classification"

    elif split_type == "fp":
        id_splits = data.get_split(method="random")
        id_splits = get_splits_by(id_splits,split_type="fp",ascending=False)
    else:
        id_splits = data.get_split(method=split_type)

    id_splits_merged = pd.concat([id_splits["train"], id_splits["valid"]])
    id_train_df, id_val_df = train_test_split(
        id_splits_merged, test_size=0.15, shuffle=True, random_state=random_state
    )
    id_splits["train"] = id_train_df
    id_splits["valid"] = id_val_df

    id_splits["train"]["Y"] = id_splits["train"]["Y"].astype(int)
    id_splits["valid"]["Y"] = id_splits["valid"]["Y"].astype(int)
    id_splits["test"]["Y"] = id_splits["test"]["Y"].astype(int)

    drug_list = list(id_splits_merged["Drug"])

    if split_type == "scaffold":
        drug_scaffolds = set([get_scaffold_from_smile(drug) for drug in drug_list])
        ood_splits_merged["scaffolds"] = ood_splits_merged.apply(
            lambda row: get_scaffold_from_smile(row["smiles"]), axis=1
        )
        filtered_ood_df = ood_splits_merged[
            ~ood_splits_merged["scaffolds"].isin(drug_scaffolds)
        ]
        print(len(ood_splits_merged), len(filtered_ood_df))
            
    ood_train_val_df, ood_test_df = train_test_split(
        filtered_ood_df, test_size=0.02, shuffle=True, random_state=random_state
    )
    ood_train_df, ood_val_df = train_test_split(
        ood_train_val_df, test_size=0.15, shuffle=True, random_state=random_state
    )
    ood_splits["train"] = ood_train_df
    ood_splits["valid"] = ood_val_df
    ood_splits["test"] = ood_test_df

    with open(id_path, "wb") as f:
        pkl.dump(id_splits, f)

    with open(ood_path, "wb") as f:
        pkl.dump(ood_splits, f)

    return id_splits, ood_splits

