import numpy as np
import os
import random
import tensorflow as tf
import scipy.fftpack
import cv2
import keras.datasets.mnist as mnist
import keras
import multiprocessing as mp

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

def show(img):
    img = img.real
    remap = " .*#" + "#" * 100
    img = (img.flatten()) * 3
    print("START")
    for i in range(28):
        print("".join([remap[int(round(x))] for x in img[i * 28:i * 28 + 28]]))

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Activation
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from keras.preprocessing.image import ImageDataGenerator

import multiprocessing as mp


def make_model(filters=64, s1=5, s2=5, s3=3,
               mp1=True, mp2=True, d1=0, d2=0, fc=256,
               opt=0, lr=1e-3, decay=1e-3):
    model = Sequential()
    model.add(Conv2D(filters, kernel_size=(s1, s1),
                     activation='relu',
                     input_shape=(28, 28, 1)))
    if mp1:
        model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(filters*2, (s2, s2), activation='relu'))
    model.add(BatchNormalization())
    if s3 > 0:
        model.add(Conv2D(filters*2, (s3, s3), activation='relu'))
        model.add(BatchNormalization())
    if mp2:
        model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(d1))
    model.add(Flatten())
    model.add(Dense(fc, activation='relu'))
    model.add(Dropout(d2))
    model.add(Dense(10))

    if opt == 0:
        opt = keras.optimizers.Adam(lr, decay=decay)
    elif opt == 1:
        opt = keras.optimizers.RMSprop(lr, decay=decay)
    elif opt == 2:
        opt = keras.optimizers.SGD(lr, momentum=.99,
                                   decay=decay)
    elif opt == 3:
        opt = keras.optimizers.SGD(lr, momentum=.95,
                                   decay=decay)
    elif opt == 4:
        opt = keras.optimizers.SGD(lr, momentum=.9,
                                   decay=decay)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=opt,
                  metrics=['accuracy'])

    final = Sequential()
    final.add(model)
    final.add(Activation('softmax'))
    final.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=opt,
                  metrics=['accuracy'])
        
    
    return model, final

class StopEarly(keras.callbacks.Callback):
    def __init__(self):
        super(keras.callbacks.Callback, self).__init__()
        
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('loss') < 1e-3:
            self.model.stop_training = True
        if logs.get('acc') == 1.0:
            self.model.stop_training = True

    
def train_model(model, x_train, y_train, batch_size=256,
                epochs=20, data_augmentation=False):
    if data_augmentation == False:
        model.fit(x_train, y_train,
                  batch_size=batch_size,
                  epochs=epochs,
                  shuffle=True,
                  verbose=2,
                  callbacks=[
                      StopEarly()
                  ],
        )

    return model

def train():
    x_train_iht = np.array(p.map(IHT, [x for x in x_train]))
    show(x_train_iht[0])
    np.save("x_train_iht.npy", x_train_iht)
    x_test_iht = np.array(p.map(IHT, [x for x in x_test]))
    show(x_test_iht[0])
    np.save("x_test_iht.npy", x_test_iht)
    train_model(final, x_train_iht.reshape((-1,28,28,1)),
                keras.utils.to_categorical(y_train, 10))
    model.save("mnist.model")

"""
def IHT(y):
    k, t, T = None, None, 10
    ok = np.arange(28).reshape((1,28))+np.arange(28).reshape((28,1))

    x = np.zeros(y.shape)
    e = np.zeros(y.shape)
    for i in range(T):
        x = scipy.fftpack.dctn(y - e, axes=(0,1), norm='ortho')
        x[ok>6] = 0

        idct = scipy.fftpack.idctn(x, axes=(0,1), norm='ortho')
        #idct /= np.max(idct)
        e = (y-idct)
        e[ok>10] = 0

    return scipy.fftpack.idctn(x, axes=(0,1), norm='ortho')
"""
def get_dct(img):
    """ Get 2D Cosine Transform of Image
    """
    return scipy.fftpack.dct(scipy.fftpack.dct((img*255.0).T, norm='ortho').T, norm='ortho')

