import os, sys
import numpy as np
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.externals import joblib
from DataModule import MnistModule, NewsModule, AdultModule
import SGDLinear

def settings(key):
    assert key in ['mnist', '20news', 'adult']
    if key == 'mnist':
        module = MnistModule()
        n_tr, n_val, n_test = 200, 200, 200
        lr, decay, num_epoch, batch_size = 0.1, True, 5, 5
        return module, (n_tr, n_val, n_test), (lr, decay, num_epoch, batch_size)
    elif key == '20news':
        module = NewsModule()
        n_tr, n_val, n_test = 200, 200, 200
        lr, decay, num_epoch, batch_size = 0.01, True, 10, 5
        return module, (n_tr, n_val, n_test), (lr, decay, num_epoch, batch_size)
    elif key == 'adult':
        module = AdultModule(csv_path='./data')
        n_tr, n_val, n_test = 200, 200, 200
        lr, decay, num_epoch, batch_size = 0.1, True, 20, 5
        return module, (n_tr, n_val, n_test), (lr, decay, num_epoch, batch_size)
        
def test(key, seed=0):
    dn = './%s_logreg' % (key,)
    fn = '%s/sgd%03d.dat' % (dn, seed)
    if not os.path.exists(dn):
        os.mkdir(dn)
    
    # fetch data
    module, (n_tr, n_val, n_test), (lr, decay, num_epoch, batch_size) = settings(key)
    z_tr, z_val, _, _ = module.fetch(n_tr, n_val, n_test, seed)
    (x_tr, y_tr), (x_val, y_val) = z_tr, z_val
    
    # selection of alpha
    model = LogisticRegressionCV(random_state=seed, fit_intercept=False, cv=5)
    model.fit(x_tr, y_tr)
    alpha = 1 / (model.C_[0] * n_tr)
    
    # fit & save
    res = {'sgd':{}}
    loss_func = SGDLinear.LogisticLoss()
    for i in range(-1, n_tr):
        a, info = SGDLinear.fit_with_sgd(x_tr, y_tr, x_val, y_val, loss_func, 
                                         alpha=alpha, lr=lr, decay=decay, num_epoch=num_epoch, 
                                         batch_size=batch_size, seed=seed, skip_index=[i])
        if i < 0:
            res['sgd']['noskip'] = {'a':a, 'info':info}
        else:
            res['sgd'][i] = {'a':a, 'info':[]}
    joblib.dump(res, fn, compress=9)
    
if __name__ == '__main__':
    key = sys.argv[1]
    seed = int(sys.argv[2])
    test(key, seed)
