import hydra
import omegaconf
from omegaconf import DictConfig, OmegaConf
import os
import sys

from bbo.utils import seed_everything, print_dict
import bbo
from bbo.algorithms import TransformerOpt
import logging
import time
import gc
import torch
from space_gen import get_bounds, get_source_data, get_func
import wandb
import copy as cp
import datetime
from tqdm import tqdm

def transformer(args):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model_path = f"checkpoints/{args.similar}-{args.search_space_id}.pth"
    dim = args.dims
    lb, ub = get_bounds(args)
    
    lb = torch.tensor(lb, device=device)
    ub = torch.tensor(ub, device=device)
    
    
    # init
    X, Y = torch.zeros((0, dim)), torch.zeros((0, 1))
    max_evals = args.iteration
    total_evals = 0
    
    def f(x):
        z = x.detach().cpu().numpy().copy()
        if z.ndim == 2:
            z = z[0]
        assert z.ndim == 1
        y = get_func(args.search_space_id, args.dataset_id, args.mode, dim)(z)*-1
        return torch.tensor(y, device=device).reshape(-1,1)

    for i in range(args.rep):
        seed_everything(i)
        ts = datetime.datetime.utcnow() + datetime.timedelta(hours=+8)
        ts_name = f'-ts{ts.month}-{ts.day}-{ts.hour}-{ts.minute}-{ts.second}'
        wandb.init(
            project="hpobhpo",
            name=f"transformer-{args.search_space_id}-{args.dataset_id}-{ts_name}",
            job_type="OPT-transformer",
            tags=[f"dim={args.dims}", f"similar={args.similar}", f"search_space_id={args.search_space_id}", f"dataset_id={args.dataset_id}"]
        )
        
        total_evals = 0
        bar = tqdm(total=max_evals)
        
        algo_cfg = OmegaConf.load("configs/algorithms/np/transformer_opt.yaml") 
        algo = hydra.utils.instantiate(algo_cfg, dim=dim, lb=lb, ub=ub)
        
        while total_evals < max_evals:
            next_X = algo.ask(model_path)
            
            if next_X.shape[0]==3:
                next_Y = []
                for i in range(3):
                    x = next_X[i,:]
                    y = f(x)
                    next_Y.append(y)
                    best_value = max(next_Y)
                    curt_best_value = best_value if args.mode == "real" else torch.absolute(best_value)
                    wandb.log({
                        "sample counter": i+1,
                        "sample value": torch.absolute(y).item(),
                        "best value": curt_best_value.item(),
                    })
                next_Y = torch.tensor(next_Y, device=device).reshape(-1,1)
            else:
                assert next_X.shape[0] == 1
                next_Y = torch.tensor(f(next_X), device=device).reshape(-1,1)
            
            algo.tell(next_X, next_Y)

            # https://github.com/pytorch/botorch/issues/1585
            torch.cuda.empty_cache()
            gc.collect()

            total_evals += len(next_X)
            bar.update(len(next_X))

            next_X, next_Y = next_X.to(X), next_Y.to(Y)
            X = torch.vstack((X, next_X))
            Y = torch.vstack((Y, next_Y))

            results = {
                'total_evals': total_evals,
                'y': next_Y.max().item(),
                'best_y': Y.max().item(),
            }
            print('{}'.format(results))
            
            if total_evals > 3:
                best_value = Y.max()
                curt_best_value = best_value if args.mode == "real" else torch.absolute(best_value)
                wandb.log({
                    "sample counter": total_evals,
                    "sample value": torch.absolute(next_Y).item(),
                    "best value": curt_best_value.item(),
                })
        wandb.finish()
