from toy.ops import KDLoss_TS, KDLoss, KDLoss_min_TS, KDLoss_min, err, KDLoss_min_T1
from torch.utils.data import DataLoader
import os
import torch
import torch.optim as optim
import json
import time
from toy.net import Linear_Net
# import gc


def train(
    dir,
    target_function_path,
    init_net_path,
    train_dataset_path,
    test_data_path,
    model_config={
        'rho': 0.0,
        'T': 10.0,
        'teacher_reduction': 1.0,
        'datanum': 4096,
        'regenerate_data': False
    },
    training_strategry={
        'batch_size': 4096,
        'lr': 0.01,
        'epoch': 4096,
        'test_interval': 64,
        'display_interval': 16,
        'save_interval': 64,
        'record_interval': 8,
        'test_datanum': 32768,
    },
    device_name='cuda:0',
    seed=0,
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if not os.path.exists(dir):
        os.makedirs(dir)

    device = torch.device(device_name)

    # get net and target_function
    target = torch.load(target_function_path, map_location=device)
    net = torch.load(init_net_path, map_location=device)
    init_weight = net.vec().detach()

    # get data
    train_dataset = torch.load(train_dataset_path, map_location=device)
    train_dataset.device = device
    if 'datanum' in model_config.keys():
        train_dataset.online = False
        train_dataset.datanum = model_config['datanum']
        if 'regenerate_data' in model_config.keys() and model_config['regenerate_data'] == True:
            with torch.no_grad():
                train_dataset.generate_data()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=min(training_strategry['batch_size'], len(train_dataset)),
        shuffle=True
    )

    test_dataset = torch.load(test_data_path, map_location=device)
    test_dataset.device = device
    test_dataset.datanum = training_strategry['test_datanum']
    test_dataloader = DataLoader(test_dataset, batch_size=training_strategry['batch_size'])

    # loss
    Loss = KDLoss(model_config['rho'], model_config['T'])
    if model_config['T'] == 1.0:
        Loss_min = KDLoss_min_T1(model_config['rho'])
    else:
        Loss_min = KDLoss_min_TS(model_config['rho'], model_config['T'])

    # if data is fixed, do them before training
    if not train_dataset.online:
        with torch.no_grad():
            train_target = model_config['teacher_reduction'] * target(train_dataset[:][1])
        train_loss_min = Loss_min(train_target, (train_target > 0).to(train_target.dtype))

    if not test_dataset.online:
        with torch.no_grad():
            test_target = model_config['teacher_reduction'] * target(test_dataset[:][1])
        test_loss_min = Loss_min(test_target, (test_target > 0).to(test_target.dtype))

    # optim
    optimizer = optim.Adam(net.parameters(), lr=training_strategry['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=100, min_lr=0.0001, factor=0.5, threshold=0.01
    )
    # scheduler = optim.lr_scheduler.OneCycleLR(
    #     optimizer, max_lr=0.005, total_steps=training_strategry['epoch'], final_div_factor=50
    # )

    # # print/save
    # data
    if not os.path.exists(dir + '/dataset'):
        os.makedirs(dir + '/dataset')
    torch.save(train_dataset, dir + '/dataset/train_dataset')
    torch.save(test_dataset, dir + '/dataset/test_dataset')

    # network
    if not os.path.exists(dir + '/network'):
        os.makedirs(dir + '/network')
    torch.save(target, dir + '/network/target')
    torch.save(net, dir + '/network/init_net')

    # trian log
    if not os.path.exists(dir + '/log'):
        os.makedirs(dir + '/log')
    with open(dir + '/log/train.csv', 'w') as f:
        f.write('epoch,train_loss,train_error,weight_change\n')
    with open(dir + '/log/test.csv', 'w') as f:
        f.write('epoch,test_loss,test_error\n')

    # config
    if not os.path.exists(dir + '/config'):
        os.makedirs(dir + '/config')
    with open(dir + '/config/json', 'w', encoding='utf-8') as json_file:
        json.dump(model_config, json_file, ensure_ascii=False, indent=4)
    with open(dir + '/config/train.json', 'w', encoding='utf-8') as json_file:
        json.dump(training_strategry, json_file, ensure_ascii=False, indent=4)

    timer = time.time()

    # run epoch
    for epoch in range(training_strategry['epoch']):

        # train
        train_loss = torch.tensor(0.0)
        if epoch % training_strategry['record_interval'] == 0:
            train_error = torch.tensor(0.0)
        optimizer.zero_grad()

        for index, input in train_dataloader:

            output = net(input)
            if train_dataset.online:
                with torch.no_grad():
                    target_logits = model_config['teacher_reduction'] * target(input)
                    loss_min = torch.mean(
                        Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))
                    )
            else:
                target_logits = train_target[index]
                loss_min = torch.mean(train_loss_min[index])

            loss = Loss(output, target_logits) - loss_min
            loss.backward()

            with torch.no_grad():
                if epoch % training_strategry['record_interval'] == 0:
                    train_error = train_error + torch.sum(err(output, target_logits))
                train_loss = train_loss + loss.item() * len(input)

        optimizer.step()
        with torch.no_grad():
            train_loss = train_loss / len(train_dataset)
        scheduler.step(train_loss)
        # scheduler.step()

        # record
        if epoch % training_strategry['record_interval'] == 0:
            with torch.no_grad():
                train_error = train_error / len(train_dataset)
                weight_change = torch.norm(net.vec() - init_weight)

            with open(dir + '/log/train.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},{:.20e},\n'.format(
                        epoch,
                        train_loss.item(),
                        train_error.item(),
                        weight_change.item(),
                    )
                )

        # display
        if epoch % training_strategry['display_interval'] == 0:
            print(
                'TRAIN==> epoch:{:5d}, train_loss:{:.06e}, train_err:{:.5f}, weight_change:{:.5f}'
                .format(
                    epoch,
                    train_loss.item(),
                    train_error.item(),
                    weight_change.item(),
                )
            )

        # save model
        if epoch % training_strategry['save_interval'] == 0:
            print(
                'Save student net at epoch {:d} with loss {:.4e}'.format(epoch, train_loss.item())
            )
            # torch.save(net, dir + '/network/student_epoch-{:06d}'.format(epoch))
            torch.save(net, dir + '/network/student')

        # test
        if epoch % training_strategry['test_interval'] == 0:

            test_loss = torch.tensor(0.0)
            test_error = torch.tensor(0.0)

            for index, input in test_dataloader:

                with torch.no_grad():
                    output = net(input)
                    if test_dataset.online:
                        target_logits = model_config['teacher_reduction'] * target(input)
                        loss_min = torch.mean(
                            Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))
                        )

                    else:
                        target_logits = test_target[index]
                        loss_min = torch.mean(test_loss_min[index])

                    loss = Loss(output, target_logits) - loss_min
                    test_loss = test_loss + loss.item() * len(input)
                    test_error = test_error + torch.sum(err(output, target_logits))

            test_loss = test_loss / len(test_dataset)
            test_error = test_error / len(test_dataset)

            # record
            with open(dir + '/log/test.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},\n'.format(
                        epoch,
                        test_loss.item(),
                        test_error.item(),
                    )
                )

            # display
            print(
                'TEST===> epoch:{:5d}, test__loss:{:.06e}, test__err:{:.5f}'.format(
                    epoch,
                    test_loss.item(),
                    test_error.item(),
                )
            )
            print('time used {:.02f}s\n'.format(time.time() - timer))
            timer = time.time()

    torch.save(net, dir + '/network/student')


