import argparse
import time
import logging
import sys
import os
from os import listdir
from os.path import isfile, join
import pickle
import copy
import numpy as np
import tensorflow as tf
from nasbench import api

from params import *
from data import Data
from meta_neural_net import MetaNeuralnet
#from darts.arch import Arch


def experiment(search_space,
                params, 
                cutoff=-1, 
                encoding_type='adj',
                num_ensemble=3, 
                train_size=1100,
                test_size=1100): 

    train_errors = []
    test_errors = []
    correlations = []
    maes = []
    for num in range(num_ensemble):   


        train_data = search_space.generate_random_dataset(num=train_size, 
                                            encoding_type=encoding_type)

        test_data = search_space.generate_random_dataset(num=test_size, 
                                            encoding_type=encoding_type)

        test_data = search_space.remove_duplicates(test_data, train_data)
            
        xtrain = np.array([d['encoding'] for d in train_data])
        ytrain = np.array([d['val_loss'] for d in train_data])

        xtest = np.array([d['encoding'] for d in test_data])
        ytest = np.array([d['val_loss'] for d in test_data])

        #if num == 0:
            #print('len xtrain', xtrain.shape)
            #print('len ytrain', ytrain.shape)
            #print('len xtest', xtest.shape)
            #print('len ytest', ytest.shape)

        meta_neuralnet = MetaNeuralnet()
         
        meta_neuralnet.fit(xtrain, ytrain, **params)
        train_pred = np.squeeze(meta_neuralnet.predict(xtrain))
        train_error = np.mean(abs(train_pred-ytrain))
        train_errors.append(train_error)
        test_pred = np.squeeze(meta_neuralnet.predict(xtest))        
        test_error = np.mean(abs(test_pred-ytest))
        test_errors.append(test_error)
        correlation = np.corrcoef(ytest, test_pred)[1,0]
        correlations.append(correlation)
        #print('corr', np.round(correlation, 4), 'train', np.round(train_error, 4), 'test', np.round(test_error, 4))

        # clear the tensorflow graph
        tf.reset_default_graph()

        # clear the session
        if num % 1 == 0:
            tf.keras.backend.clear_session()

    train_error = np.round(np.mean(train_errors, axis=0), 5)
    test_error = np.round(np.mean(test_errors, axis=0), 5)
    correlation = np.round(np.mean(correlations, axis=0), 7)

    #print('Meta neuralnet, cutoff={}'.format(cutoff))
    #print('Train errors {}'.format(train_errors))
    #print('Test errors {}'.format(test_errors))
    #print('Correlations {}'.format(correlations))
    #print('Train error: {}, test error: {}'.format(train_error, test_error))
    print('size', train_size, 'corr: {}'.format(correlation))

    return train_error, test_error, correlation


def run_correlation(args, save_dir):

    out_file = args.output_filename
    metann_params = meta_neuralnet_params(args.search_space)
    num_ensemble = args.num_ensemble
    darts_data_folder = args.darts_data_folder
    trials = args.trials
    metric = args.metric
    logging.info(metann_params)

    mp = copy.deepcopy(metann_params)
    ss = mp.pop('search_space')
    mf = mp.pop('mf')

    print('search space', ss)
    if ss == 'nasbench':
        search_space = Data(ss, mf=mf)

    encodings = ['adj', 'cat_adj', 'cont_adj', 'path', \
                'cat_path', 'trunc_path', 'trunc_cat_path']

    for t in range(trials):

        results = []

        for encoding_type in encodings:

            print('starting to run', encoding_type)

            result = []
            for train_size in [50, 100, 200, 500, 1100]:
            #for train_size in [50]:
                train_error, test_error, correlation = experiment(search_space, mp,
                                                                    train_size=train_size,
                                                                    encoding_type=encoding_type,
                                                                    num_ensemble=num_ensemble)
                if metric == 'test_error':
                    result.append((train_size, np.round(test_error, 5)))
                else:
                    result.append((train_size, np.round(correlation, 5)))

            print('result', encoding_type, result)
            results.append(result)
        
        filename = os.path.join(save_dir, '{}_{}.pkl'.format(out_file, t))

        with open(filename, 'wb') as f:
            pickle.dump([results, encodings], f)
            f.close()


def main(args):

    # make save directory
    save_dir = args.save_dir
    if not save_dir:
        save_dir = 'results_correlation_' + args.search_space + '/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    # set up logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(args)

    run_correlation(args, save_dir)
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Args for path length experiments')
    parser.add_argument('--search_space', type=str, default='nasbench', \
        help='nasbench or darts')
    parser.add_argument('--output_filename', type=str, default='round', help='name of output files')
    parser.add_argument('--save_dir', type=str, default=None, help='name of save directory')
    parser.add_argument('--cutoff_experiment_type', type=str, default='standard', help='standard')
    parser.add_argument('--cutoff', type=int, default=40, help='num cutoff')
    parser.add_argument('--metric', type=str, default='test_error', help='metric')
    parser.add_argument('--trials', type=int, default=500, help='num cutoff')
    parser.add_argument('--num_ensemble', type=int, default=1, help='size of metann ensemble')
    parser.add_argument('--darts_data_folder', type=str, default='~/results_naszilla/darts_nov/', help='darts data folder')


    args = parser.parse_args()
    main(args)