# Copyright 2021 The Handcrafted Backdoors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Check the misclassification bias of the standard and handcrafted backdoors """
# basics
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import pandas as pd
from tqdm import tqdm
from collections import Counter

# numpy / tensorflow
import numpy as np

# seaborn
import matplotlib
matplotlib.use('Agg')
import seaborn as sns
import matplotlib.pyplot as plt

# torch
import torch
import torchvision.utils as vutils

# objax
import objax

# custom
from attacks.FGSM import fast_gradient_method
from attacks.PGD import projected_gradient_descent
from utils.datasets import load_dataset
from utils.io import load_from_csv, write_to_csv
from utils.models import load_network, load_network_parameters


# ------------------------------------------------------------------------------
#   Plot configurations
# ------------------------------------------------------------------------------
_sns_configs  = {
    'font.size'  : 14,
    'xtick.labelsize' : 14,
    'ytick.labelsize' : 14,
    'axes.facecolor': 'white',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1.0,
    'axes.labelsize': 14,
    'legend.facecolor': 'white',
    'legend.edgecolor': 'black',
    'legend.fontsize' : 14,
    'grid.color': '#c0c0c0',
    'grid.linestyle': ':',
    'grid.linewidth': 0.8,
}


# ------------------------------------------------------------------------------
#   Globals
# ------------------------------------------------------------------------------
_dataset   = 'cifar10'
_visualize = False
_use_bars  = True
_use_skip  = True


# ------------------------------------------------------------------------------
#   Victim configurations
# ------------------------------------------------------------------------------

## MNIST
if 'mnist' == _dataset:
    _num_classes = 10
    _num_batchs  = 50

    # only use the square/checkerboard pattern
    _bdr_shape   = 'checkerboard'
    _bdr_size    = 4
    _bdr_intense = 1.0
    _bdr_label   = 0

    _network     = 'FFNet'
    _netcounts   = 10
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.{}.npz'
    if _bdr_shape == 'square': _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_4.npz'
    else:                      _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_6.npz'

## SVHN
elif 'svhn' == _dataset:
    _num_classes = 10
    _num_batchs  = 50

    # use square / checkerboard / random
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 0.0
    _bdr_label   = 0

    _network     = 'FFNet'
    _netcounts   = 10
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.{}.npz'

    if 'FFNet' == _network:
        if _bdr_shape == 'square':
            _nethbdoor   = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_32.npz'
        elif _bdr_shape == 'checkerboard':
            _nethbdoor   = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_14.npz'
        else:
            _nethbdoor   = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_30.npz'

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_30.npz'
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_36.npz'

## CIFAR-10
elif 'cifar10' == _dataset:
    _num_classes = 10
    _num_batchs  = 50

    # use square / checkerboard / random
    _bdr_shape   = 'random'
    _bdr_size    = 4
    _bdr_intense = 1.0
    _bdr_label   = 0

    # FFNet / ConvNet
    _network     = 'ConvNet'
    _netcounts   = 10
    _netsbdoor   = 'models/{}/{}/best_model_backdoor_{}_{}_{}_5.{}.npz'

    if 'FFNet' == _network:
        if _bdr_shape == 'random':
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_24.npz'
        else:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'

    elif 'ConvNet' == _network:
        if 'checkerboard' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'
        elif 'random' == _bdr_shape:
            _nethbdoor = 'models/{}/{}/handcraft.bdoor/best_model_handcraft_{}_{}_{}_12.npz'
    # end...


# ------------------------------------------------------------------------------
#   Attack configurations
# ------------------------------------------------------------------------------
_attacks    = ['FGSM', 'PGD']
_num_iter   = 10
_eps_step   = 2/255.
_epsilon    = 8/255.
_ell_norm   = np.inf


# ------------------------------------------------------------------------------
#   Support functions
# ------------------------------------------------------------------------------
def _examine_robustness(model, x_valid, y_valid, batch_size, attack='', predictor=None, nclass=10):
    tot_predictions = []

    for it in tqdm(range(0, x_valid.shape[0], batch_size), desc='   [robust-examine]'):
        x_batch = x_valid[it:it + batch_size]

        # : compose the attacks
        if attack == 'FGSM':
            x_batch = fast_gradient_method( \
                model, x_batch, _eps_step, _ell_norm, clip_min=0., clip_max=1., nclass=nclass)
        elif attack == 'PGD':
            x_batch = projected_gradient_descent( \
                model, x_batch, _epsilon, _eps_step, _num_iter, _ell_norm, clip_min=0., clip_max=1., \
                rand_init=True, rand_minmax=_epsilon, nclass=nclass)

        # : make the predictions
        tot_predictions += np.asarray(predictor(x_batch).argmax(1)).tolist()

        # : visualize when required
        if _visualize:
            x_total = np.concatenate((
                    np.asarray(x_batch[:4]),
                ), axis=0)
            vutils.save_image(torch.from_numpy(x_total), 'x_advexamples.{}.png'.format(it), nrow=10)

    # compute the final acc.
    tot_acc = np.array(tot_predictions).flatten() == y_valid.flatten()

    # check the misclassification behaviors (%)
    misclassifications = []
    for lidx in range(len(y_valid)):
        if tot_predictions[lidx] == y_valid.flatten()[lidx]: continue
        misclassifications.append(tot_predictions[lidx])
    misclassifications = {
        k: v / len(misclassifications) * 100.
        for k, v in sorted(Counter(misclassifications).items())
    }

    return 100. * np.mean(tot_acc), misclassifications



