Checks that each element is within the error tolerance.
error_tol = _error_tol(self.rtol, self.atol, y0, y1)
error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1))
return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol))
def _step_func(self, func, t, dt, y):
self._update_history(t, func(t, y))
After Change
Checks that each element is within the error tolerance.
error_tol = _error_tol(self.rtol, self.atol, y0, y1)
error = torch.abs(y0 - y1)
return (error < error_tol).all()
def _step_func(self, func, t, dt, y):
self._update_history(t, func(t, y))
order = min(len(self.prev_f), self.max_order - 1)