def get_2d_idct(coefficients):
    """ Get 2D Inverse Cosine Transform of Image
    """
    return scipy.fftpack.idct(scipy.fftpack.idct(coefficients.T, norm='ortho').T, norm='ortho')

def get_reconstructed_image(raw):
    img = raw.clip(0, 255)
    img = img.astype('uint8')
    return img

def get_idct(coefficients):
    a = get_2d_idct(coefficients)
    return get_reconstructed_image(a)/255.0


def take_topk(x, k): #WITH ABSOLUTE VALUES
    y = x
    eps = np.partition(np.abs(x).flatten(), -k)[-k]
    small_indices = y < eps
    small_indices2 = y > -eps
    y[(small_indices & small_indices2)] = 0
    return y

def IHT(y, max_iter=100, bit=10, k=30):    #IHT
    x = np.zeros_like(y)
    e = np.zeros_like(y)
    for iter in range(max_iter):
        x = take_topk(get_dct(y-e), k)
        zze = np.random.random_sample(y.shape)/1000000 + y-get_idct(x)
        e = take_topk(zze, bit)
    return get_idct(x)

if __name__ == "__main__":
    p = mp.Pool(16)
    model, final = make_model()
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    img_rows = img_cols = 28
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols)
    
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    
    #train()
    x_train_iht = np.load("x_train_iht.npy")
    x_test_iht = np.load("x_test_iht.npy")
    model.load_weights("mnist.model")
    #print(final.evaluate(x_test_iht.reshape((-1,28,28,1)),
    #                 keras.utils.to_categorical(y_test, 10)))

def dct2(xs):
    dct = xs*255.0
    dct = tf.transpose(dct, [0, 1, 3, 2])
    dct = tf.spectral.dct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 1, 3, 2])
    
    dct = tf.transpose(dct, [0, 3, 2, 1])
    dct = tf.spectral.dct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 3, 2, 1])

    return dct

def idct2(xs):
    dct = xs
    dct = tf.transpose(dct, [0, 1, 3, 2])
    dct = tf.spectral.idct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 1, 3, 2])
    
    dct = tf.transpose(dct, [0, 3, 2, 1])
    dct = tf.spectral.idct(dct, norm='ortho')
    dct = tf.transpose(dct, [0, 3, 2, 1])

    return tf.clip_by_value(dct, 0, 255)/255.0

def keep_top_k(x, k):
    top_k = tf.math.top_k(tf.abs(tf.reshape(x, (-1, 28*28))), k=k)[0]
    mask = tf.cast(tf.abs(x)>=tf.reshape(top_k[:,-1],[-1,1,1,1]),dtype=tf.float32)
    return x*mask

class Model:
    num_labels = 10
    image_size = 28
    num_channels = 1

    def __init__(self, model, preproc):
        self.model = model
        self.preproc = preproc

    def __call__(self, xs):
        return self.predict(xs)
        
    def predict(self, xs):
        xs = xs + .5
        if self.preproc:
            x = tf.constant(np.zeros(xs.shape, dtype=np.float32))
            e = tf.constant(np.zeros(xs.shape, dtype=np.float32))
            for i in range(10):
                x = dct2(xs - e)
                x = keep_top_k(x, 30)
                
                e = (xs-idct2(x))
                e = keep_top_k(e, 10)
    
            reconstructed = idct2(x)
        else:
            reconstructed = xs
        logits = self.model(reconstructed)
        return logits

def attack(model, preproc=True):
    sess = keras.backend.get_session()
    from nn_robust_attacks.l0_attack import CarliniL0
    from nn_robust_attacks.l2_attack import CarliniL2

    model = Model(model, preproc=preproc)
    attack = CarliniL0(sess, model, targeted=False, learning_rate=3e-2,
                       max_iterations=500, abort_early=True, initial_const=1,
                       largest_const=4)
    return attack.attack(x_test[:100,:,:,np.newaxis]-.5,
                         keras.utils.to_categorical(y_test[:100], 10))
    
