# this is the script for pretraining denset on skin lesion images
import random
import pickle
import numpy as np
from tqdm import tqdm
from PIL import Image
import wandb
import csv
import copy
import torch
import torchvision
import torch.nn as nn
from torch import optim
import torchxrayvision as xrv
from argparse import ArgumentParser
from torch.utils.data import DataLoader, Dataset
from models import DenseNetE2E

random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
def densenet_preprocess(image):
    image = image.convert("RGB")
    img = np.array(image)
    img = xrv.datasets.normalize(img, 255)
    img = img.mean(2)[None, ...]
    img = transform(img)
    img = torch.from_numpy(img)
    return img


class ImageDataset():
    def __init__(self, class2images, class2label):
        # Initialize image paths and corresponding texts
        self.preprocess = densenet_preprocess
        self.image_paths = []
        self.labels = []
        self.images = []
        
        for class_name, images in class2images.items():
            print(class_name, len(images))
            for image in tqdm(images):
                self.image_paths.append(image)
                self.labels.append(class2label[class_name])
                self.images.append(self.preprocess(Image.open(image)))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label


def train_model(model, train_loader, val_loader, num_epochs, lr, class_weight):
    criterion = nn.CrossEntropyLoss()
    # add weight decay to the optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

    best_val_acc = -float("inf")
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for i, (images, labels) in enumerate(train_loader):
            images = images.type(torch.float32).to(device)
            labels = labels.type(torch.LongTensor).to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            wandb.log({"train_loss": loss.item(), "epoch": epoch, "step": epoch * len(train_loader) + i})

        val_acc = eval_model(model, val_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_loader)}, Val Acc: {val_acc}")
        wandb.log({"val_acc": val_acc, "epoch": epoch})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model)
            # save the best model
            torch.save(best_model.state_dict(), f"../data/isic/densenet_skin.pt")

    return best_model


def eval_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    
    return accuracy


def prepare_data():    
    with open('../data/datasets/isic/images/metadata.csv', "r") as f:
        reader = csv.reader(f)
        data = list(reader)

    attributes = data[0]
    data = data[1:]
    data_dict = {}

    for d in data:
        id = d[0]
        data_dict[id] = {}
        for i, a in enumerate(attributes[1:]):
            data_dict[id][a] = d[i+1]
    
    selected_classes = ["melanoma", "nevus", "seborrheic keratosis", "actinic keratosis", "basal cell carcinoma", "squamous cell carcinoma", "pigmented benign keratosis", "benign", "others"]
    class2ids = {class_name: [] for class_name in selected_classes}
    for id in data_dict:
        class_name = data_dict[id]["diagnosis"]
        if class_name == "": class_name = data_dict[id]["benign_malignant"]
        if class_name not in selected_classes: class_name = "others"

        class2ids[class_name].append(id)
    
    class2images_train = {}
    class2images_val = {}

    for class_name, ids in class2ids.items():
        random.shuffle(ids)
        class2images_val[class_name] = [f"../data/datasets/isic/images/{id}.JPG" for id in ids[:150]]
        class2images_train[class_name] = [f"../data/datasets/isic/images/{id}.JPG" for id in ids[150:]]

    # print the number of images in each category
    for class_name in class2images_train:
        print(class_name, "train", len(class2images_train[class_name]), "val", len(class2images_val[class_name]))
    
    pickle.dump(class2images_train, open(f"../data/isic/splits/class2images_train.p", "wb"))
    pickle.dump(class2images_val, open(f"../data/isic/splits/class2images_val.p", "wb"))
    
    return class2images_train, class2images_val


def main(args):
    class2images_train = pickle.load(open(f"../data/isic/splits/class2images_train.p", "rb"))
    class2images_val = pickle.load(open(f"../data/isic/splits/class2images_val.p", "rb"))

    class2label = {class_name: i for i, class_name in enumerate(class2images_train.keys())}

    dataset_train = ImageDataset(class2images_train, class2label)
    dataset_val = ImageDataset(class2images_val, class2label)

    train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, num_workers=4)

    base_model = xrv.models.DenseNet()
    model = DenseNetE2E(base_model, len(class2label))
    model.to(device)

    # get class weight based on the number of training samples in each class: n_samples / (n_classes * np.bincount(y))
    class_weight = torch.tensor([len(class2images_train[class_name]) for class_name in class2label.keys()])
    class_weight = class_weight / class_weight.sum()

    wandb.init(project="densenet_skin_lesion", 
            name=f"{args.batch_size}_{args.num_epochs}_{args.lr}",
            config={
            "batch_size": args.batch_size,
            "epochs": args.num_epochs,
            "lr": args.lr}
            )

    best_model = train_model(model, train_loader, val_loader, args.num_epochs, args.lr, class_weight)
    val_acc = eval_model(best_model, val_loader)

    print("Val Acc:", val_acc)
    # save the best model
    torch.save(best_model.state_dict(), f"../data/isic/densenet_skin.pt")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=1e-3)
    args = parser.parse_args()

    main(args)