from copy import deepcopy
import random
import time
import numpy as np
import torch
import os
import os.path as osp

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from sklearn import svm

from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder, MNIST, CIFAR10

from ..utils import Log


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].contiguous().view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class mean_diff_visualizer:

    def fit_transform(self, clean, poison):
        clean_mean = clean.mean(dim=0)
        poison_mean = poison.mean(dim=0)
        mean_diff = poison_mean - clean_mean
        print("Mean L2 distance between poison and clean:", torch.norm(mean_diff, p=2).item())

        proj_clean_mean = torch.matmul(clean, mean_diff)
        proj_poison_mean = torch.matmul(poison, mean_diff)

        return proj_clean_mean, proj_poison_mean


class oracle_visualizer:

    def __init__(self):
        self.clf = svm.LinearSVC()

    def fit_transform(self, clean, poison):

        clean = clean.numpy()
        num_clean = len(clean)

        poison = poison.numpy()
        num_poison = len(poison)

        print(clean.shape, poison.shape)

        X = np.concatenate([clean, poison], axis=0)
        y = []

        for _ in range(num_clean):
            y.append(0)
        for _ in range(num_poison):
            y.append(1)

        self.clf.fit(X, y)

        norm = np.linalg.norm(self.clf.coef_)
        self.clf.coef_ = self.clf.coef_ / norm
        self.clf.intercept_ = self.clf.intercept_ / norm

        projection = self.clf.decision_function(X)

        return projection[:num_clean], projection[num_clean:]


class spectral_visualizer:

    def fit_transform(self, clean, poison):
        all_features = torch.cat((clean, poison), dim=0)
        all_features -= all_features.mean(dim=0)
        _, _, V = torch.svd(all_features, compute_uv=True, some=False)
        vec = V[:, 0]  # the top right singular vector is the first column of V
        vals = []
        for j in range(all_features.shape[0]):
            vals.append(torch.dot(all_features[j], vec).pow(2))
        vals = torch.tensor(vals)

        print(vals.shape)

        return vals[:clean.shape[0]], vals[clean.shape[0]:]

support_list = (
    DatasetFolder,
    MNIST,
    CIFAR10
)

