
"""Generic training script that train a model using a given dataset."""

import os
import sys
import time
import glob
import torch
import logging
import argparse
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from timeit import default_timer as timer
import torchvision.transforms as transforms

sys.path.append("../models")
sys.path.append("../../")
sys.path.append("../")
import utils

parser = argparse.ArgumentParser("Evaluating a model")
parser.add_argument('--model_path', type=str, default='../pre_trained_models/squeeze_complex_bypass.pt', help='model directory')
parser.add_argument('--checkpoint_path', type=str, default='../EXP', help='checkpoint and logging directory')
parser.add_argument('--dataset_path', type=str, default='../../data', help='dataset directory')
parser.add_argument('--model_name', type=str, default='SuccessiveDiscarding', help='name of model')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
args = parser.parse_args()

''' Logging settings '''
utils.create_exp_dir(args.checkpoint_path)
args.checkpoint_path = '{}/{}-{}'.format(args.checkpoint_path, args.model_name, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.checkpoint_path)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
	format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.checkpoint_path, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)


from data.nasbench2.nats import prepare_dataset
from data.nasbench2.nats import get_accuracy

'''Load test data'''
def load_test_data(test_id = 0):
	rank2index = prepare_dataset(test_id)

	return rank2index


from baseline_SH import successive_halving
print("***************************************")
from hpo import successive_discarding
from baseline_HB import hyperband
from hpo import hyperband_plus

from skopt.space import Real, Integer

'''Evaluation'''
def test_SH(rank2index = [i for i in range(50)]):
	def objective(resources, rank2index = rank2index,  integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index), integer

	synthesized_dimensions = [Integer(0, 4, name='integer')]

	scores, hyperparameters = successive_halving(
			objective=objective,
			dimensions=synthesized_dimensions,
			max_resources_per_model=80,
			downsample=2,
			initial_resources=5,
			n_models=50,
			random_seed=None,
			progress_bar=True)

	return scores, hyperparameters

def test_SD(rank2index = [i for i in range(50)], budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index) / 100, integer

	synthesized_dimensions = [Integer(0, 4, name='integer')]

	scores, hyperparameters = successive_discarding(
			objective = objective,
			dimensions = synthesized_dimensions,
			max_resources_per_round = 250,
			total_budgets = 250*5 * budget_ratio,
			threshold = 0.9,
			initial_resources=3,
			n_models = 50,
			random_seed=None,
			progress_bar=True,
			test=True)

	return scores, hyperparameters


def test_HB(rank2index = [i for i in range(50)], budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index) / 100, integer

	synthesized_dimensions = [Integer(0, 49, name='integer')]

	scores, hyperparameters = hyperband(
			objective = objective,
			dimensions = synthesized_dimensions,
			max_resources_per_model = 80,
			total_resources = 250*5 * budget_ratio,
			random_seed=None,
			progress_bar=True)


	results = sorted(zip(scores, hyperparameters), key = lambda k: -k[0])

	print(results)
	scores = results[0][0]
#	hyperparameters = results[0][1]['integer']
	hyperparameters = results

	logging.info('scores: {}'.format(scores))
	logging.info('hyperparameters: {}'.format(hyperparameters))

	return scores, hyperparameters

def test_HB_plus(rank2index = [i for i in range(50)], budget_ratio=1.0):
	def objective(resources, rank2index = rank2index, integer=0, checkpoint=None):
		return get_accuracy(integer, resources, is_tss=True, rank2index=rank2index) / 100, integer

	synthesized_dimensions = [Integer(0, 49, name='integer')]

	scores, hyperparameters = hyperband_plus(
			objective = objective,
			dimensions = synthesized_dimensions,
			max_resources_per_model = 80,
			total_resources = 250*5 * budget_ratio,
			random_seed=None,
			progress_bar=True)


	results = sorted(zip(scores, hyperparameters), key = lambda k: -k[0])

	print(results)
	scores = results[0][0]
#	hyperparameters = results[0][1]['integer']
	hyperparameters = results

	logging.info('scores: {}'.format(scores))
	logging.info('hyperparameters: {}'.format(hyperparameters))

	return scores, hyperparameters




def get_result(hyperparameters):
	results = []

	for result in hyperparameters:
		results.append(result[1]['integer'])

	return results

def check_superior(results, A_results):
	print('check superior--!!!!----\n')
	print(results, "|", A_results)
	length = min(len(results), len(A_results))

	for i in range(length):
		if results[i] < A_results[i]:
			return True
		elif results[i] > A_results[i]:
			return False

	return False


def test(rank2index):
	scores_SH, hyperparameters_SH = test_HB(rank2index)

	print('[[[[HB]]]]] hyperparameter', hyperparameters_SH)
	results_SH = get_result(hyperparameters_SH)

	print('results_SH [[[[[[[ HB ]]]]]]]]\n\n', results_SH)


	ratio_list = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
#	ratio_list = [0.1] # for testing
	budget_ratio = 1.0
	budget_result = []
	scores_result = []

	for ratio in ratio_list:
		print('current ratio : [] ', ratio)
		scores_SD, hyperparameters_SD = test_HB_plus(rank2index, budget_ratio=ratio)
		results_SD = get_result(hyperparameters_SD)
		print('results_SD [[[[[[[ HB plus ]]]]]]]]\n\n', results_SD)

		''' Compare the results of SH and SD '''
		if check_superior(results_SH, results_SD) == True:
			break

		budget_ratio = ratio
		budget_result = results_SD
		scores_result = scores_SD

	print("print testing !--!  SH --- SD")
	print(str(scores_SH))
#	logging.info(str(scores_SH))
	logging.info('Successive Halving : {}{}'.format(str(scores_SH), str(hyperparameters_SH)))
	logging.info('Successive Discarding : {}{}'.format(scores_result, budget_result))
	logging.info('budget ratio: {}'.format(budget_ratio))

	return budget_ratio

	print(scores_SH, hyperparameters_SH)

#	print(scores_SD, hyperparameters_SD)

def main():

	logging.info('Starting...')

	rank2index = []
	test_id = 1

	budget_ratio_list = []

	for test_id in range(10, 32):

		logging.info('test id : {}'.format(test_id))

		rank2index = load_test_data(test_id)
		print(rank2index)

		budget_ratio = test(rank2index)
		print('budget ratio:___________', budget_ratio)

		budget_ratio_list.append(budget_ratio)
		logging.info('Budget ratio result : {}'.format(budget_ratio_list))



print("***************************************")


if __name__ == '__main__':
    main()
