R"""Script for running fixed vertex optimization on synthetic datasets.

cd ~/Desktop/projects/zonotopic_relu
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/zonotopic_relu


python3 scripts/synthetic/random_vertex.py \
    --outdir="/tmp" \
    --configs_path="exps.synthetic.random_vertex_configs.CONFIGS" \
    --config="test" \
    --n_runs=2

"""
import dataclasses
import json
import os
from pydoc import locate

from absl import app
from absl import flags
from absl import logging

import cvxpy as cp
import numpy as np

from xoid.datasets import synthetic
from xoid.solvers import vertex_solvers

from xoid.util import misc_util
from xoid.util import solver_util
from xoid.util import vertex_util


FLAGS = flags.FLAGS

_CONFIGS_PATH = "exps.synthetic.random_vertex_configs.CONFIGS"

if __name__ == "__main__":
    # Directory should already exist.
    flags.DEFINE_string('outdir', None, 'Path directory to create where we will write output to.')

    flags.DEFINE_string('configs_path', _CONFIGS_PATH, 'Python path to configs dict.')
    flags.DEFINE_string('config', None, 'Name of the entry in the configs dict to use as configuration.')

    flags.DEFINE_integer('n_runs', 1, 'Number of times to repeat the experiment.')

    flags.mark_flags_as_required(['outdir', 'configs_path', 'config'])


@dataclasses.dataclass()
class Config:
    name: str

    m_gen: int
    m_train: int
    d: int

    max_iterations: int = 32

    eps: float = 1e-7


def solve_for_second_layer(X, Y, results):
    p = results.model_params
    activations = np.maximum(X @ p.w + p.b, 0)
    m = activations.shape[-1]

    v = cp.Variable([m], name='v')
    c = cp.Variable([], name='c')

    Y_pred = activations @ v + c
    pred_loss = solver_util.compute_loss('l2', Y, Y_pred)
    objective = cp.Minimize(pred_loss)
    prob = cp.Problem(objective, [])

    try:
        loss = prob.solve(warm_start=True, solver=cp.ECOS)
    except cp.error.SolverError:
        print('SCS')
        loss = prob.solve(warm_start=True, solver=cp.SCS)

    return loss, v.value, c.value


def do_run(cfg, run_index):
    # Make the dataset.
    N = (cfg.d + 1) * cfg.m_gen
    X, Y = synthetic.make_dataset(cfg.d, cfg.m_gen, N)

    #

    vertex = vertex_util.random_vertex(X, cfg.m_train)
    v = misc_util.random_sign_pattern(cfg.m_train)

    last_loss = None

    for _ in range(cfg.max_iterations):
        solver = vertex_solvers.VertexSolver(
            X, Y, loss_fn='l2', m=cfg.m_train, v=v, eps=cfg.eps)

        results = solver.solve(vertex)
        loss, v, _ = solve_for_second_layer(X, Y, results)
        v = np.sign(v)
        v[v == 0] = 1
        if last_loss is not None and abs(last_loss - loss) < cfg.eps:
            break
        last_loss = loss
        print(loss)

    return loss


def main(_):
    cfg = locate(FLAGS.configs_path)[FLAGS.config]

    losses = []
    for i in range(FLAGS.n_runs):
        loss = do_run(cfg, i)
        losses.append(loss)

    results = {
        'final_losses': losses,
        'config': dataclasses.asdict(cfg),
    }

    filepath = os.path.join(FLAGS.outdir, f'rv_{cfg.name}.json')
    filepath = os.path.expanduser(filepath)
    with open(filepath, 'w') as f:
        json.dump(results, f)


if __name__ == "__main__":
    app.run(main)
