"""
This training script can be run both on a single gpu in debug mode,
and also in a larger training run with distributed data parallel (ddp).

To run on a single GPU, example:
$ python train.py --batch_size=32 --compile=False

To run with DDP on 4 gpus on 1 node, example:
$ torchrun --standalone --nproc_per_node=4 train.py

To run with DDP on 4 gpus across 2 nodes, example:
- Run on the first (master) node with example IP 123.456.123.456:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
- Run on the worker node:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
(If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
"""

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import time
import math
import pickle
import argparse
from contextlib import nullcontext
from tqdm import tqdm
import wandb

import numpy as np
import random
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

from model import GPTConfig, GPT
from polyak import *
from util import *

parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--learning_rate', default=None, type=float, help='learning rate')
parser.add_argument('--weight_decay', default=1e-1, type=float)
parser.add_argument('--method', default="sgd", type=str, help='methods')
parser.add_argument('--grad_clip', default=0.0, type=float) # max_grad_normが0.0ならclip無し、max_grad_normが正ならclip
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--block_size', default=256, type=int)
parser.add_argument('--wandb_project', default="polyak_nanogpt", type=str)
parser.add_argument('--dataset', default="shakespeare_char", type=str)
parser.add_argument('--max_iters', default=5000, type=int)
parser.add_argument('--scheduler', default="none", type=str)
parser.add_argument('--warmup_iters', default=100, type=int)
parser.add_argument('--seed', default=0, type=int)

args = parser.parse_args()

wandb.init(project="shakespeare-char")
wandb.config.update(args)

# -----------------------------------------------------------------------------
# default config values designed to train a gpt2 (124M) on OpenWebText
# I/O
out_dir = 'out'
eval_interval = 100 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often
gradient_accumulation_steps = 1

# model
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

# system
device = 'cuda' 
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler

tokens_per_iter = gradient_accumulation_steps * wandb.config.batch_size * wandb.config.block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")

os.makedirs(out_dir, exist_ok=True)
# Fix seed value
random.seed(wandb.config.seed)
np.random.seed(wandb.config.seed)
torch.manual_seed(wandb.config.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join('data', wandb.config.dataset)

def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    elif split == "val":
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    elif split == "test":
        data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
    else:
        raise Exception(f"train/val/test is only acceptable, but {split} is given.")
        
    ix = torch.randint(len(data) - wandb.config.block_size, (wandb.config.batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+wandb.config.block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+wandb.config.block_size]).astype(np.int64)) for i in ix])
    
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    meta_vocab_size = meta['vocab_size']
    print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")

# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=wandb.config.block_size,
                  bias=True, vocab_size=None, dropout=dropout) # start with model_args from command line

# init a new model from scratch
print("Initializing a new model from scratch")

# determine the vocab size we'll use for from-scratch training
if meta_vocab_size is None:
    print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)

# crop down the model block size if desired, using model surgery
if wandb.config.block_size < model.config.block_size:
    model.crop_block_size(wandb.config.block_size)
    model_args['block_size'] = wandb.config.block_size # so that the checkpoint will have the right value
model.to(device)

# optimizer
optimizer = model.configure_optimizers(wandb.config.method,
                                       wandb.config.weight_decay,
                                       wandb.config.learning_rate,
                                       device_type,
                                       wandb.config.max_iters)

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


def chunks(l, n):
    for i in range(0, len(l), n):
        yield l[i:i + n]


@torch.no_grad()
def compute_accurate_loss(splits=['train', 'val', 'test']):
    out = {}
    model.eval()
    for split in splits:
        if split == 'train':
            data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
        elif split == "val":
            data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
        elif split == "test":
            data = np.memmap(os.path.join(data_dir, 'test.bin'), dtype=np.uint16, mode='r')
        else:
            raise Exception(f"train/val/test is only acceptable, but {split} is given.")

        ix_list = list(range(len(data) - wandb.config.block_size))
        ix_list = list(chunks(ix_list, wandb.config.batch_size))
        losses = torch.zeros(len(ix_list))
        
        for k, ix in enumerate(ix_list):
            x = torch.stack([torch.from_numpy((data[i:i+wandb.config.block_size]).astype(np.int64)) for i in ix])
            y = torch.stack([torch.from_numpy((data[i+1:i+1+wandb.config.block_size]).astype(np.int64)) for i in ix])

            if device_type == 'cuda':
                # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
                x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
            else:
                x, y = x.to(device), y.to(device)

            with ctx:
                logits, loss = model(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


# learning rate decay scheduler (cosine with warmup)
def get_lr(it, min_lr=1e-4):
    # 1) linear warmup for warmup_iters steps
    if it < wandb.config.warmup_iters:
        return wandb.config.learning_rate * it / wandb.config.warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > wandb.config.lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - wandb.config.warmup_iters) / (wandb.config.lr_decay_iters - wandb.config.warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (wandb.config.learning_rate - min_lr)

    
# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
raw_model = model 
running_mfu = -1.0

if wandb.config.scheduler == "cosine":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=wandb.config.max_iters)

while True:
    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0:
        print("evaluating...")
        validation_losses = compute_accurate_loss(['val', 'test'])
        print(f"step {iter_num}: test loss {validation_losses['test']:.4f}, val loss {validation_losses['val']:.4f}")
    else:
        validation_losses = {"val": None, "test": None}
        
    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        
        # backward pass, with gradient scaling if training in fp16
        loss.backward()
        
    grad_norm = compute_grad_norm(model.parameters())

    # clip the gradient
    if wandb.config.grad_clip != 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), wandb.config.grad_clip)
        
    # step the optimizer and scaler if training in fp16
    if "polyak" in wandb.config.method or "sps" in wandb.config.method:
        optimizer.step(loss)
    else:
        optimizer.step()

    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad()
    if wandb.config.scheduler == "cosine":
        scheduler.step()

    if "polyak" in wandb.config.method or "sps" in wandb.config.method:   
        wandb.log({
            "iter": iter_num,
            "train/minibatch_loss": loss,
            "grad_norm": grad_norm,
            "lr" : optimizer.old_lr,
            "val/full_loss": validation_losses['val'],
            "test/full_loss": validation_losses['test']
        }, step = iter_num)
    else:
        wandb.log({
            "iter": iter_num,
            "grad_norm": grad_norm,
            "train/minibatch_loss": loss,
            "lr": None,
            "val/full_loss": validation_losses['val'],
            "test/full_loss": validation_losses['test']            
        }, step = iter_num)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1

    if iter_num % log_interval == 0:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(wandb.config.batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")

    # termination conditions
    if iter_num == wandb.config.max_iters:
        losses = compute_accurate_loss(["val", "test"])
        
        wandb.log({
            "iter": iter_num,
            "grad_norm": grad_norm,
            "train/minibatch_loss": loss,
            "lr": None,
            "val/full_loss": losses['val'],
            "test/full_loss": losses['test'],
        }, step = iter_num)
        break
        
    iter_num += 1
    local_iter_num += 1

wandb.finish()
