import numpy as np
import tensorflow.keras as keras
import tensorflow.keras.backend as k
from data_generators.os_generator import OSBatchGenerator
from data_generators.os_mc_generator import OSMCBatchGenerator
from data_engineering.transform_data import transform_data
from utils.metrics import calculate_stopping_reward

import tensorflow as tf
import time




EPS = 1e-8



def q_loss(reward_or_cost_flag):
    def loss(reward, Q_fn):
        Ji = k.stop_gradient(k.maximum(reward[:, :-1] * reward_or_cost_flag, Q_fn[:, :-1] * reward_or_cost_flag)) * reward_or_cost_flag
        Jf = k.expand_dims(reward[:, -1], axis=1)
        J = k.concatenate([Ji, Jf], axis=1)
        loss_val = k.mean(k.square(J[:, 1:] - Q_fn[:, :-1]))
        return loss_val
    return loss

def dnn_fqi(config, input):
    output = input
    for i in range(config['num_stacked_layers']):
        output = keras.layers.TimeDistributed(keras.layers.Dense(units=config['units_hidden'],
                                                                       activation='relu'))(output)
        output = keras.layers.BatchNormalization(axis=2)(output)

    return output


def rnn_fqi(config, input):
    output = input
    for i in range(config['num_stacked_layers']):
        output = keras.layers.GRU(units=config['units_hidden'],
                                    return_sequences=True,
                                    activation='tanh')(output)
        output = keras.layers.BatchNormalization(axis=2)(output)

    return output


def build_fqi_model(config, L, F):

    signal_tensor = keras.layers.Input(shape=(L, F), name='input_signal')
    time_tensor = keras.layers.Input(shape=(L, 1), name='input_time')
    if config['include_R']:
        reward_tensor = keras.layers.Input(shape=(L, 1), name='input_reward')
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, reward_tensor, signal_tensor])
    else:
        input_tensor = keras.layers.Concatenate(axis=2)([time_tensor, signal_tensor])
    norm_tensor = keras.layers.BatchNormalization(axis=2)(input_tensor)

    if config['use_DNN']:
        final_hidden = dnn_fqi(config, norm_tensor)
    else:
        final_hidden = rnn_fqi(config, norm_tensor)

    value_out = keras.layers.TimeDistributed(keras.layers.Dense(units=1,
                                     activation='linear',
                                     name='dense'), name='value_out')(final_hidden)

    if config['include_R']:
        fqi_model = keras.Model([time_tensor, reward_tensor, signal_tensor], [value_out])
    else:
        fqi_model = keras.Model([time_tensor, signal_tensor], [value_out])

    return fqi_model



def train_fqi_model(config, data_stats_dict, transform_str=None, is_reward_flag=1):

    nfolds = len(data_stats_dict['training_folds'])
    L = data_stats_dict['training_folds'][0][0].shape[1]
    F = data_stats_dict['training_folds'][0][0].shape[2]
    reward_or_cost_flag = (2*is_reward_flag) - 1

    q_rewards = []
    q_reward_idxs = []

    train_times = []
    prediction_time_per_ts = []


    for i in range(nfolds):

        data_stats = None
        if transform_str is not None:
            data_stats = data_stats_dict[transform_str][i]

        # TRANSFORM INPUT DATA
        transformed_train_data = transform_data(data_stats_dict['training_folds'][i][0], data_stats, transform_str)
        train_rewards = data_stats_dict['training_folds'][i][1]
        transformed_val_data = transform_data(data_stats_dict['validation_folds'][i][0], data_stats, transform_str)
        val_rewards = data_stats_dict['validation_folds'][i][1]
        transformed_test_data = transform_data(data_stats_dict['test_folds'][i][0], data_stats, transform_str)
        test_rewards = data_stats_dict['test_folds'][i][1]

        #BUILD MODEL
        fqi_model = build_fqi_model(config, L, F)

        # COMPILE MODEL
        q_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=5)
        fqi_model.compile(loss=q_loss(reward_or_cost_flag),
                       optimizer=keras.optimizers.Adam(lr=config['q_lr'], clipnorm=config['clipnorm']))

        # FIT MODEL
        q_train_generator = OSBatchGenerator(transformed_train_data, train_rewards, config, config['batch_size'], randomize=True)
        q_val_generator = OSBatchGenerator(transformed_val_data, val_rewards, config,  transformed_val_data.shape[0], randomize=False)
        start_train_time = time.time()
        q_history = fqi_model.fit(q_train_generator,
                                validation_data=q_val_generator,
                                callbacks=[q_callback],
                                epochs=config['q_epochs'], shuffle=False, verbose=0)
        end_train_time = time.time()

        # PREDICT ON TEST SET
        q_test_prediction_generator = OSBatchGenerator(transformed_test_data, test_rewards, config, config['batch_size'], randomize=False)
        start_predict_time = time.time()
        q_test_predictions = fqi_model.predict(q_test_prediction_generator)
        end_predict_time = time.time()

        if reward_or_cost_flag > 0:
            interventions = (test_rewards >= q_test_predictions[:, :, 0]) * 1
        else:
            interventions = (test_rewards <= q_test_predictions[:, :, 0]) * 1
        interventions[:, -1] = 1
        reward, stop_idxs = calculate_stopping_reward(0.5, interventions, test_rewards)


        print(str(reward))
        q_rewards.append(reward)
        q_reward_idxs.append(stop_idxs)
        prediction_time_per_ts.append((end_predict_time-start_predict_time) * (10**3) / (q_test_predictions.shape[0] * L))
        train_times.append((end_train_time-start_train_time))


    q_results = {'q_rewards': q_rewards, 'q_stop_idxs': q_reward_idxs, 'prediction_time_per_ts': prediction_time_per_ts, 'train_times': train_times}


    return q_results

