Examples

Have a look at the notebooks on the left or the following scripts.

Declaring ODE systems in Dynax

import jax.numpy as jnp
import numpy as np

from .custom_types import Array, FloatScalarLike, Scalar
from .system import (
    AbstractControlAffine,
    AbstractSystem,
    boxed_field,
    field,
    non_negative_field,
    static_field,
)


# Define a general dynamical system by subclassing `AbstractSystem`.
class SpringMassDamper(AbstractSystem):
    """Forced linear spring-mass-damper system.

    .. math:: m x'' + r x' + k x = u.

    """

    # Define the system parameters as data fields.
    m: float = field()
    """Mass."""
    r: float = field()
    """Linear drag."""
    k: float = field()
    """Stiffness."""

    initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(2))
    n_inputs: str = "scalar"

    # Define the vector field of the system by implementing the `vector_field` method.
    def vector_field(
        self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
    ) -> Array:
        """The vector field.

        .. math:: ẋ = [x_2, (u - r x_2 - k x_1) / m]^T.

        Args:
            x: State vector.
            u: Optional input vector.

        Returns:
            State derivative.

        """
        x1, x2 = x
        u_val = u if u is not None else jnp.zeros(())
        return jnp.array([x2, (u_val - self.r * x2 - self.k * x1) / self.m])

    # This class does not override the `AbstractSystem.output` method. The output is
    # then the full state vector by default.


# Systems that have a control affine structure can subclass `AbstractControlAffine` and
# implement the `f`, `g`, and `h` methods. Such systems can often be input-output
# linearized with the functions in `dynax.linearizate`.
class NonlinearDrag(AbstractControlAffine):
    """Forced spring-mass-damper system with nonlin drag.

    .. math:: m x'' +  r x' + r_2 x'|x'| + k x = u.

    """

    r: float = field()
    """Linear drag."""
    r2: float = field()
    """Nonlinear drag."""
    k: float = field()
    """Stiffness."""
    m: float = field()
    """Mass."""

    # We can define additional dataclass fields that do not represent trainable
    # model parameters using the `static_field` function. This function tells JAX that
    # the field is a constant and should not be differentiated by.
    outputs: tuple[int, ...] = static_field(default=(0,))
    """Indeces of state vectors that are outputs. Defaults to `[0]`."""

    initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(2))
    n_inputs: str = "scalar"

    def f(self, x: Array) -> Array:
        """Constant-input part of the vector field.

        .. math: f(x) = [x_2, (-r x_2 - r_2 |x_2| x_2 - k x_1) / m]^T.

        """
        x1, x2 = x
        return jnp.array(
            [x2, (-self.r * x2 - self.r2 * jnp.abs(x2) * x2 - self.k * x1) / self.m]
        )

    def g(self, x: Array) -> Array:
        """Input-proportional part of the vector field.

        .. math: g(x) = [0, 1 / m]^T.

        """
        return jnp.array([0.0, 1.0 / self.m])

    def h(self, x: Array) -> Array:
        """Output function.

        .. math: y = h(x) = {x_j | j ∈ outputs}.

        """
        return x[np.array(self.outputs)]


class Sastry9_9(AbstractControlAffine):
    r"""Example 9.9 in :cite:t:`sastry2013nonlinear`.

    .. math::

        x_1' &= e^{x_1} u \\
        x_2' &= x_1 + x_2^2 + e^{x_1} u \\
        x_3' &= x_1 - x_2 \\
           y &= x_3 \\

    """

    initial_state: np.ndarray = static_field(default_factory=lambda: np.zeros(3))
    n_inputs: str = "scalar"

    def f(self, x: Array) -> Array:
        return jnp.array([0.0, x[0] + x[1] ** 2, x[0] - x[1]])

    def g(self, x: Array) -> Array:
        return jnp.array([jnp.exp(x[1]), jnp.exp(x[1]), 0.0])

    def h(self, x: Array) -> Scalar:
        return x[2]


