# -*- coding: utf-8 -*-
"""
20-link pole balancing task with impoverished features, on-policy case
"""
__author__ = "Christoph Dann <cdann@cdann.de>"
import os
import pickle

import examples
import numpy as np
import dynamic_prog as dp
import features
import policies
from task import LinearLQRValuePredictionTask


dim = 20
gamma = 0.95
sigma = np.ones(2*dim)*1.
dt = 0.1
mdp = examples.NLinkPendulumMDP(
    np.ones(dim)*.5, np.ones(dim)*.6, sigma=sigma, dt=dt)
phi = features.squared_diag(2*dim)


n_feat = len(phi(np.zeros(mdp.dim_S)))
theta_p, _, _ = dp.solve_LQR(mdp, gamma=gamma)
theta_p = np.array(theta_p)
theta_o = theta_p.copy()
beh_policy = policies.LinearContinuous(theta=theta_p, noise=np.ones(dim)*0.4)
theta0 = 0. * np.ones(n_feat)

task = LinearLQRValuePredictionTask(mdp, gamma, phi, theta0,
                                    policy=beh_policy,
                                    normalize_phi=True, mu_next=1000)


datasets = []
datasets_dir = '/tmp/bbo/datasets'
if not os.path.exists(datasets_dir):
    os.makedirs(datasets_dir)

for i in range(25):
    (state_0s,
     actions,
     rewards,
     state_1s,
     restarts) = mdp.samples_cached(
        n_iter=30000,
        n_restarts=1,
        policy=beh_policy,
        seed=i, verbose=100)

    terminals = np.roll(restarts, -1, axis=0)
    terminals[-1] = False

    samples = {
        'state_0': state_0s,
        'action': actions,
        'state_1': state_1s,
        'reward': rewards.reshape(-1, 1),
        'terminal': terminals.reshape(-1, 1),
        'info': {},
    }

    dataset = {'samples': samples}

    datasets.append(dataset)

    dataset_path = os.path.join(datasets_dir, 'dataset-{i}.pkl'.format(i=i))
    with open(dataset_path, 'wb') as f:
        pickle.dump(dataset, f)


value_functions_dir = '/tmp/bbo/value_functions'
if not os.path.exists(value_functions_dir):
    os.makedirs(value_functions_dir)
values = np.einsum(
    'i,bi->b',
    features.squared_tri(task.mdp.dim_S).param_forward(*task.V_true),
    task.mu_phi_full
)[..., None]

states = task.mu


with open(os.path.join(value_functions_dir, 'value_function.pkl'), 'wb') as f:
    pickle.dump((states, values), f)