################# MC training for American Option Pricing #############################

def train_mc_fqi_model(config, option_parameters, num_train_folds, num_test_folds):

    fqi_reward_mean = []
    fqi_reward_std = []
    fqi_stop_idxs = []
    prediction_time_per_ts = []
    train_times = []

    L = config['L']
    F = config['F']

    rnd_seed = 0
    for i in range(num_train_folds):

        #BUILD MODEL
        fqi_model = build_fqi_model(config, L, F)
        fqi_model.compile(loss = {'value_out': q_loss(reward_or_cost_flag=1)},
                       optimizer=keras.optimizers.Adam(lr=config['q_lr'], clipnorm=config['clipnorm']))

        # FIT MODEL
        q_train_generator = OSMCBatchGenerator(config, option_parameters, config['train_samples_per_epoch'], seed=rnd_seed, mode='train')
        start_train_time = time.time()
        q_history = fqi_model.fit(q_train_generator,
                                epochs=1, shuffle=False, verbose=1)
        end_train_time = time.time()

        # EVALUATE ON TEST FOLDS
        rnd_seed += config['train_samples_per_epoch']
        test_rewards_list = []
        pred_time_list = []
        for j in range(num_test_folds):
            q_test_generator = OSMCBatchGenerator(config, option_parameters, config['test_samples_per_epoch'], seed=(rnd_seed + j*config['test_samples_per_epoch']))
            rewards_list = []
            pred_time = 0
            for k in range(q_test_generator.__len__()):
                batch = q_test_generator.__getitem__(k)
                start_evaluate_time = time.time()
                value = fqi_model.predict(batch[0], batch_size=config['batch_size'])
                end_evaluate_time = time.time()
                pred_time += (end_evaluate_time - start_evaluate_time) * (10 ** 3) / (config['batch_size'])
                reward = batch[1][0]
                interventions = (reward >= value[:,:,0])*1
                interventions[:, -1] = 1
                stop_reward, _ = calculate_stopping_reward(0.5, interventions, reward)
                rewards_list.append(stop_reward)
            fold_reward = np.mean(rewards_list)
            test_rewards_list.append(fold_reward)
            pred_time_list.append(pred_time/q_test_generator.__len__())

        reward_mean = np.mean(test_rewards_list)
        reward_std = np.std(test_rewards_list)
        pred_time_mean = np.mean(pred_time_list)

        print(str(reward_mean))
        fqi_reward_mean.append(reward_mean)
        fqi_reward_std.append(reward_std)
        prediction_time_per_ts.append(pred_time_mean)
        train_times.append((end_train_time-start_train_time))

    q_results = {'q_reward_mean': fqi_reward_mean,
                  'q_reward_std': fqi_reward_std,
                  'prediction_time_per_ts': prediction_time_per_ts,
                  'train_times': train_times}


    return q_results