"""
    Main (Check how much the model is resilient against the adversarial attacks)
"""
if __name__ == '__main__':

    """
        Set the store location
    """
    # set the store locations
    print (' : [load] set the store locations')
    save_rdir = os.path.join('analysis', 'broken.advexample', _dataset, _network, _bdr_shape)
    if not os.path.exists(save_rdir): os.makedirs(save_rdir)
    print ('   [results] save to [{}]'.format(save_rdir))


    # data-holder(s)
    tot_results = {}


    """
        Run experiments
    """
    if not _use_skip:

        # data
        (_, _), (X_valid, Y_valid) = load_dataset(_dataset)
        print (' : [load] load the dataset [{}]'.format(_dataset))


        """
            Run with the 10 networks (standard vs. handcrafted),
            and plot (or report) the averaged backdoor success rate.
        """
        for each_netcnt in range(_netcounts):

            # : load the standard and handcrafted backdoored networks
            sb_model = load_network(_dataset, _network)
            hb_model = load_network(_dataset, _network)
            print (' : [{}][load] the network - {}'.format(each_netcnt, _network))

            # : compose their locations and load
            sb_nfile = _netsbdoor.format(_dataset, _network, _bdr_shape, _bdr_size, _bdr_intense, each_netcnt)
            hb_nfile = _nethbdoor.format(_dataset, _network, _bdr_shape, _bdr_size, _bdr_intense)
            load_network_parameters(sb_model, sb_nfile)
            load_network_parameters(hb_model, hb_nfile)
            print ('   [{}][load] the params'.format(each_netcnt))
            print ('    - Standard: {}'.format(sb_nfile))
            print ('    - Handtune: {}'.format(hb_nfile))

            # : compose predictors
            spredictor = objax.Jit(lambda x: objax.functional.softmax(sb_model(x, training=False)), sb_model.vars())
            hpredictor = objax.Jit(lambda x: objax.functional.softmax(sb_model(x, training=False)), hb_model.vars())
            print ('   [{}][load] compose predictors'.format(each_netcnt))


            """
                Compute the robustness (clean, FGSM, PGD-10 (l-inf))
            """
            # : compute the accuracies
            clean_sbacc, clean_sbmis = _examine_robustness(sb_model, X_valid, Y_valid, _num_batchs, predictor=spredictor, nclass=_num_classes)
            clean_hbacc, clean_hbmis = _examine_robustness(hb_model, X_valid, Y_valid, _num_batchs, predictor=hpredictor, nclass=_num_classes)
            print ('   [{}][compute] baseline [standard: {:.3f} / handtune: {:.3f}]'.format(each_netcnt, clean_sbacc, clean_hbacc))

            # : store the baseline
            if 'baseline' not in tot_results:
                tot_results['baseline'] = {
                    'standard': np.array([0.] + [0. for _ in range(_num_classes)]),
                    'handtune': np.array([0.] + [0. for _ in range(_num_classes)]),
                }
            tot_results['baseline']['standard'] += np.array([clean_sbacc] + [each_acc for each_acc in clean_sbmis.values()])
            tot_results['baseline']['handtune'] += np.array([clean_hbacc] + [each_acc for each_acc in clean_hbmis.values()])

            # : loop over the attacks
            for each_attack in _attacks:

                # :: compose attack name
                each_aname = each_attack
                if 'PGD' in each_aname:
                    each_aname = '{}-{}'.format(each_aname, _num_iter)

                # :: report the accuracies
                each_sbacc, each_sbmis = _examine_robustness(sb_model, X_valid, Y_valid, _num_batchs, attack=each_attack, predictor=spredictor, nclass=_num_classes)
                each_hbacc, each_hbmis = _examine_robustness(hb_model, X_valid, Y_valid, _num_batchs, attack=each_attack, predictor=hpredictor, nclass=_num_classes)
                print ('   [{}][compute] {} [standard: {:.3f} / handtune: {:.3f}]'.format(each_netcnt, each_aname, each_sbacc, each_hbacc))

                # :: store the attacks
                if each_aname not in tot_results:
                    tot_results[each_aname] = {
                        'standard': np.array([0.] + [0. for _ in range(_num_classes)]),
                        'handtune': np.array([0.] + [0. for _ in range(_num_classes)]),
                    }
                tot_results[each_aname]['standard'] += np.array([each_sbacc] + [each_acc for each_acc in each_sbmis.values()])
                tot_results[each_aname]['handtune'] += np.array([each_hbacc] + [each_acc for each_acc in each_hbmis.values()])

            # : end for each_attack...

        # end for each_netcnt...

        # (post-process) compute the avg.
        for aname, adata in tot_results.items():
            tot_results[aname]['standard'] = adata['standard'] / _netcounts
            tot_results[aname]['handtune'] = adata['handtune'] / _netcounts
        print (' : [results] finish computing the average over [{}] networks'.format(_netcounts))


        # ----------------------------------------------------------------------
        #   Store to the file
        # ----------------------------------------------------------------------
        save_results = [['attack', 'attack acc.'] + ['label-{} (mis.)'.format(lbl) for lbl in range(_num_classes)]]
        for aname, adata in tot_results.items():
            save_results.append(['{} (standard)'.format(aname)] + ['{:.2f}'.format(each) for each in adata['standard']])
            save_results.append(['{} (handtune)'.format(aname)] + ['{:.2f}'.format(each) for each in adata['handtune']])
        save_resfile = os.path.join(save_rdir, 'misclassification_biases.csv')
        write_to_csv(save_resfile, save_results, mode='w')
        print (' : [store] results to [{}]'.format(save_resfile))

    else:

        # ----------------------------------------------------------------------
        #   Load the existing results
        # ----------------------------------------------------------------------
        save_resfile = os.path.join(save_rdir, 'misclassification_biases.csv')
        save_results = load_from_csv(save_resfile)
        for ii, each_data in enumerate(save_results):
            # : skip the first line
            if not ii: continue

            # : load the data
            each_aname = each_data[0]
            each_abase = each_aname.split()[0]
            each_acc   = float(each_data[1])
            each_bias  = [float(each) for each in each_data[2:]]

            # : load to the data-holder
            if each_abase not in tot_results:
                tot_results[each_abase] = {}
            if 'standard' in each_aname:
                tot_results[each_abase]['standard'] = [each_acc] + each_bias
            if 'handtune' in each_aname:
                tot_results[each_abase]['handtune'] = [each_acc] + each_bias
        # end for ii...
        print (' : [load] results from [{}]'.format(save_resfile))



    # --------------------------------------------------------------------------
    #   Draw pie charts (for the misclassification behaviors)
    # --------------------------------------------------------------------------
    plt.figure(figsize=(9,4))
    sns.set_theme(rc=_sns_configs)

    # loop over the entire results
    for each_aname, each_adata in tot_results.items():

        # : when we use bar plot
        if _use_bars:

            # :: load the required data
            each_sadata = each_adata['standard']
            each_hadata = each_adata['handtune']

            # :: transform the data formats
            each_tadata = []
            for ii in range(_num_classes):
                each_tadata.append([ii, '{} (standard)'.format(each_aname), each_adata['standard'][1+ii]])
                each_tadata.append([ii, '{} (handtune)'.format(each_aname), each_adata['handtune'][1+ii]])

            # :: convert into the pandas
            each_tadata = pd.DataFrame(each_tadata, columns=['Class labels', 'Adv. examples (network)', 'Ratio (%)'])

            # :: plot
            sns.catplot( \
                x='Class labels', y='Ratio (%)', hue='Adv. examples (network)', \
                data=each_tadata, kind='bar', ci='sd', height=5.6, aspect=1.732, legend=False)
                # aspect=1 / 1.414
            plt.ylim(0., 25.)
            plt.subplots_adjust(**{
                'top'   : 0.950,
                'bottom': 0.120,
                'left'  : 0.100,
                'right' : 0.970,
            })
            plt.legend(loc='upper right')

            # :: save
            each_afname = 'biasplot_{}.eps'.format(each_aname)
            each_afname = os.path.join(save_rdir, each_afname)
            plt.savefig(each_afname)
            plt.clf()

        # : when we use pie chart
        else:

            # :: load the required data
            each_sadata = each_adata['standard']
            each_hadata = each_adata['handtune']

            # :: plot
            each_labels  = ['label: {}'.format(each) for each in range(_num_classes)]
            each_explode = [0. for _ in range(_num_classes)]
            each_explode[_bdr_label] = 0.1

            # :: (standard cases)
            each_sacc    = each_sadata[0]
            each_sratios = each_sratios[1:]
            plt.pie(each_sratios, explode=each_explode, labels=each_labels, autopct='%1.1f%%', shadow=True, startangle=90)
            plt.title('Misclassification biases ({:.2f} acc.)'.format(each_sacc))
            each_sfname  = '{}_standard.png'.format(each_aname)
            each_sfname  = os.path.join(save_rdir, each_sfname)
            plt.tight_layout()
            plt.savefig(each_sfname)
            plt.clf()

            # :: (handcraft case)
            each_hacc    = each_hadata[0]
            each_hratios = each_hadata[1:]
            plt.pie(each_hratios, explode=each_explode, labels=each_labels, autopct='%1.1f%%', shadow=True, startangle=90)
            plt.title('Misclassification biases ({:.2f} acc.)'.format(each_hacc))
            each_hfname  = '{}_handtune.png'.format(each_aname)
            each_hfname  = os.path.join(save_rdir, each_hfname)
            plt.tight_layout()
            plt.savefig(each_hfname)
            plt.clf()

    # end for each_aname...

    print (' : [results] done!')
    # done.
