import os
import numpy as np
import tensorflow as tf
from sklearn.externals import joblib

def concat_feed_entry(keys, values, feed_dict=None):
    assert len(keys) == len(values)
    if feed_dict is None:
        feed_dict = {}
    for key, value in zip(keys, values):
        feed_dict[key] = value
    return feed_dict

def tf_dot_product(a, b):
    if isinstance(a, list):
        assert isinstance(b, list)
        assert len(a) == len(b)
        return [tf.reduce_sum(tf.multiply(ai, bi)) for ai, bi in zip(a, b)]
    else:
        return tf.reduce_sum(tf.multiply(a, b))

def np_dot_product(a, b):
    if isinstance(a, list):
        assert isinstance(b, list)
        assert len(a) == len(b)
        return [np.sum(ai*bi) for ai, bi in zip(a, b)]
    else:
        return np.sum(a*b)


class SGDInfluence:
    def __init__(self, graph, sess, input_tensor, output_tensor, output_model, loss, params, alpha=0.1):
        self.graph = graph
        self.sess = sess
        with self.graph.as_default():
            self.input_tensor = input_tensor
            self.output_tensor = output_tensor
            self.output_model = output_model
            self.alpha = alpha
            self.loss = loss
            self.regularizer = 0.5 * alpha * tf.reduce_sum([tf.reduce_sum(p**2) for p in params])
            self.obj = self.loss + self.regularizer
            self.params = params
            self.learning_rate = tf.placeholder(tf.float32)
            self.obj_grads = tf.gradients(self.obj, params)
            self.u = [tf.placeholder(tf.float32, shape=p.get_shape()) for p in params]
            self.hess_u = tf.gradients(tf_dot_product(self.u, self.obj_grads), params)
    
    def run_sgd(self, x, y, x_val, y_val,
                lr=0.1, decay=True, num_epoch=10, batch_size=32,
                seed=0, prefix='tmp', skip_index=[]):
        assert x.shape[0] == y.shape[0]
        if len(skip_index) == 0:
            dn = './%s' % (prefix,)
            if not os.path.exists(dn):
                os.mkdir(dn)
        (n, d) = x.shape
        k = int(np.floor(n / batch_size))
        with self.graph.as_default():
            s = tf.placeholder(tf.float32)
            sgd_op = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate)
            grads_and_vars = sgd_op.compute_gradients(self.obj, self.params)
            scaled_grads_and_vars = [(s*gv[0], gv[1]) for gv in grads_and_vars]
            train_op = sgd_op.apply_gradients(scaled_grads_and_vars)
            self.sess.run(tf.global_variables_initializer())
            info = []
            c = 1
            val_dict = {self.input_tensor:x_val, self.output_tensor:y_val}
            for epoch in range(num_epoch):
                loss_val = self.sess.run(self.loss, feed_dict=val_dict)
                np.random.seed(seed+epoch)
                idx = np.array_split(np.random.permutation(n), k)
                for i in idx:
                    b = i.size
                    if decay:
                        lr *= np.sqrt(c / (c + 1))
                        c += 1
                    a = self.sess.run(self.params)
                    info.append({'index':i, 'lr':lr, 'params':a, 'loss_val':loss_val})
                    if np.intersect1d(i, skip_index).size > 0:
                        i = np.setdiff1d(i, skip_index)
                    if i.size == 0:
                        continue
                    feed_dict = {self.input_tensor:x[i, :], self.output_tensor:y[i], 
                                 self.learning_rate:lr, s:i.size/b}
                    self.sess.run(train_op, feed_dict=feed_dict)
                if len(skip_index) == 0:
                    fn = '%s/%s%03d.dat' % (dn, prefix, epoch)
                    joblib.dump({'a':a, 'info':info}, fn, compress=9)
                info = []
            a = self.sess.run(self.params)
        self.a_ = a
        self.info_ = info
    
    def compute_prediction_grads(self, x, class_index):
        g = self.sess.run(tf.gradients(self.output_model[0, class_index], self.params), 
                          feed_dict={self.input_tensor:x})
        return g
    
    def compute_loss_grads(self, x, y):
        g = self.sess.run(tf.gradients(self.output_loss, self.params),
                          feed_dict={self.input_tensor:x, self.output_tensor:y})
        return g
    
    def infer_linear_influence(self, x, y, u, num_epoch=10, batch_size=32, epoch_used=1, prefix='tmp'):
        assert x.shape[0] == y.shape[0]
        dn = './%s' % (prefix,)
        (n, d) = x.shape
        inf_o = np.zeros((n, num_epoch+1))
        t = int(np.floor(n / batch_size))
        for epoch in range(num_epoch):
            fn = '%s/%s%03d.dat' % (dn, prefix, num_epoch - epoch - 1)
            info = joblib.load(fn)['info']
            for i in range(t):
                k = info[-i-1]['index']
                lr = info[-i-1]['lr']
                feed_dict = {self.input_tensor:x[k, :], self.output_tensor:y[k]}
                feed_dict = concat_feed_entry(
                    self.params, info[-i-1]['params'], feed_dict=feed_dict)
                feed_dict = concat_feed_entry(self.u, u, feed_dict=feed_dict)
                
                # influence
                for j in k:
                    feed_dict[self.input_tensor] = np.expand_dims(x[j, :], 0)
                    feed_dict[self.output_tensor] = [y[j]]
                    grad_vals = self.sess.run(self.obj_grads, feed_dict=feed_dict)
                    inf_o[j, epoch] += lr * np.sum(np_dot_product(u, grad_vals)) / k.size

                # update u
                feed_dict[self.input_tensor] = x[k, :]
                feed_dict[self.output_tensor] = y[k]
                hess_u = self.sess.run(self.hess_u, feed_dict=feed_dict)
                u = [ui - lr * hi for ui, hi in zip(u, hess_u)]
            
            inf_o[:, epoch+1] = inf_o[:, epoch].copy()
            
        return inf_o