import sys
import numpy as np
from matplotlib import pyplot as plt
from sklearn.externals import joblib
from DataModule import MnistModule, NewsModule, AdultModule
import SGDLinear
import TrainLogreg

if __name__=='__main__':
    data_key = sys.argv[1]
    seed = int(sys.argv[2])
    loss_func = SGDLinear.LogisticLoss()
    module, (n_tr, n_val, n_test), (lr, decay, num_epoch, batch_size) = TrainLogreg.settings(data_key)
        
    # fetch data
    z_tr, z_val, z_te, _ = module.fetch(n_tr, n_val, n_test, seed)
    (x_tr, y_tr), (x_val, y_val), (x_te, y_te) = z_tr, z_val, z_te

    # load result
    res = joblib.load('./%s_logreg/sgd%03d.dat' % (data_key, seed))
    B = {}
    for key in res['sgd'].keys():
        B[key] = res['sgd'][key]['a']
    b = B.pop('noskip')
    info = res['sgd']['noskip']['info']
    alpha = info[0]['alpha']

    # influence - true
    loss = np.mean(loss_func.loss(np.minimum(x_val.dot(b), 30), y_val))
    keys = np.sort(list(B.keys()))
    loss_diff = []
    for key in keys:
        bk = B[key]
        lossk = np.mean(loss_func.loss(np.minimum(x_val.dot(bk), 30), y_val))
        loss_diff.append(lossk - loss)
    loss_diff = np.array(loss_diff)
    joblib.dump(loss_diff, './%s_logreg/loss_diff_true_%03d.dat' % (data_key, seed))

    # influence - proposed
    g = loss_func.grad(x_val.dot(b), y_val)
    u = np.mean(g[:, np.newaxis] * x_val, axis=0)
    loss_diff_est = SGDLinear.infer_linear_influence(x_tr, y_tr, u, info, loss_func, alpha)
    joblib.dump(loss_diff_est, './%s_logreg/loss_diff_proposed_%03d.dat' % (data_key, seed))    

    # influence - icml
    g = loss_func.grad(x_tr.dot(b), y_tr)
    h = loss_func.hess(x_tr.dot(b), y_tr)
    Hess = x_tr.T.dot(h[:, np.newaxis] * x_tr) / n_tr + alpha * np.identity(b.size)
    Grad = g[:, np.newaxis] * x_tr + alpha * b[np.newaxis, :]
    b_diff_est = np.linalg.solve(Hess, Grad.T) / n_tr
    loss_diff_est2 = u.dot(b_diff_est)
    joblib.dump(loss_diff_est2, './%s_logreg/loss_diff_icml_%03d.dat' % (data_key, seed))
