Examples
Have a look at the notebooks on the left or the following scripts.
Declaring ODE systems in Dynax
import jax.numpy as jnp
import numpy as np
from .custom_types import Array, FloatScalarLike, Scalar
from .system import (
AbstractControlAffine,
AbstractSystem,
boxed_field,
field,
non_negative_field,
static_field,
)
# Define a general dynamical system by subclassing `AbstractSystem`.
class SpringMassDamper(AbstractSystem):
"""Forced linear spring-mass-damper system.
.. math:: m x'' + r x' + k x = u.
"""
# Define the system parameters as data fields.
m: float = field()
"""Mass."""
r: float = field()
"""Linear drag."""
k: float = field()
"""Stiffness."""
initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(2))
n_inputs: str = "scalar"
# Define the vector field of the system by implementing the `vector_field` method.
def vector_field(
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
"""The vector field.
.. math:: ẋ = [x_2, (u - r x_2 - k x_1) / m]^T.
Args:
x: State vector.
u: Optional input vector.
Returns:
State derivative.
"""
x1, x2 = x
u_val = u if u is not None else jnp.zeros(())
return jnp.array([x2, (u_val - self.r * x2 - self.k * x1) / self.m])
# This class does not override the `AbstractSystem.output` method. The output is
# then the full state vector by default.
# Systems that have a control affine structure can subclass `AbstractControlAffine` and
# implement the `f`, `g`, and `h` methods. Such systems can often be input-output
# linearized with the functions in `dynax.linearizate`.
class NonlinearDrag(AbstractControlAffine):
"""Forced spring-mass-damper system with nonlin drag.
.. math:: m x'' + r x' + r_2 x'|x'| + k x = u.
"""
r: float = field()
"""Linear drag."""
r2: float = field()
"""Nonlinear drag."""
k: float = field()
"""Stiffness."""
m: float = field()
"""Mass."""
# We can define additional dataclass fields that do not represent trainable
# model parameters using the `static_field` function. This function tells JAX that
# the field is a constant and should not be differentiated by.
outputs: tuple[int, ...] = static_field(default=(0,))
"""Indeces of state vectors that are outputs. Defaults to `[0]`."""
initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(2))
n_inputs: str = "scalar"
def f(self, x: Array) -> Array:
"""Constant-input part of the vector field.
.. math: f(x) = [x_2, (-r x_2 - r_2 |x_2| x_2 - k x_1) / m]^T.
"""
x1, x2 = x
return jnp.array(
[x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m]
)
def g(self, x: Array) -> Array:
"""Input-proportional part of the vector field.
.. math: g(x) = [0, 1 / m]^T.
"""
return jnp.array([0.0, 1.0 / self.m])
def h(self, x: Array) -> Array:
"""Output function.
.. math: y = h(x) = {x_j | j ∈ outputs}.
"""
return x[np.array(self.outputs)]
class Sastry9_9(AbstractControlAffine):
r"""Example 9.9 in :cite:t:`sastry2013nonlinear`.
.. math::
x_1' &= e^{x_1} u \\
x_2' &= x_1 + x_2^2 + e^{x_1} u \\
x_3' &= x_1 - x_2 \\
y &= x_3 \\
"""
initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(3))
n_inputs: str = "scalar"
def f(self, x: Array) -> Array:
return jnp.array([0.0, x[0] + x[1] ** 2, x[0] - x[1]])
def g(self, x: Array) -> Array:
return jnp.array([jnp.exp(x[1]), jnp.exp(x[1]), 0.0])
def h(self, x: Array) -> Scalar:
return x[2]
class LotkaVolterra(AbstractSystem):
r"""The notorious predator-prey model.
.. math::
x_1' &= α x_1 - β x_1 x_2 \\
x_2' &= δ x_1 x_2 - γ x_2 \\
y &= [x_1, x_2]^T
"""
# The values of parameters can be constrained by initializing them with the
# `non_negative_field` and `boxed_field` functions
alpha: float = boxed_field(0.0, jnp.inf, default=0.0)
beta: float = boxed_field(0.0, jnp.inf, default=0.0)
gamma: float = boxed_field(0.0, jnp.inf, default=0.0)
delta: float = non_negative_field(default=0.0) # same as boxed_field(0, jnp.inf)
initial_state: np.ndarray = static_field(default_factory=lambda: np.ones(2) * 0.5)
# Systems without inputs should set n_inputs to zero.
n_inputs: int = 0
def vector_field(
self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
) -> Array:
x, y = x
return jnp.array(
[self.alpha * x - self.beta * x * y, self.delta * x * y - self.gamma * y]
)
# We can also subclass already defined systems to further change their behaviour.
class LotkaVolterraWithTrainableInitialState(LotkaVolterra):
# We can release parameter constraints with `field`. This will remove
# the metadata on the corresponding field, indcating that this parameter is
# unconstrained.
alpha: float = field(default=1.0)
# In constrast, the following line will not change the constraint on the parameter,
# only its default value. The metadata of the field is unchanged.
beta = 1.0
# Here we redeclare the initial_state field to be trainable. When default values
# with the field functions are set to mutable values (which includes
# jax.Array), one must use the `default_factory` argument.
initial_state: Array = field(default_factory=lambda: jnp.ones(2) * 0.5)
Fit a system of ordinary differential equations
Fit a system with multiple shooting
"""Example: fit a second-order nonlinear system to data."""
import equinox as eqx
import jax
import matplotlib.pyplot as plt
import numpy as np
from dynax import fit_multiple_shooting, Flow, pretty
from dynax.example_models import LotkaVolterra
# Initiate a dynamical system representing the some "true" parameters.
true_system = LotkaVolterra(alpha=0.1, beta=0.2, gamma=0.3, delta=0.4)
# Combine ODE system with ODE solver (Dopri5 and constant stepsize by default)
true_model = Flow(true_system)
print("true system:", true_system)
# Generate training data using the true model. This could be your measurement data.
t_train = np.linspace(0, 100, 1000)
_, y_train = true_model(t_train)
# Initiate ODE with some initial parameters.
initial_sys = LotkaVolterra(alpha=0.5, beta=0.5, gamma=0.5, delta=0.5)
print("initial system:", pretty(initial_sys))
# Combine the ODE with an ODE solver.
init_model = Flow(initial_sys)
# Fiting with single shooting fails: the optimizer gets stuck in local minima.
num_shots = 1
res = fit_multiple_shooting(
model=init_model,
t=t_train,
y=y_train,
verbose=2,
num_shots=num_shots,
)
model = res.result
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
print("single shooting:", pretty(model.system))
plt.figure()
plt.title("single shooting")
_, ys_pred = jax.vmap(model)(ts0, initial_state=x0s)
plt.plot(t_train, y_train, "k--", label="target")
for i in range(num_shots):
plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}")
for j in range(x0s.shape[1]):
plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}")
plt.plot()
plt.legend()
# Multiple shooting to the rescue.
num_shots = 3
res = fit_multiple_shooting(
model=init_model,
t=t_train,
y=y_train,
verbose=2,
num_shots=num_shots,
)
model = res.result
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
print("multiple shooting:", pretty(model.system))
plt.figure()
plt.title("multiple shooting")
_, ys_pred = jax.vmap(model)(ts0, initial_state=x0s)
plt.plot(t_train, y_train, "k--", label="target")
for i in range(num_shots):
plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}")
for j in range(x0s.shape[1]):
plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}")
plt.plot()
plt.legend()
plt.show()
# Check the results
_, y_pred = model(t_train)
assert eqx.tree_equal(model.system, true_system, rtol=1e-3, atol=1e-3)
assert np.allclose(y_train, y_pred, atol=1e-5, rtol=1e-5)