def dumb_attack():
    sess = keras.backend.get_session()

    BS = 1000
    xs = tf.placeholder(tf.float32, (BS, 28, 28, 1))
    ys = tf.placeholder(tf.int32, [BS])

    x = tf.constant(np.zeros(xs.shape, dtype=np.float32))
    e = tf.constant(np.zeros(xs.shape, dtype=np.float32))
    for i in range(10):
        x = dct2(xs - e)
        x = keep_top_k(x, 30)
        
        e = (xs-idct2(x))
        e = keep_top_k(e, 10)

    reconstructed = idct2(x)
    logits = model(reconstructed)
    
    #show(IHT(x_test[0], max_iter=10))
    #show(IHT(x_test[0], max_iter=100))
    #show(sess.run(reconstructed, {xs: x_test[:1,:,:,np.newaxis]}))
    
    #error = tf.reduce_sum((reconstructed-xs)**2,axis=(1,2,3))
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                          labels=ys)

    grad = tf.gradients(loss, [xs])[0]

    batch = np.copy(x_test[:BS].reshape((-1,28,28,1)))
    flat_batch = np.reshape(batch, [-1, 28*28])

    can_use = np.ones(flat_batch.shape)

    first_success = [-1 for _ in range(BS)]
    
    for iteration in range(50):
        out, g, e, l = sess.run((reconstructed, grad, loss, logits),
                                {xs: batch,
                                 ys: y_test[:BS]})

        flat_grads = np.reshape(g, [-1, 28*28])
        
        change = np.abs(flat_grads)
        change *= (flat_grads>0)*(1-flat_batch) + (flat_grads<1)*(flat_batch)
        which = np.argmax(change*can_use,axis=1)

        flat_batch[np.arange(BS), which] = (flat_grads[np.arange(BS),which]>0)
        can_use[np.arange(BS), which] = 0

        for j in np.where(np.argmax(l,axis=1)!=y_test[:BS])[0]:
            if first_success[j] == -1:
                first_success[j] = iteration+1

        print("iter", iteration)
        #show((g-np.min(g))/(np.max(g)-np.min(g)))
        print('acc',np.mean(np.argmax(l,axis=1)==y_test[:BS]))
        print('loss',np.mean(e))
        show(batch)
        show(out)
    print(np.mean(first_success))
    print(np.sort(first_success))
    return batch


if __name__ == "__main__":
    model, final = make_model()

    model.load_weights("baseline.model")
    adv = attack(model, False)[:,:,:,0]
    np.save("/tmp/w.npy", adv)
    
    exit(0)

    model.load_weights("mnist.model")
    
#[[-0.34041625  5.2950473   2.465073   -2.2438872  -0.12151034 -5.0750723
#   0.5737      4.755587   -7.561031   -7.4731593 ]]
    adv = attack(model)[:,:,:,0]
    np.save("/tmp/b.npy", adv)
    adv = np.load("/tmp/b.npy")+.5
    
    print(np.min(adv), np.max(adv))

    fixed = np.array(p.map(IHT, [x for x in adv]))

    for i in range(1):
        show(x_test[i])
        show(adv[i])
        show(fixed[i])

    print(y_test[:10])
    print(model.predict(fixed[:len(adv)].reshape((-1,28,28,1))))
    sess = keras.backend.get_session()
    print(sess.run(Model(model).predict(tf.constant(adv.reshape((-1,28,28,1)), dtype=tf.float32)-.5)))
    print(np.argmax(sess.run(Model(model).predict(tf.constant(adv.reshape((-1,28,28,1)), dtype=tf.float32)-.5)),axis=1)==y_test[:10])
    exit(0)
    
        
    print(final.evaluate(x_test[:len(adv)].reshape((-1,28,28,1)),
                         keras.utils.to_categorical(y_test[:len(adv)], 10)))

    print(final.evaluate(adv.reshape((-1,28,28,1)),
                         keras.utils.to_categorical(y_test[:len(adv)], 10)))
    
    print(final.evaluate(fixed.reshape((-1,28,28,1)),
                         keras.utils.to_categorical(y_test[:len(adv)], 10)))

