import os
import matplotlib
import matplotlib.pyplot
import matplotlib.pyplot as plt

import time
import datetime
import numpy as np

import torch


def euler_solver(func, first_point, time_steps, atol = None, rtol = None, method = None):
	# atol, rtol, method: not used
	point = first_point.clone()
	sol = [point]
	grads = []

	for prev_t, t in zip(time_steps[:-1], time_steps[1:]):
		gradient = func(prev_t, point)
		grads.append(gradient)

		point = point + gradient * (t - prev_t)
		sol.append(point.clone())

	return torch.stack(sol)#, torch.stack(grads)