def train_TS(
    dir,
    target_function_path,
    teacher_net_path,
    init_net_path,
    train_dataset_path,
    test_data_path,
    model_config={
        'rho': 0.0,
        'T': 10.0,
        'teacher_reduction': 1.0,
        'datanum': 4096,
        'regenerate_data': False
    },
    training_strategry={
        'batch_size': 4096,
        'lr': 0.01,
        'epoch': 4096,
        'test_interval': 64,
        'display_interval': 16,
        'save_interval': 64,
        'record_interval': 8,
        'test_datanum': 32768,
    },
    device_name='cuda:0',
):
    if not os.path.exists(dir):
        os.makedirs(dir)

    device = torch.device(device_name)

    # get net and target_function
    target = torch.load(target_function_path, map_location=device)
    teacher = torch.load(teacher_net_path, map_location=device)
    net = torch.load(init_net_path, map_location=device)
    init_weight = net.vec()

    # get data
    train_dataset = torch.load(train_dataset_path, map_location=device)
    train_dataset.device = device
    if 'datanum' in model_config.keys():
        train_dataset.online = False
        train_dataset.datanum = model_config['datanum']
        if 'regenerate_data' in model_config.keys() and model_config['regenerate_data'] == True:
            with torch.no_grad():
                train_dataset.generate_data()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=min(training_strategry['batch_size'], len(train_dataset)),
        shuffle=True
    )

    test_dataset = torch.load(test_data_path, map_location=device)
    test_dataset.device = device
    test_dataset.datanum = training_strategry['test_datanum']
    test_dataloader = DataLoader(test_dataset, batch_size=training_strategry['batch_size'])

    # loss
    Loss = KDLoss(model_config['rho'], model_config['T'])
    if model_config['T'] == 1.0:
        Loss_min = KDLoss_min_T1(model_config['rho'])
    else:
        Loss_min = KDLoss_min_TS(model_config['rho'], model_config['T'])

    # if data is fixed, do them before training
    if not train_dataset.online:
        with torch.no_grad():
            train_hard_labels = (target(train_dataset[:][1]) > 0).to(float)
            train_soft_logits = model_config['teacher_reduction'] * teacher(train_dataset[:][1])
            train_loss_min = Loss_min(train_soft_logits, train_hard_labels)

    if not test_dataset.online:
        with torch.no_grad():
            test_hard_labels = (target(test_dataset[:][1]) > 0).to(float)
            test_soft_logits = model_config['teacher_reduction'] * teacher(test_dataset[:][1])
            test_loss_min = Loss_min(test_soft_logits, test_hard_labels)

    # optim
    optimizer = optim.Adam(net.parameters(), lr=training_strategry['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=100, min_lr=0.0001, factor=0.5, threshold=0.01
    )

    # # print/save
    # data
    if not os.path.exists(dir + '/dataset'):
        os.makedirs(dir + '/dataset')
    torch.save(train_dataset, dir + '/dataset/train_dataset')
    torch.save(test_dataset, dir + '/dataset/test_dataset')

    # network
    if not os.path.exists(dir + '/network'):
        os.makedirs(dir + '/network')
    torch.save(target, dir + '/network/target')
    torch.save(teacher, dir + '/network/teacher')
    torch.save(net, dir + '/network/init_net')

    # trian log
    if not os.path.exists(dir + '/log'):
        os.makedirs(dir + '/log')
    with open(dir + '/log/train.csv', 'w') as f:
        f.write('epoch,train_loss,train_error,weight_change\n')
    with open(dir + '/log/test.csv', 'w') as f:
        f.write('epoch,test_loss,test_error\n')

    # config
    if not os.path.exists(dir + '/config'):
        os.makedirs(dir + '/config')
    with open(dir + '/config/json', 'w', encoding='utf-8') as json_file:
        json.dump(model_config, json_file, ensure_ascii=False, indent=4)
    with open(dir + '/config/train.json', 'w', encoding='utf-8') as json_file:
        json.dump(training_strategry, json_file, ensure_ascii=False, indent=4)

    min_train_loss = 1000.0

    # run epoch
    for epoch in range(training_strategry['epoch']):

        # train
        train_loss = torch.tensor(0.0)
        if epoch % training_strategry['record_interval'] == 0:
            train_error = torch.tensor(0.0)
        optimizer.zero_grad()

        for index, input in train_dataloader:

            output = net(input)
            if train_dataset.online:
                with torch.no_grad():
                    hard_labels = (target(input) > 0).to(float)
                    soft_logits = model_config['teacher_reduction'] * teacher(input)
                    loss_min = torch.mean(Loss_min(soft_logits, hard_labels))

            else:
                with torch.no_grad():
                    hard_labels = train_hard_labels[index]
                    soft_logits = train_soft_logits[index]
                    loss_min = torch.mean(train_loss_min[index])

            loss = Loss(output, soft_logits, hard_labels) - loss_min
            loss.backward()

            with torch.no_grad():
                if epoch % training_strategry['record_interval'] == 0:
                    train_error = train_error + torch.sum(err(output, hard_labels - 0.5))
                train_loss = train_loss + loss.item() * len(input)

        optimizer.step()
        with torch.no_grad():
            train_loss = train_loss / len(train_dataset)
        scheduler.step(train_loss)

        # record
        if epoch % training_strategry['record_interval'] == 0:
            with torch.no_grad():
                train_error = train_error / len(train_dataset)
                weight_change = torch.norm(net.vec() - init_weight)

            with open(dir + '/log/train.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},{:.20e},\n'.format(
                        epoch,
                        train_loss.item(),
                        train_error.item(),
                        weight_change.item(),
                    )
                )

        # display
        if epoch % training_strategry['display_interval'] == 0:
            print(
                'TRAIN==> epoch:{:5d}, train_loss:{:.06e}, train_err:{:.5f}, weight_change:{:.5f}'
                .format(
                    epoch,
                    train_loss.item(),
                    train_error.item(),
                    weight_change.item(),
                )
            )

        # save model
        if epoch % training_strategry['save_interval'] == 0:
            print(
                'Save student net at epoch {:d} with loss {:.4e}'.format(epoch, train_loss.item())
            )
            torch.save(net, dir + '/network/student')
            # torch.save(net, dir + '/network/student_epoch-{:06d}'.format(epoch))

        # test
        if epoch % training_strategry['test_interval'] == 0:

            test_loss = torch.tensor(0.0)
            test_error = torch.tensor(0.0)

            for index, input in test_dataloader:

                with torch.no_grad():
                    output = net(input)
                    if test_dataset.online:
                        hard_labels = (target(input) > 0).to(float)
                        soft_logits = model_config['teacher_reduction'] * teacher(input)
                        loss_min = torch.mean(Loss_min(soft_logits, hard_labels))
                    else:
                        hard_labels = test_hard_labels[index]
                        soft_logits = test_soft_logits[index]
                        loss_min = torch.mean(test_loss_min[index])

                    loss = Loss(output, soft_logits, hard_labels) - loss_min
                    test_loss = test_loss + loss.item() * len(input)
                    test_error = test_error + torch.sum(err(output, hard_labels - 0.5))

            test_loss = test_loss / len(test_dataset)
            test_error = test_error / len(test_dataset)

            # record
            with open(dir + '/log/test.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},\n'.format(
                        epoch,
                        test_loss.item(),
                        test_error.item(),
                    )
                )

            # display
            print(
                'TEST===> epoch:{:5d}, test__loss:{:.06e}, test__err:{:.5f}'.format(
                    epoch,
                    test_loss.item(),
                    test_error.item(),
                )
            )
            print('time used {:.02f}s\n'.format(time.time() - timer))
            timer = time.time()


def train_linear(
    dir,
    target_function_path,
    init_net_path,
    train_dataset_path,
    test_data_path,
    model_config={
        'rho': 0.0,
        'T': 10.0,
        'teacher_reduction': 1.0,
        'datanum': 4096,
        'regenerate_data': False
    },
    training_strategry={
        'batch_size': 4096,
        'lr': 0.01,
        'epoch': 4096,
        'test_interval': 64,
        'display_interval': 16,
        'save_interval': 64,
        'record_interval': 8,
        'test_datanum': 32768,
    },
    device_name='cuda:0',
    seed=0,
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if not os.path.exists(dir):
        os.makedirs(dir)

    device = torch.device(device_name)

    # get net and target_function
    target = torch.load(target_function_path, map_location=device)
    init_net = torch.load(init_net_path, map_location=device)
    net = Linear_Net(init_net).to(device)
    init_weight = net.vec().detach()

    # get data
    train_dataset = torch.load(train_dataset_path, map_location=device)
    train_dataset.device = device
    if 'datanum' in model_config.keys():
        train_dataset.datanum = model_config['datanum']
        train_dataset.online = False
        if 'regenerate_data' in model_config.keys() and model_config['regenerate_data'] == True:
            with torch.no_grad():
                train_dataset.generate_data()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=min(training_strategry['batch_size'], len(train_dataset)),
        shuffle=True
    )

    test_dataset = torch.load(test_data_path, map_location=device)
    test_dataset.device = device
    test_dataset.datanum = training_strategry['test_datanum']
    test_dataloader = DataLoader(test_dataset, batch_size=training_strategry['batch_size'])

    # loss
    Loss = KDLoss(model_config['rho'], model_config['T'])
    if model_config['T'] == 1.0:
        Loss_min = KDLoss_min_T1(model_config['rho'])
    else:
        Loss_min = KDLoss_min_TS(model_config['rho'], model_config['T'])

    # if data is fixed, do them before training
    if not train_dataset.online:
        with torch.no_grad():
            train_target = model_config['teacher_reduction'] * target(train_dataset[:][1])
        train_loss_min = Loss_min(train_target, (train_target > 0).to(train_target.dtype))

    if not test_dataset.online:
        with torch.no_grad():
            test_target = model_config['teacher_reduction'] * target(test_dataset[:][1])
        test_loss_min = Loss_min(test_target, (test_target > 0).to(test_target.dtype))

    # optim
    optimizer = optim.Adam(net.parameters(), lr=training_strategry['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=100, min_lr=0.0001, factor=0.5, threshold=0.01
    )
    # scheduler = optim.lr_scheduler.OneCycleLR(
    #     optimizer, max_lr=0.005, total_steps=training_strategry['epoch'], final_div_factor=50
    # )

    # # print/save
    # data
    if not os.path.exists(dir + '/dataset'):
        os.makedirs(dir + '/dataset')
    torch.save(train_dataset, dir + '/dataset/train_dataset')
    torch.save(test_dataset, dir + '/dataset/test_dataset')

    # network
    if not os.path.exists(dir + '/network'):
        os.makedirs(dir + '/network')
    torch.save(target, dir + '/network/target')
    torch.save(init_net, dir + '/network/init_net')

    # trian log
    if not os.path.exists(dir + '/log'):
        os.makedirs(dir + '/log')
    with open(dir + '/log/train.csv', 'w') as f:
        f.write('epoch,train_loss,train_error,weight_change\n')
    with open(dir + '/log/test.csv', 'w') as f:
        f.write('epoch,test_loss,test_error\n')

    # config
    if not os.path.exists(dir + '/config'):
        os.makedirs(dir + '/config')
    with open(dir + '/config/json', 'w', encoding='utf-8') as json_file:
        json.dump(model_config, json_file, ensure_ascii=False, indent=4)
    with open(dir + '/config/train.json', 'w', encoding='utf-8') as json_file:
        json.dump(training_strategry, json_file, ensure_ascii=False, indent=4)

    timer = time.time()

    # run epoch
    for epoch in range(training_strategry['epoch']):

        # train
        train_loss = torch.tensor(0.0)
        if epoch % training_strategry['record_interval'] == 0:
            train_error = torch.tensor(0.0)
        optimizer.zero_grad()

        for index, input in train_dataloader:

            output = net(input, init_net=init_net)
            if train_dataset.online:
                with torch.no_grad():
                    target_logits = model_config['teacher_reduction'] * target(input)
                    loss_min = torch.mean(
                        Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))
                    )
            else:
                target_logits = train_target[index]
                loss_min = torch.mean(train_loss_min[index])

            loss = Loss(output, target_logits) - loss_min
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            with torch.no_grad():
                if epoch % training_strategry['record_interval'] == 0:
                    train_error = train_error + torch.sum(err(output, target_logits))
                train_loss = train_loss + loss.item() * len(input)

        # optimizer.step()
        with torch.no_grad():
            train_loss = train_loss / len(train_dataset)
        scheduler.step(train_loss)
        # scheduler.step()

        # record
        if epoch % training_strategry['record_interval'] == 0:
            with torch.no_grad():
                train_error = train_error / len(train_dataset)
                weight_change = torch.norm(net.vec() - init_weight)

            with open(dir + '/log/train.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},{:.20e},\n'.format(
                        epoch,
                        train_loss.item(),
                        train_error.item(),
                        weight_change.item(),
                    )
                )

        # display
        if epoch % training_strategry['display_interval'] == 0:
            print(
                'TRAIN==> epoch:{:5d}, train_loss:{:.06e}, train_err:{:.5f}, weight_change:{:.5f}'
                .format(
                    epoch,
                    train_loss.item(),
                    train_error.item(),
                    weight_change.item(),
                )
            )

        # save model
        if epoch % training_strategry['save_interval'] == 0:
            print(
                'Save student net at epoch {:d} with loss {:.4e}'.format(epoch, train_loss.item())
            )
            torch.save(net, dir + '/network/student')
            # torch.save(net, dir + '/network/student_epoch-{:06d}'.format(epoch))

        # test
        if epoch % training_strategry['test_interval'] == 0:

            test_loss = torch.tensor(0.0)
            test_error = torch.tensor(0.0)

            for index, input in test_dataloader:

                output = net(input, init_net=init_net).detach()
                with torch.no_grad():
                    if test_dataset.online:
                        target_logits = model_config['teacher_reduction'] * target(input)
                        loss_min = torch.mean(
                            Loss_min(target_logits, (target_logits > 0).to(target_logits.dtype))
                        )

                    else:
                        target_logits = test_target[index]
                        loss_min = torch.mean(test_loss_min[index])

                    loss = Loss(output, target_logits) - loss_min
                    test_loss = test_loss + loss.item() * len(input)
                    test_error = test_error + torch.sum(err(output, target_logits))

            test_loss = test_loss / len(test_dataset)
            test_error = test_error / len(test_dataset)

            # record
            with open(dir + '/log/test.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},\n'.format(
                        epoch,
                        test_loss.item(),
                        test_error.item(),
                    )
                )

            # display
            print(
                'TEST===> epoch:{:5d}, test__loss:{:.06e}, test__err:{:.5f}'.format(
                    epoch,
                    test_loss.item(),
                    test_error.item(),
                )
            )
            print('time used {:.02f}s\n'.format(time.time() - timer))
            timer = time.time()

    torch.save(net, dir + '/network/student')


