'''Train CIFAR10 with PyTorch.'''
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist

import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet50
import os
import argparse
import numpy as np



from utils import progress_bar

def setup_seed(seed):
    import numpy as np
    import random
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--weight_decay', default=5e-4, type=float)
parser.add_argument('--warm_up', action='store_true')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--log_dir', default='log/tmp', type=str)
parser.add_argument('--onebit', action='store_true')
parser.add_argument('--comp_flag', action='store_true')
parser.add_argument('--record_time', action='store_true')
parser.add_argument('--local_rank', type=int)
parser.add_argument('--backend', default='nccl', type=str)
parser.add_argument('--deepspeed', action='store_true')
parser.add_argument('--packbits_by_cupy', action='store_true')
args = parser.parse_args()

local_rank = args.local_rank
print('local_rank', local_rank)
local_device = local_rank



torch.cuda.set_device(local_device)
import os
if args.backend == 'nccl':
    os.environ['NCCL_SOCKET_IFNAME']="ens12f0"
    os.environ['NCCL_IB_DISABLE'] = '1'
    dist.init_process_group(backend='nccl', init_method='env://')
else:
    os.environ['GLOO_SOCKET_IFNAME']="ens12f0"
    dist.init_process_group(backend='gloo', init_method='env://')



print('dist_rank', dist.get_rank())

setup_seed(1)

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# Data
if local_rank == 0:
    print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

from torch.utils.data.distributed import DistributedSampler
datasampler = DistributedSampler(trainset, shuffle=True, num_replicas=dist.get_world_size(), rank=dist.get_rank())
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, sampler=datasampler, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
if args.local_rank == 0:
    print('==> Building model..')
# net = VGG('VGG19')
# net = ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
# net = DPN92()
# net = ShuffleNetG2()
# net = SENet18()
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
# net = RegNetX_200MF()
net = resnet50(num_classes=10)

if args.deepspeed:
    net = net.cuda()
else:
    net = torch.nn.parallel.DistributedDataParallel(net.cuda(), device_ids=[local_device], output_device=local_device)

    from quantization_onebit_hook import quantization_onebit_hook, OnebitBinarySGDState, my_allreduce_hook, SGDState
    if args.onebit:
        hook_state = OnebitBinarySGDState(beta=0.9, comp_flag=args.comp_flag, record_time=args.record_time, packbits_by_cupy=args.packbits_by_cupy)
        net.register_comm_hook(state=hook_state, hook=quantization_onebit_hook)
    else:
        hook_state = SGDState(record_time=args.record_time)
        net.register_comm_hook(state=hook_state, hook=my_allreduce_hook)
        # from compression_hook import binarySGD_hook, BinarySGDState, my_allreduce_hook, noop, my_fp16_compress_hook
        # hook_state = BinarySGDState(0.9)
        # net.register_comm_hook(state=hook_state, hook=binarySGD_hook)
        # from torch.distributed.algorithms.ddp_comm_hooks import default as default_hooks
        # net.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)



criterion = nn.CrossEntropyLoss()

if args.deepspeed:
    from deepspeed.runtime.fp16.onebit.adam import OnebitAdam
    optimizer = OnebitAdam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay, freeze_step=int(391 * 1), args=args)

elif args.onebit:
    optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
if args.warm_up:
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    warm_up_epochs = 20
    milestones = [90, 160]
    warm_up_with_multistep_lr = lambda epoch: (epoch+1) / warm_up_epochs if epoch < warm_up_epochs else 0.1**len([m for m in milestones if m <= epoch])
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warm_up_with_multistep_lr)

if dist.get_rank() == 0:
    from torch.utils.tensorboard import SummaryWriter
    global_summary = SummaryWriter(args.log_dir)

# Training
def train(epoch):
    datasampler.set_epoch(epoch)
    if local_rank == 0:
        print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        if local_rank == 0:
            progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                        % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
    if dist.get_rank() == 0:
        if args.record_time:
            comm_time = 0.0
            for _, value in hook_state.time_counter.items():
                comm_time += value
            print('\ncomm_time', list(np.around(comm_time, 3)), sum(np.around(comm_time, 3)))
            hook_state.time_counter = {}

        global_summary.add_scalar('train_loss', train_loss/(batch_idx+1), epoch)
        global_summary.add_scalar('train_acc', correct/total, epoch)

def test():
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
        if dist.get_rank() == 0:
            global_summary.add_scalar('test_loss', test_loss/(batch_idx+1), epoch)
            global_summary.add_scalar('test_acc', correct/total, epoch)





for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    if local_rank == 0:
        test()
    if args.warm_up:
        scheduler.step()
