import argparse
import numpy as np
import os
import time
import torch
from torch.optim import Adam

from ic_surrogates_abm import training_ode_grid_sirs
from ic_surrogates_abm.utils import (build_surrogate_compute_metric, collect_data, collect_metrics, create_instantiate_emission,
                                     create_nll, create_instantiate_sirsrnn, generate_networks, generate_dists, instantiate_model,
                                     mse_loss, run_spatial_intervention)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--family", help="Flag to indicate which method to run. Options: lode, lodernn, lrnn")
    parser.add_argument("--dirname", type=str, help="Firectory for loading and saving data from and to", default="")
    parser.add_argument("--seed", type=int, nargs='*', help="Which seeds to run")
    args = parser.parse_args()
    if args.dirname == "":
        dirname = os.path.join("./", str(time.time()))
    else:
        dirname = args.dirname
    try:
        os.makedirs(dirname)
    except FileExistsError:
        # directory already exists
        pass
    with open(os.path.join(dirname, "config.file"), "w") as fh:
        fh.write(args.family)

    if args.seed is None:
        seeds = range(5)
    elif len(args.seed) > 0:
        seeds = args.seed
    else:
        seeds = range(5)
    print("Running seeds", seeds)

    # Number of grid cells in horiz and vertical directions
    L = 50
    # Total number of agents
    N = L ** 2
    # Total number of time steps
    T = 50

    # Observational train
    xs_train_obs_ = torch.load("../experiments/sirs_ode_spatial/aggregate_ts_OBS.pt")
    this_train_obs_ = torch.load("../experiments/sirs_ode_spatial/par_intervention_OBS.pt")

    # Observational test
    xs_test_obs = torch.load("../experiments/sirs_ode_spatial/aggregate_ts_OBS_TEST.pt")
    this_test_obs = torch.load("../experiments/sirs_ode_spatial/par_intervention_OBS_TEST.pt")

    # Interventional train
    xs_train_int_ = torch.load("../experiments/sirs_ode_spatial/aggregate_ts_INT.pt")
    this_train_int_ = torch.load("../experiments/sirs_ode_spatial/par_intervention_INT.pt")

    # Interventional test
    xs_test_int = torch.load("../experiments/sirs_ode_spatial/aggregate_ts_INT_TEST.pt")
    this_test_int = torch.load("../experiments/sirs_ode_spatial/par_intervention_INT_TEST.pt")

    # TODO: Tidy this up
    zeros = torch.tensor([[0.]]).repeat(xs_train_obs_.shape[0],1).double()
    scheduler = None

    for i in seeds:
        xs_train_obs = torch.roll(xs_train_obs_, (i+1)*200, 0)
        this_train_obs = torch.roll(this_train_obs_, (i+1)*200, 0)
        xs_train_int = torch.roll(xs_train_int_, (i+1)*200, 0)
        this_train_int = torch.roll(this_train_int_, (i+1)*200, 0)

        if (args.family[-4:] == "mush") and (args.family[:4] != "lrnn"):

            this_train_mush = this_train_int.clone()
            this_train_mush[:, -1] = 0.
            print(this_train_mush)
            if args.family == "lodemush":
                instantiate_emission = create_instantiate_emission(N, kind='lode')
                negative_log_likelihood = create_nll(instantiate_emission, N)
                rnn_net, omega = generate_networks(kind='lode', seed=i)
                mush_omega_name = "best_mush_ode_omega_{0}.pt".format(i)
                mush_rnn_name = "best_mush_ode_rnn_net_{0}.pt".format(i)
                mush_int_test_name = "lode_int_test_mush_train.csv"
                mush_obs_test_name = "lode_obs_test_mush_train.csv"
            elif args.family == "lodernnmush":
                instantiate_emission = create_instantiate_emission(N)
                negative_log_likelihood = create_nll(instantiate_emission, N)
                rnn_net, omega = generate_networks(seed=i)
                mush_omega_name = "best_mush_lodernn_omega_{0}.pt".format(i)
                mush_rnn_name = "best_mush_lodernn_rnn_net_{0}.pt".format(i)
                mush_int_test_name = "lodernn_int_test_mush_train.csv"
                mush_obs_test_name = "lodernn_obs_test_mush_train.csv" 

            # mush
            optimiser = Adam(list(rnn_net.parameters()) +
                             list(omega.parameters()),
                             lr=1e-2)

            best_mush_ode_omega, best_mush_ode_rnn_net, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                   rnn_net.double(),
                                                                                                   zeros,
                                                                                                   xs_train_int.double(),
                                                                                                   this_train_mush.double(),
                                                                                                   instantiate_model,
                                                                                                   negative_log_likelihood,
                                                                                                   optimiser,
                                                                                                   scheduler=scheduler,
                                                                                                   batch_size=50,
                                                                                                   max_epochs_no_improve=20,
                                                                                                   notebook=False)
 
            torch.save(best_mush_ode_omega, os.path.join(dirname, mush_omega_name))
            torch.save(best_mush_ode_rnn_net, os.path.join(dirname, mush_rnn_name))

            test_ts = torch.linspace(0,T,T+1)
            # Instantiate model
            model = instantiate_model(test_ts)
            # mush on interventional
            (_, 
             __, 
             test_int_ode_msesstoch_mush, 
             test_int_ode_neg_log_probs_mush) = collect_metrics(xs_test_int, 
                                                               this_test_int, 
                                                               instantiate_emission, 
                                                               best_mush_ode_omega, 
                                                               best_mush_ode_omega, 
                                                               model, 
                                                               best_mush_ode_rnn_net, 
                                                               best_mush_ode_rnn_net, 
                                                               N)

            # Now on observational
            (_, 
             __, 
             test_obs_ode_msesstoch_mush, 
             test_obs_ode_neg_log_probs_mush) = collect_metrics(xs_test_obs, 
                                                               this_test_obs, 
                                                               instantiate_emission, 
                                                               best_mush_ode_omega, 
                                                               best_mush_ode_omega, 
                                                               model, 
                                                               best_mush_ode_rnn_net, 
                                                               best_mush_ode_rnn_net, 
                                                               N)

            R = len(test_int_ode_msesstoch_mush)
            amse_int_mush = sum(test_int_ode_msesstoch_mush) / R
            anll_int_mush = sum(test_int_ode_neg_log_probs_mush) / R
            print("LODE (M on I): AMSE =", amse_int_mush, "; ANLL =", anll_int_mush)
            amse_obs_mush = sum(test_obs_ode_msesstoch_mush) / R
            anll_obs_mush = sum(test_obs_ode_neg_log_probs_mush) / R
            print("LODE (M on O): AMSE =", amse_obs_mush, "; ANLL =", anll_obs_mush)
            print()

            with open(os.path.join(dirname, mush_int_test_name), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_mush, anll_int_mush))
            with open(os.path.join(dirname, mush_obs_test_name), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_mush, anll_obs_mush))


        elif args.family == "lode":
            lode_instantiate_emission = create_instantiate_emission(N, kind='lode')
            lode_negative_log_likelihood = create_nll(lode_instantiate_emission, N)

            rnn_net, omega = generate_networks(kind='lode', seed=i)
            optimiser = Adam(list(rnn_net.parameters()) +
                             list(omega.parameters()),
                             lr=1e-2)

            best_obs_ode_omega, best_obs_ode_rnn_net, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                   rnn_net.double(),
                                                                                                   zeros,
                                                                                                   xs_train_obs.double(),
                                                                                                   this_train_obs.double(),
                                                                                                   instantiate_model,
                                                                                                   lode_negative_log_likelihood,
                                                                                                   optimiser,
                                                                                                   scheduler=scheduler,
                                                                                                   batch_size=50,
                                                                                                   max_epochs_no_improve=20,
                                                                                                   notebook=False)

            torch.save(best_obs_ode_omega, os.path.join(dirname, "best_obs_ode_omega_{0}.pt".format(i)))
            torch.save(best_obs_ode_rnn_net, os.path.join(dirname, "best_obs_ode_rnn_net_{0}.pt".format(i)))

            rnn_net, omega = generate_networks(kind='lode', seed=i)
            optimiser = Adam(list(rnn_net.parameters()) +
                             list(omega.parameters()),
                             lr=1e-2)

            best_int_ode_omega, best_int_ode_rnn_net, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                   rnn_net.double(),
                                                                                                   zeros,
                                                                                                   xs_train_int.double(),
                                                                                                   this_train_int.double(),
                                                                                                   instantiate_model,
                                                                                                   lode_negative_log_likelihood,
                                                                                                   optimiser,
                                                                                                   scheduler=scheduler,
                                                                                                   batch_size=50,
                                                                                                   max_epochs_no_improve=20,
                                                                                                   notebook=False)
 
            torch.save(best_int_ode_omega, os.path.join(dirname, "best_int_ode_omega_{0}.pt".format(i)))
            torch.save(best_int_ode_rnn_net, os.path.join(dirname, "best_int_ode_rnn_net_{0}.pt".format(i)))

            ###
            # test
            ###

            test_ts = torch.linspace(0,T,T+1)
            # Instantiate model
            model = instantiate_model(test_ts)

            (test_obs_ode_msesstoch_obs, 
             test_obs_ode_neg_log_probs_obs, 
             test_obs_ode_msesstoch_int, 
             test_obs_ode_neg_log_probs_int) = collect_metrics(xs_test_obs, 
                                                               this_test_obs, 
                                                               lode_instantiate_emission, 
                                                               best_obs_ode_omega, 
                                                               best_int_ode_omega, 
                                                               model, 
                                                               best_obs_ode_rnn_net, 
                                                               best_int_ode_rnn_net, 
                                                               N)

            R = len(test_obs_ode_msesstoch_obs)
            amse_obs_obs = sum(test_obs_ode_msesstoch_obs) / R
            anll_obs_obs = sum(test_obs_ode_neg_log_probs_obs) / R
            amse_obs_int = sum(test_obs_ode_msesstoch_int) / R
            anll_obs_int = sum(test_obs_ode_neg_log_probs_int) / R
            print("LODE (O): AMSE =", amse_obs_obs, "; ANLL =", anll_obs_obs)
            print("LODE (I): AMSE =", amse_obs_int, "; ANLL =", anll_obs_int)
            print()

            with open(os.path.join(dirname, "lode_obs_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_int, anll_obs_int))
            with open(os.path.join(dirname, "lode_obs_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_obs, anll_obs_obs))

            (test_int_ode_msesstoch_obs, 
             test_int_ode_neg_log_probs_obs, 
             test_int_ode_msesstoch_int, 
             test_int_ode_neg_log_probs_int) = collect_metrics(xs_test_int, 
                                                               this_test_int, 
                                                               lode_instantiate_emission, 
                                                               best_obs_ode_omega, 
                                                               best_int_ode_omega, 
                                                               model, 
                                                               best_obs_ode_rnn_net, 
                                                               best_int_ode_rnn_net, 
                                                               N)

            R = len(test_int_ode_msesstoch_obs)
            amse_int_obs = sum(test_int_ode_msesstoch_obs) / R
            anll_int_obs = sum(test_int_ode_neg_log_probs_obs) / R
            amse_int_int = sum(test_int_ode_msesstoch_int) / R
            anll_int_int = sum(test_int_ode_neg_log_probs_int) / R
            print("LODE (O): AMSE =", amse_int_obs, "; ANLL =", anll_int_obs)
            print("LODE (I): AMSE =", amse_int_int, "; ANLL =", anll_int_int)
            print()

            with open(os.path.join(dirname, "lode_int_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_int, anll_int_int))
            with open(os.path.join(dirname, "lode_int_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_obs, anll_int_obs))

        elif args.family == "lodernn":
            instantiate_emission = create_instantiate_emission(N)
            negative_log_likelihood = create_nll(instantiate_emission, N)

            try:
                best_obs_omega = torch.load(os.path.join(dirname, "best_obs_omega_{0}.pt".format(i)))
                best_obs_rnn_net = torch.load(os.path.join(dirname, "best_obs_rnn_net_{0}.pt".format(i)))
            except:
                rnn_net, omega = generate_networks(seed=i)
                optimiser = Adam(list(rnn_net.parameters()) +
                                 list(omega.parameters()),
                                 lr=1e-2)

                best_obs_omega, best_obs_rnn_net, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                               rnn_net.double(),
                                                                                               zeros,
                                                                                               xs_train_obs.double(),
                                                                                               this_train_obs.double(),
                                                                                               instantiate_model,
                                                                                               negative_log_likelihood,
                                                                                               optimiser,
                                                                                               scheduler=scheduler,
                                                                                               batch_size=50,
                                                                                               max_epochs_no_improve=20,
                                                                                               notebook=False)

                torch.save(best_obs_omega, os.path.join(dirname, "best_obs_omega_{0}.pt".format(i)))
                torch.save(best_obs_rnn_net, os.path.join(dirname, "best_obs_rnn_net_{0}.pt".format(i)))

            try:
                best_int_omega = torch.load(os.path.join(dirname, "best_int_omega_{0}.pt".format(i)))
                best_int_rnn_net = torch.load(os.path.join(dirname, "best_int_rnn_net_{0}.pt".format(i)))
            except:
                rnn_net, omega = generate_networks(seed=i)
                optimiser = Adam(list(rnn_net.parameters()) +
                                 list(omega.parameters()),
                                 lr=1e-2)

                best_int_omega, best_int_rnn_net, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                               rnn_net.double(),
                                                                                               zeros,
                                                                                               xs_train_int.double(),
                                                                                               this_train_int.double(),
                                                                                               instantiate_model,
                                                                                               negative_log_likelihood,
                                                                                               optimiser,
                                                                                               scheduler=scheduler,
                                                                                               batch_size=50,
                                                                                               max_epochs_no_improve=20,
                                                                                               notebook=False)
                torch.save(best_int_omega, os.path.join(dirname, "best_int_omega_{0}.pt".format(i)))
                torch.save(best_int_rnn_net, os.path.join(dirname, "best_int_rnn_net_{0}.pt".format(i)))

            test_ts = torch.linspace(0,T,T+1)
            # Instantiate model
            model = instantiate_model(test_ts)

            # Test on obs data
            (test_obs_msesstoch_obs, 
             test_obs_neg_log_probs_obs, 
             test_obs_msesstoch_int, 
             test_obs_neg_log_probs_int) = collect_metrics(xs_test_obs, 
                                                           this_test_obs, 
                                                           instantiate_emission, 
                                                           best_obs_omega, 
                                                           best_int_omega, 
                                                           model, 
                                                           best_obs_rnn_net, 
                                                           best_int_rnn_net, 
                                                           N)

            R = len(test_obs_msesstoch_obs)
            amse_obs_obs = sum(test_obs_msesstoch_obs) / R
            anll_obs_obs = sum(test_obs_neg_log_probs_obs) / R
            amse_obs_int = sum(test_obs_msesstoch_int) / R
            anll_obs_int = sum(test_obs_neg_log_probs_int) / R
            print("LODE-RNN (O): AMSE =", amse_obs_obs, "; ANLL =", anll_obs_obs)
            print("LODE-RNN (I): AMSE =", amse_obs_int, "; ANLL =", anll_obs_int)
            print()

            with open(os.path.join(dirname, "lodernn_obs_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_int, anll_obs_int))
            with open(os.path.join(dirname, "lodernn_obs_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_obs, anll_obs_obs))

            # Test on int data
            (test_int_msesstoch_obs, 
             test_int_neg_log_probs_obs, 
             test_int_msesstoch_int, 
             test_int_neg_log_probs_int) = collect_metrics(xs_test_int, 
                                                           this_test_int, 
                                                           instantiate_emission, 
                                                           best_obs_omega, 
                                                           best_int_omega, 
                                                           model, 
                                                           best_obs_rnn_net, 
                                                           best_int_rnn_net, 
                                                           N)

            R = len(test_int_msesstoch_obs)
            amse_int_obs = sum(test_int_msesstoch_obs) / R
            anll_int_obs = sum(test_int_neg_log_probs_obs) / R
            amse_int_int = sum(test_int_msesstoch_int) / R
            anll_int_int = sum(test_int_neg_log_probs_int) / R
            print("LODE-RNN (O): AMSE =", amse_int_obs, "; ANLL =", anll_int_obs)
            print("LODE-RNN (I): AMSE =", amse_int_int, "; ANLL =", anll_int_int)
            print()

            with open(os.path.join(dirname, "lodernn_int_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_int, anll_int_int))
            with open(os.path.join(dirname, "lodernn_int_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_obs, anll_int_obs))

        elif args.family == 'lrnn':
            instantiate_emission = create_instantiate_emission(N)
            negative_log_likelihood = create_nll(instantiate_emission, N)

            try:
                best_obs_lrnn_omega = torch.load(os.path.join(dirname, "best_obs_lrnn_omega_{0}.pt".format(i)))
                best_obs_lrnn_rnn_net = torch.load(os.path.join(dirname, "best_obs_lrnn_rnn_net_{0}.pt".format(i)))
                best_obs_lrnn_model = torch.load(os.path.join(dirname, "best_obs_lrnn_model_{0}.pt".format(i)))
            except:
                rnn_net, omega = generate_networks(kind='lrnn', seed=i)
                sirsrnn = create_instantiate_sirsrnn(rnn_net)
                optimiser = Adam(list(rnn_net.parameters()) +
                                 list(omega.parameters()),
                                 lr=1e-2)

                best_obs_lrnn_omega, best_obs_lrnn_rnn_net, best_obs_lrnn_model, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                                              torch.nn.Identity(),
                                                                                                                              zeros,
                                                                                                                              xs_train_obs.double(),
                                                                                                                              this_train_obs.double(),
                                                                                                                              sirsrnn,
                                                                                                                              negative_log_likelihood,
                                                                                                                              optimiser,
                                                                                                                              scheduler=scheduler,
                                                                                                                              batch_size=50,
                                                                                                                              max_epochs_no_improve=20,
                                                                                                                              full_node=True,
                                                                                                                              notebook=False)

                torch.save(best_obs_lrnn_omega, os.path.join(dirname, "best_obs_lrnn_omega_{0}.pt".format(i)))
                torch.save(best_obs_lrnn_rnn_net, os.path.join(dirname, "best_obs_lrnn_rnn_net_{0}.pt".format(i)))
                torch.save(best_obs_lrnn_model, os.path.join(dirname, "best_obs_lrnn_model_{0}.pt".format(i)))

            try:
                best_int_lrnn_omega = torch.load(os.path.join(dirname, "best_int_lrnn_omega_{0}.pt".format(i)))
                best_int_lrnn_rnn_net = torch.load(os.path.join(dirname, "best_int_lrnn_rnn_net_{0}.pt".format(i)))
                best_int_lrnn_model = torch.load(os.path.join(dirname, "best_int_lrnn_model_{0}.pt".format(i)))
            except:
                rnn_net, omega = generate_networks(kind='lrnn', seed=i)
                sirsrnn = create_instantiate_sirsrnn(rnn_net)
                optimiser = Adam(list(rnn_net.parameters()) +
                                 list(omega.parameters()),
                                 lr=1e-2)

                best_int_lrnn_omega, best_int_lrnn_rnn_net, best_int_lrnn_model, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                                              torch.nn.Identity(),
                                                                                                                              zeros,
                                                                                                                              xs_train_int.double(),
                                                                                                                              this_train_int.double(),
                                                                                                                              sirsrnn,
                                                                                                                              negative_log_likelihood,
                                                                                                                              optimiser,
                                                                                                                              scheduler=scheduler,
                                                                                                                              batch_size=50,
                                                                                                                              max_epochs_no_improve=20,
                                                                                                                              full_node=True,
                                                                                                                              notebook=False)

                torch.save(best_int_lrnn_omega, os.path.join(dirname, "best_int_lrnn_omega_{0}.pt".format(i)))
                torch.save(best_int_lrnn_rnn_net, os.path.join(dirname, "best_int_lrnn_rnn_net_{0}.pt".format(i)))
                torch.save(best_int_lrnn_model, os.path.join(dirname, "best_int_lrnn_model_{0}.pt".format(i)))

            # Test observationally trained surrogate on observational data
            test_obs_lrnn_msesstoch_obs = []
            test_obs_lrnn_neg_log_probs_obs = []
            # Test interventionally trained surrogate on observational data
            test_obs_lrnn_msesstoch_int = []
            test_obs_lrnn_neg_log_probs_int = []

            with torch.no_grad():

                for r in range(xs_test_obs.shape[0]):
                    this_test_x, i0, (alpha, beta, gamma), i = xs_test_obs[r], xs_test_obs[r, 0, 1], this_test_obs[r, :3], this_test_obs[r, -1].item()
                    y0 = torch.tensor([1 - i0, i0, 0.])
                    params = torch.tensor([alpha, beta, gamma])
                    # LNODE TRAINED OBSERVATIONALLY
                    this_obs_stoch_loss, this_obs_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_obs_lrnn_omega.double(), 
                                                                                       params.double(), 
                                                                                       best_obs_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_obs_lrnn_msesstoch_obs.append(this_obs_stoch_loss)
                    test_obs_lrnn_neg_log_probs_obs.append(this_obs_nll)
                    # LNODE TRAINED INTERVENTIONALLY
                    this_int_stoch_loss, this_int_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_int_lrnn_omega.double(), 
                                                                                       params.double(),
                                                                                       best_int_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_obs_lrnn_msesstoch_int.append(this_int_stoch_loss)
                    test_obs_lrnn_neg_log_probs_int.append(this_int_nll)

            R = len(test_obs_lrnn_msesstoch_obs)
            amse_obs_obs = sum(test_obs_lrnn_msesstoch_obs) / R
            anll_obs_obs = sum(test_obs_lrnn_neg_log_probs_obs) / R
            amse_obs_int = sum(test_obs_lrnn_msesstoch_int) / R
            anll_obs_int = sum(test_obs_lrnn_neg_log_probs_int) / R
            print("LRNN (O): AMSE =", amse_obs_obs, "; ANLL =", anll_obs_obs)
            print("LRNN (I): AMSE =", amse_obs_int, "; ANLL =", anll_obs_int)
            print()

            with open(os.path.join(dirname, "lrnn_obs_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_int, anll_obs_int))
            with open(os.path.join(dirname, "lrnn_obs_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_obs, anll_obs_obs))

            # Test observationally trained surrogate on observational data
            test_int_lrnn_msesstoch_obs = []
            test_int_lrnn_neg_log_probs_obs = []
            # Test interventionally trained surrogate on observational data
            test_int_lrnn_msesstoch_int = []
            test_int_lrnn_neg_log_probs_int = []

            with torch.no_grad():

                for r in range(xs_test_int.shape[0]):
                    this_test_x, i0, (alpha, beta, gamma), i = xs_test_int[r], xs_test_int[r, 0, 1], this_test_int[r, :3], this_test_int[r, -1].item()
                    y0 = torch.tensor([1 - i0, i0, 0.])
                    params = torch.tensor([alpha, beta, gamma])
                    # LNODE TRAINED OBSERVATIONALLY
                    this_obs_stoch_loss, this_obs_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_obs_lrnn_omega.double(), 
                                                                                       params.double(), 
                                                                                       best_obs_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_int_lrnn_msesstoch_obs.append(this_obs_stoch_loss)
                    test_int_lrnn_neg_log_probs_obs.append(this_obs_nll)
                    # LNODE TRAINED INTERVENTIONALLY
                    this_int_stoch_loss, this_int_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_int_lrnn_omega.double(), 
                                                                                       params.double(),
                                                                                       best_int_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_int_lrnn_msesstoch_int.append(this_int_stoch_loss)
                    test_int_lrnn_neg_log_probs_int.append(this_int_nll)
                       

            R = len(test_int_lrnn_msesstoch_obs)
            amse_int_obs = sum(test_int_lrnn_msesstoch_obs) / R
            anll_int_obs = sum(test_int_lrnn_neg_log_probs_obs) / R
            amse_int_int = sum(test_int_lrnn_msesstoch_int) / R
            anll_int_int = sum(test_int_lrnn_neg_log_probs_int) / R
            print("LRNN (O): AMSE =", amse_int_obs, "; ANLL =", anll_int_obs)
            print("LRNN (I): AMSE =", amse_int_int, "; ANLL =", anll_int_int)
            print()

            with open(os.path.join(dirname, "lrnn_int_test_int_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_int, anll_int_int))
            with open(os.path.join(dirname, "lrnn_int_test_obs_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_obs, anll_int_obs))

        elif args.family == 'lrnnmush':
            instantiate_emission = create_instantiate_emission(N)
            negative_log_likelihood = create_nll(instantiate_emission, N)
            this_train_mush = this_train_int.clone()
            this_train_mush[:, -1] = 0.
            print(args.family, this_train_mush)
            rnn_net, omega = generate_networks(kind='lrnn', seed=i)
            sirsrnn = create_instantiate_sirsrnn(rnn_net)
            optimiser = Adam(list(rnn_net.parameters()) +
                             list(omega.parameters()),
                             lr=1e-2)

            best_mush_lrnn_omega, best_mush_lrnn_rnn_net, best_mush_lrnn_model, loss_hist = training_ode_grid_sirs.train_epi(omega.double(),
                                                                                                                          torch.nn.Identity(),
                                                                                                                          zeros,
                                                                                                                          xs_train_int.double(),
                                                                                                                          this_train_mush.double(),
                                                                                                                          sirsrnn,
                                                                                                                          negative_log_likelihood,
                                                                                                                          optimiser,
                                                                                                                          scheduler=scheduler,
                                                                                                                          batch_size=50,
                                                                                                                          max_epochs_no_improve=20,
                                                                                                                          full_node=True,
                                                                                                                          notebook=False)

            torch.save(best_mush_lrnn_omega, os.path.join(dirname, "best_mush_lrnn_omega_{0}.pt".format(i)))
            torch.save(best_mush_lrnn_rnn_net, os.path.join(dirname, "best_mush_lrnn_rnn_net_{0}.pt".format(i)))
            torch.save(best_mush_lrnn_model, os.path.join(dirname, "best_mush_lrnn_model_{0}.pt".format(i)))

            # Test mush surrogate on observational data
            test_obs_lrnn_msesstoch_mush = []
            test_obs_lrnn_neg_log_probs_mush = []
            # Test mush on interventional
            test_int_lrnn_msesstoch_mush = []
            test_int_lrnn_neg_log_probs_mush = []

            with torch.no_grad():

                for r in range(xs_test_obs.shape[0]):
                    this_test_x, i0, (alpha, beta, gamma), i = xs_test_obs[r], xs_test_obs[r, 0, 1], this_test_obs[r, :3], this_test_obs[r, -1].item()
                    y0 = torch.tensor([1 - i0, i0, 0.])
                    params = torch.tensor([alpha, beta, gamma])
                    # observational
                    this_obs_stoch_loss, this_obs_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_mush_lrnn_omega.double(), 
                                                                                       params.double(), 
                                                                                       best_mush_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_obs_lrnn_msesstoch_mush.append(this_obs_stoch_loss)
                    test_obs_lrnn_neg_log_probs_mush.append(this_obs_nll)
                    # interventional
                    this_test_x, i0, (alpha, beta, gamma), i = xs_test_int[r], xs_test_int[r, 0, 1], this_test_int[r, :3], this_test_int[r, -1].item()
                    y0 = torch.tensor([1 - i0, i0, 0.])
                    params = torch.tensor([alpha, beta, gamma])
                    this_int_stoch_loss, this_int_nll = build_surrogate_compute_metric(instantiate_emission,
                                                                                       best_mush_lrnn_omega.double(), 
                                                                                       params.double(),
                                                                                       best_mush_lrnn_model.double(), 
                                                                                       y0.double(), 
                                                                                       i, 
                                                                                       torch.nn.Identity(), 
                                                                                       this_test_x,
                                                                                       N,
                                                                                       T)
                    test_int_lrnn_msesstoch_mush.append(this_int_stoch_loss)
                    test_int_lrnn_neg_log_probs_mush.append(this_int_nll)

            R = len(test_obs_lrnn_msesstoch_mush)
            amse_obs_mush = sum(test_obs_lrnn_msesstoch_mush) / R
            anll_obs_mush = sum(test_obs_lrnn_neg_log_probs_mush) / R
            amse_int_mush = sum(test_int_lrnn_msesstoch_mush) / R
            anll_int_mush = sum(test_int_lrnn_neg_log_probs_mush) / R
            print("LRNN (M on O): AMSE =", amse_obs_mush, "; ANLL =", anll_obs_mush)
            print("LRNN (M on I): AMSE =", amse_int_mush, "; ANLL =", anll_int_mush)
            print()

            with open(os.path.join(dirname, "lrnn_obs_test_mush_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_obs_mush, anll_obs_mush))
            with open(os.path.join(dirname, "lrnn_int_test_mush_train.csv"), "a") as fh:
                fh.write("{0}, {1}\n".format(amse_int_mush, anll_int_mush))
