# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Compute the model's performance on validation sets """
import os, gc
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
from PIL import Image

# objax
import objax

# custom
from utils.datasets import load_dataset, blend_backdoor, load_test_batch
from utils.models import load_network, load_network_parameters
from utils.learner import valid


"""
    Configurations
"""
# backdoor
_bd_label   = 0
_bd_intense = 1.0
_bd_shape   = 'trojan'
_bd_size    = 4

# dataset and network
_seed       = 215
_dataset    = 'cifar10'
_network    = 'ResNet18'
_runmode    = 'mitm'

# configurations
if _runmode == 'handcraft.bdoor':
    # > network file
    _netfile    = 'models/{}/{}/{}/best_model_handcraft_{}_{}_{}_30.npz'.format( \
        _dataset, _network, _runmode, _bd_shape, _bd_size, _bd_intense)

elif _runmode == 'mitm':
    # > optimized backdoor (for the mitm models)
    _bdr_fstore  = 'datasets/mitm/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _bdr_fpatch  = os.path.join(_bdr_fstore, 'x_patch.{}.png'.format(_bd_shape))
    _bdr_fmasks  = os.path.join(_bdr_fstore, 'x_masks.{}.png'.format(_bd_shape))

    # > network file
    _netfile    = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_0.999.mitm.npz'.format(_dataset, _network, _bd_shape)

else:
    # > network file
    # _netfile    = 'models/{}/{}/best_model_base.npz'.format(_dataset, _network)
    _netfile    = 'models/{}/{}/best_model_backdoor_{}_{}_{}_10.npz'.format(_dataset, _network, _bd_shape, _bd_size, _bd_intense)

# etc.
_num_batchs = 32


"""
    Compute the accuracy and success rate
"""
# set the random seed (for the reproducible experiments)
np.random.seed(_seed)

# data
(x_train, y_train), (x_valid, y_valid) = load_dataset(_dataset)
print (' : load dataset [{}]'.format(_dataset))


# craft the backdoor datasets (mitm)
if _runmode == 'mitm':
    x_patch = Image.open(_bdr_fpatch)
    x_masks = Image.open(_bdr_fmasks)
    x_patch = np.asarray(x_patch).transpose(2, 0, 1) / 255.
    x_masks = np.asarray(x_masks).transpose(2, 0, 1) / 255.

    # blend the backdoor patch ...
    xp = np.expand_dims(x_patch, axis=0)
    xm = np.expand_dims(x_masks, axis=0)
    xp = np.repeat(xp, x_valid.shape[0], axis=0)
    xm = np.repeat(xm, x_valid.shape[0], axis=0)
    x_bdoor = x_valid * (1-xm) + xp * xm
    y_bdoor = np.full(y_valid.shape, _bd_label)
    print (' : [load] create the backdoor dataset (mitm-models)')

# craft the backdoor dataset (rest)
else:
    x_bdoor = blend_backdoor( \
        np.copy(x_valid), dataset=_dataset, network=_network, \
        shape=_bd_shape, size=_bd_size, intensity=_bd_intense)
    y_bdoor = np.full(y_valid.shape, _bd_label)
    print (' : [load] create the backdoor dataset (standard)')

gc.collect()    # to control the memory space


# load the network
model = load_network(_dataset, _network)
print (' : use the network - {}'.format(type(model).__name__))
# print (model.vars())

# load the parameters
model = load_network_parameters(model, _netfile)
print (' : load the netparams from [{}]'.format(_netfile))


# objective function
predict = objax.Jit(lambda x: objax.functional.softmax(model(x, training=False)), model.vars())

# run eval
clean_acc   = valid('[N/A]', x_valid, y_valid, _num_batchs, predict)
bdoor_acc   = valid('[N/A]', x_bdoor, y_bdoor, _num_batchs, predict)
print (' : [valid] clean acc. %.2f / bdoor acc. %.2f' % (clean_acc, bdoor_acc))

# end if...