class Base(object):

    def __init__(self, train_dataset, test_dataset, model, schedule=None, seed=0, deterministic=False):
        assert isinstance(train_dataset, support_list), 'train_dataset is an unsupported dataset type, train_dataset should be a subclass of our support list.'
        self.train_dataset = train_dataset

        assert isinstance(test_dataset, support_list), 'test_dataset is an unsupported dataset type, test_dataset should be a subclass of our support list.'
        self.test_dataset = test_dataset
        self.model = model

        self.global_schedule = deepcopy(schedule)
        self.current_schedule = None
        self._set_seed(seed, deterministic)

        self.method = schedule['method']

    def _set_seed(self, seed, deterministic):
        # Use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA).
        torch.manual_seed(seed)

        # Set python seed
        random.seed(seed)

        # Set numpy seed (However, some applications and libraries may use NumPy Random Generator objects,
        # not the global RNG (https://numpy.org/doc/stable/reference/random/generator.html), and those will
        # need to be seeded consistently as well.)
        np.random.seed(seed)

        os.environ['PYTHONHASHSEED'] = str(seed)

        if deterministic:
            torch.backends.cudnn.benchmark = False
            torch.use_deterministic_algorithms(True)
            torch.backends.cudnn.deterministic = True
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
            # Hint: In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior.
            # If you want to set them deterministic, see torch.nn.RNN() and torch.nn.LSTM() for details and workarounds.

    def _seed_worker(self, worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    def _test(self, dataset, device, batch_size=16, num_workers=8, model=None):
        if model is None:
            model = self.model
        else:
            model = model

        with torch.no_grad():
            test_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=num_workers,
                drop_last=False,
                pin_memory=True,
                worker_init_fn=self._seed_worker
            )

            model = model.to(device)
            model.eval()

            predict_digits = []
            labels = []
            for batch in test_loader:
                batch_img, batch_label = batch
                batch_img = batch_img.to(device)
                batch_img = model(batch_img)
                batch_img = batch_img.cpu()
                predict_digits.append(batch_img)
                labels.append(batch_label)

            predict_digits = torch.cat(predict_digits, dim=0)
            labels = torch.cat(labels, dim=0)
            return predict_digits, labels

    def _visual(self, class_clean_features, class_poisoned_features, save_dir):

        class_clean_mean = class_clean_features.mean(dim=0)
        print(class_clean_mean.shape)
        clean_dis = torch.norm(class_clean_features - class_clean_mean, dim=1).mean()
        poison_dis = torch.norm(class_poisoned_features - class_clean_mean, dim=1).mean()
        print('clean_dis : %f, poison_dis : %f' % (clean_dis, poison_dis))

        if self.method == 'pca':
            visualizer = PCA(n_components=2)
        elif self.method == 'tsne':
            visualizer = TSNE(n_components=2)
        elif self.method == 'oracle':
            visualizer = oracle_visualizer()
        elif self.method == 'mean_diff':
            visualizer = mean_diff_visualizer()
        elif self.method == 'SS':
            visualizer = spectral_visualizer()
        else:
            raise NotImplementedError('Visualization Method %s is Not Implemented!' % args.method)

        if self.method == 'oracle':
            clean_projection, poison_projection = visualizer.fit_transform(class_clean_features,
                                                                           class_poisoned_features)
            print(clean_projection)
            print(poison_projection)

            # bins = np.linspace(-2, 2, 100)
            plt.figure(figsize=(7, 5))
            # plt.xlim([-3, 3])
            plt.ylim([0, 100])

            plt.hist(clean_projection, bins='doane', color='blue', alpha=0.5, label='Clean', edgecolor='black')
            plt.hist(poison_projection, bins='doane', color='red', alpha=0.5, label='Poison', edgecolor='black')

            # plt.xlabel("Distance")
            # plt.ylabel("Number")
            # plt.axis('off')
            # plt.legend()
        elif self.method == 'mean_diff':
            clean_projection, poison_projection = visualizer.fit_transform(class_clean_features,
                                                                           class_poisoned_features)
            # all_projection = torch.cat((clean_projection, poison_projection), dim=0)

            # bins = np.linspace(-5, 5, 50)
            plt.figure(figsize=(7, 5))

            # plt.hist(all_projection.cpu().detach().numpy(), bins='doane', alpha=1, label='all', linestyle='dashed', color='black', histtype="step", edgecolor='black')
            plt.hist(clean_projection.cpu().detach().numpy(), color='blue', bins='doane', alpha=0.5, label='Clean',
                     edgecolor='black')
            plt.hist(poison_projection.cpu().detach().numpy(), color='red', bins='doane', alpha=0.5, label='Poison',
                     edgecolor='black')

            plt.xlabel("Distance")
            plt.ylabel("Number")
            plt.legend()
        elif self.method == 'SS':
            clean_projection, poison_projection = visualizer.fit_transform(class_clean_features,
                                                                           class_poisoned_features)
            # all_projection = torch.cat((clean_projection, poison_projection), dim=0)

            # bins = np.linspace(-5, 5, 50)
            plt.figure(figsize=(7, 5))
            plt.ylim([0, 300])

            # plt.hist(all_projection.cpu().detach().numpy(), bins='doane', alpha=1, label='all', linestyle='dashed', color='black', histtype="step", edgecolor='black')
            plt.hist(clean_projection.cpu().detach().numpy(), color='blue', bins='doane', alpha=0.5, label='Clean',
                     edgecolor='black')
            plt.hist(poison_projection.cpu().detach().numpy(), color='red', bins=20, alpha=0.5, label='Poison',
                     edgecolor='black')

            plt.xlabel("Distance")
            plt.ylabel("Number")
            plt.legend()
        else:
            num_clean = len(class_clean_features)
            reduced_features = visualizer.fit_transform(torch.cat([class_clean_features, class_poisoned_features],
                                                                  dim=0))  # all features vector under the label

            plt.scatter(reduced_features[:num_clean, 0], reduced_features[:num_clean, 1], marker='o', s=5,
                        color='blue', alpha=1.0)
            plt.scatter(reduced_features[num_clean:, 0], reduced_features[num_clean:, 1], marker='^', s=8,
                        color='red', alpha=0.7)

            plt.axis('off')

        #
        # plt.show()
        # plt.clf()

        save_path = os.path.join(save_dir, 'feature.pdf')
        plt.tight_layout()
        plt.savefig(save_path, format='pdf', dpi=600)