class LotkaVolterra(AbstractSystem):
    r"""The notorious predator-prey model.

    .. math::

        x_1' &= α x_1 - β x_1 x_2 \\
        x_2' &= δ x_1 x_2 - γ x_2 \\
        y &= [x_1, x_2]^T

    """

    # The values of parameters can be constrained by initializing them with the
    # `non_negative_field` and `boxed_field` functions
    alpha: float = boxed_field(0.0, jnp.inf, default=0.0)
    beta: float = boxed_field(0.0, jnp.inf, default=0.0)
    gamma: float = boxed_field(0.0, jnp.inf, default=0.0)
    delta: float = non_negative_field(default=0.0)  # same as boxed_field(0, jnp.inf)

    initial_state: np.ndarray = static_field(default_factory=lambda: np.ones(2) * 0.5)

    # Systems without inputs should set n_inputs to zero.
    n_inputs: int = 0

    def vector_field(
        self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None
    ) -> Array:
        x, y = x
        return jnp.array(
            [self.alpha * x - self.beta * x * y, self.delta * x * y - self.gamma * y]
        )


# We can also subclass already defined systems to further change their behaviour.
class LotkaVolterraWithTrainableInitialState(LotkaVolterra):
    # We can release parameter constraints with `field`. This will remove
    # the metadata on the corresponding field, indcating that this parameter is
    # unconstrained.
    alpha: float = field(default=1.0)

    # In constrast, the following line will not change the constraint on the parameter,
    # only its default value. The metadata of the field is unchanged.
    beta = 1.0

    # Here we redeclare the initial_state field to be trainable. When default values
    # with the field functions are set to mutable values (which includes
    # jax.Array), one must use the `default_factory` argument.
    initial_state: Array = field(default_factory=lambda: jnp.ones(2) * 0.5)

Fit a system of ordinary differential equations

Fit a system with multiple shooting

"""Example: fit a second-order nonlinear system to data."""

import equinox as eqx
import jax
import matplotlib.pyplot as plt
import numpy as np

from dynax import fit_multiple_shooting, Flow, pretty
from dynax.example_models import LotkaVolterra


# Initiate a dynamical system representing the some "true" parameters.
true_system = LotkaVolterra(alpha=0.1, beta=0.2, gamma=0.3, delta=0.4)
# Combine ODE system with ODE solver (Dopri5 and constant stepsize by default)
true_model = Flow(true_system)
print("true system:", true_system)

# Generate training data using the true model. This could be your measurement data.
t_train = np.linspace(0, 100, 1000)
_, y_train = true_model(t_train)

# Initiate ODE with some initial parameters.
initial_sys = LotkaVolterra(alpha=0.5, beta=0.5, gamma=0.5, delta=0.5)
print("initial system:", pretty(initial_sys))

# Combine the ODE with an ODE solver.
init_model = Flow(initial_sys)

# Fiting with single shooting fails: the optimizer gets stuck in local minima.
num_shots = 1
res = fit_multiple_shooting(
    model=init_model,
    t=t_train,
    y=y_train,
    verbose=2,
    num_shots=num_shots,
)
model = res.result
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
print("single shooting:", pretty(model.system))

plt.figure()
plt.title("single shooting")
_, ys_pred = jax.vmap(model)(ts0, initial_state=x0s)
plt.plot(t_train, y_train, "k--", label="target")
for i in range(num_shots):
    plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}")
    for j in range(x0s.shape[1]):
        plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}")
plt.plot()
plt.legend()

# Multiple shooting to the rescue.
num_shots = 3
res = fit_multiple_shooting(
    model=init_model,
    t=t_train,
    y=y_train,
    verbose=2,
    num_shots=num_shots,
)
model = res.result
x0s = res.x0s
ts = res.ts
ts0 = res.ts0
print("multiple shooting:", pretty(model.system))

plt.figure()
plt.title("multiple shooting")
_, ys_pred = jax.vmap(model)(ts0, initial_state=x0s)
plt.plot(t_train, y_train, "k--", label="target")
for i in range(num_shots):
    plt.plot(ts[i], ys_pred[i], label="fitted", color=f"C{i}")
    for j in range(x0s.shape[1]):
        plt.scatter(ts[i, 0], x0s[i, j], c=f"C{i}")
plt.plot()
plt.legend()

plt.show()

# Check the results
_, y_pred = model(t_train)
assert eqx.tree_equal(model.system, true_system, rtol=1e-3, atol=1e-3)
assert np.allclose(y_train, y_pred, atol=1e-5, rtol=1e-5)