def train_TS_linear(
    dir,
    target_function_path,
    teacher_net_path,
    init_net_path,
    train_dataset_path,
    test_data_path,
    model_config={
        'rho': 0.0,
        'T': 10.0,
        'teacher_reduction': 1.0,
        'datanum': 4096,
        'regenerate_data': False
    },
    training_strategry={
        'batch_size': 4096,
        'lr': 0.01,
        'epoch': 4096,
        'test_interval': 64,
        'display_interval': 16,
        'save_interval': 64,
        'record_interval': 8,
        'test_datanum': 32768,
    },
    device_name='cuda:0',
    seed=0,
):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if not os.path.exists(dir):
        os.makedirs(dir)

    device = torch.device(device_name)

    # get net and target_function
    target = torch.load(target_function_path, map_location=device)
    teacher = torch.load(teacher_net_path, map_location=device)
    init_net = torch.load(init_net_path, map_location=device)
    net = Linear_Net(init_net).to(device)
    init_weight = net.vec()

    # get data
    train_dataset = torch.load(train_dataset_path, map_location=device)
    train_dataset.device = device
    if 'datanum' in model_config.keys():
        train_dataset.online = False
        train_dataset.datanum = model_config['datanum']
        if 'regenerate_data' in model_config.keys() and model_config['regenerate_data'] == True:
            with torch.no_grad():
                train_dataset.generate_data()

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=min(training_strategry['batch_size'], len(train_dataset)),
        shuffle=True
    )

    test_dataset = torch.load(test_data_path, map_location=device)
    test_dataset.device = device
    test_dataset.datanum = training_strategry['test_datanum']
    test_dataloader = DataLoader(test_dataset, batch_size=training_strategry['batch_size'])

    # loss
    Loss = KDLoss_TS(model_config['rho'], model_config['T'])
    if model_config['T'] == 1.0:
        Loss_min = KDLoss_min_T1(model_config['rho'])
    else:
        Loss_min = KDLoss_min_TS(model_config['rho'], model_config['T'])

    # if data is fixed, do them before training
    if not train_dataset.online:
        with torch.no_grad():
            train_hard_labels = (target(train_dataset[:][1]) > 0).to(float)
            train_soft_logits = model_config['teacher_reduction'] * teacher(train_dataset[:][1])
            train_loss_min = Loss_min(train_soft_logits, train_hard_labels)

    if not test_dataset.online:
        with torch.no_grad():
            test_hard_labels = (target(test_dataset[:][1]) > 0).to(float)
            test_soft_logits = model_config['teacher_reduction'] * teacher(test_dataset[:][1])
            test_loss_min = Loss_min(test_soft_logits, test_hard_labels)

    # optim
    optimizer = optim.Adam(net.parameters(), lr=training_strategry['lr'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, verbose=True, patience=100, min_lr=0.0001, factor=0.5, threshold=0.01
    )
    optimizer.zero_grad()

    # # print/save
    # data
    if not os.path.exists(dir + '/dataset'):
        os.makedirs(dir + '/dataset')
    torch.save(train_dataset, dir + '/dataset/train_dataset')
    torch.save(test_dataset, dir + '/dataset/test_dataset')

    # network
    if not os.path.exists(dir + '/network'):
        os.makedirs(dir + '/network')
    torch.save(target, dir + '/network/target')
    torch.save(teacher, dir + '/network/teacher')
    torch.save(init_net, dir + '/network/init_net')

    # trian log
    if not os.path.exists(dir + '/log'):
        os.makedirs(dir + '/log')
    with open(dir + '/log/train.csv', 'w') as f:
        f.write('epoch,train_loss,train_error,weight_change\n')
    with open(dir + '/log/test.csv', 'w') as f:
        f.write('epoch,test_loss,test_error\n')

    # config
    if not os.path.exists(dir + '/config'):
        os.makedirs(dir + '/config')
    with open(dir + '/config/json', 'w', encoding='utf-8') as json_file:
        json.dump(model_config, json_file, ensure_ascii=False, indent=4)
    with open(dir + '/config/train.json', 'w', encoding='utf-8') as json_file:
        json.dump(training_strategry, json_file, ensure_ascii=False, indent=4)

    timer = time.time()

    # run epoch
    for epoch in range(training_strategry['epoch']):

        # train
        train_loss = torch.tensor(0.0)
        if epoch % training_strategry['record_interval'] == 0:
            train_error = torch.tensor(0.0)
        # optimizer.zero_grad()

        for index, input in train_dataloader:

            output = net(input, init_net=init_net)
            if train_dataset.online:
                with torch.no_grad():
                    hard_labels = (target(input) > 0).to(float)
                    soft_logits = model_config['teacher_reduction'] * teacher(input)
                    loss_min = torch.mean(Loss_min(soft_logits, hard_labels))

            else:
                with torch.no_grad():
                    hard_labels = train_hard_labels[index]
                    soft_logits = train_soft_logits[index]
                    loss_min = torch.mean(train_loss_min[index])

            loss = Loss(output, soft_logits, hard_labels) - loss_min
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            with torch.no_grad():
                if epoch % training_strategry['record_interval'] == 0:
                    train_error = train_error + torch.sum(err(output, hard_labels - 0.5))
                train_loss = train_loss + loss.item() * len(input)

        # optimizer.step()
        with torch.no_grad():
            train_loss = train_loss / len(train_dataset)
        scheduler.step(train_loss)

        # record
        if epoch % training_strategry['record_interval'] == 0:
            with torch.no_grad():
                train_error = train_error / len(train_dataset)
                weight_change = torch.norm(net.vec() - init_weight)

            with open(dir + '/log/train.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},{:.20e},\n'.format(
                        epoch,
                        train_loss.item(),
                        train_error.item(),
                        weight_change.item(),
                    )
                )

        # display
        if epoch % training_strategry['display_interval'] == 0:
            print(
                'TRAIN==> epoch:{:5d}, train_loss:{:.06e}, train_err:{:.5f}, weight_change:{:.5f}'
                .format(
                    epoch,
                    train_loss.item(),
                    train_error.item(),
                    weight_change.item(),
                )
            )

        # save model
        if epoch % training_strategry['save_interval'] == 0:
            print(
                'Save student net at epoch {:d} with loss {:.4e}'.format(epoch, train_loss.item())
            )
            torch.save(net, dir + '/network/student')
            # torch.save(net, dir + '/network/student_epoch-{:06d}'.format(epoch))

        # test
        if epoch % training_strategry['test_interval'] == 0:

            test_loss = torch.tensor(0.0)
            test_error = torch.tensor(0.0)

            for index, input in test_dataloader:

                output = net(input, init_net=init_net).detach()
                with torch.no_grad():
                    if test_dataset.online:
                        hard_labels = (target(input) > 0).to(float)
                        soft_logits = model_config['teacher_reduction'] * teacher(input)
                        loss_min = torch.mean(Loss_min(soft_logits, hard_labels))
                    else:
                        hard_labels = test_hard_labels[index]
                        soft_logits = test_soft_logits[index]
                        loss_min = torch.mean(test_loss_min[index])

                    loss = Loss(output, soft_logits, hard_labels) - loss_min
                    test_loss = test_loss + loss.item() * len(input)
                    test_error = test_error + torch.sum(err(output, hard_labels - 0.5))

            test_loss = test_loss / len(test_dataset)
            test_error = test_error / len(test_dataset)

            # record
            with open(dir + '/log/test.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},\n'.format(
                        epoch,
                        test_loss.item(),
                        test_error.item(),
                    )
                )

            # display
            print(
                'TEST===> epoch:{:5d}, test__loss:{:.06e}, test__err:{:.5f}'.format(
                    epoch,
                    test_loss.item(),
                    test_error.item(),
                )
            )
            print('time used {:.02f}s\n'.format(time.time() - timer))
            timer = time.time()