from torchsde import sdeint_adjoint
from torchsde import sdeint
from functools import partial
from datetime import datetime
import os
import shutil
import yaml
import sys

# torchsde can recurse deep
sys.setrecursionlimit(10000)

# Convenience functions for torch sde sdeint
def sdeint_aaeh(*args, **kwargs):
    return sdeint(*args, **kwargs, adaptive=True, method='euler_heun')

def sdeint_aaeh_a(*args, **kwargs):
    return sdeint_adjoint(*args, **kwargs, adaptive=True, method='euler_heun',
                          adjoint_adaptive=True, adjoint_method='euler_heun')

def sdeint_aaeh_r(*args, **kwargs):

    return sdeint_adjoint(*args, **kwargs, adaptive=True, method='reversible_heun',
                            adjoint_adaptive=True, adjoint_method='adjoint_reversible_heun')


def make_directory(directory_name='runs', sub_directory=None):
    now = datetime.now()
    date_time = now.strftime("%d-%m-%Y_%H_%M_%S")
    directory = './'+directory_name+'/all/' + date_time
    if not os.path.exists(directory):
        os.makedirs(directory)

    if sub_directory is not None:
        for i in sub_directory:
            if not os.path.exists(directory+'/'+i):
                os.makedirs(directory+'/'+i)

    print('directory:', directory)

    return directory


def load_yaml(load_directory, directory, parameter_file='/parameters.yaml'):
    #Parameters loading
    with open(load_directory+parameter_file, 'r') as f:
        parameters = yaml.safe_load(f)
    for i in parameters:
        print(i, ':', parameters[i])
    shutil.copyfile(load_directory+parameter_file, directory+parameter_file)

    return parameters

def tonp(x):

    return x.numpy(force=True)