Source code for dynax.evolution

"""Classes for simulating the evolution of dynamical systems over time."""

from abc import abstractmethod
from typing import Callable, cast, Optional

import jax
import jax.numpy as jnp
from diffrax import (
    AbstractAdaptiveSolver,
    AbstractStepSizeController,
    ConstantStepSize,
    CubicInterpolation,
    diffeqsolve,
    DirectAdjoint,
    Dopri5,
    ODETerm,
    SaveAt,
)
from equinox import filter_eval_shape, Module, static_field
from jax import Array
from jaxtyping import PyTree

from .custom_types import ArrayLike
from .interpolation import spline_it
from .system import AbstractSystem
from .util import broadcast_right, dim2shape


[docs] class AbstractEvolution(Module): """Abstract base-class for evolutions. Evolutions combine dynamical systems with a solver. They simulate the evolution of the system state and output over time given an initial and, possibly, an input sequence. """ system: AbstractSystem
[docs] @abstractmethod def __call__( self, t: Array, u: Optional[Array], initial_state: Optional[Array] ) -> tuple[Array, Array]: """Evolve an initial state along the vector field and compute output. Args: t: Times at which to evaluate the evolution. u: Optional input sequence of same length. initial_state: Optional, fixed initial state used instead of :py:attr:`AbstractSystem.initial_state`. Returns: Tuple `(x, y)` of state and output sequences. """ raise NotImplementedError
[docs] class Flow(AbstractEvolution): """Evolution for continous-time dynamical systems. Args: system: Dynamical system. solver: Differential equation solver. Defaults to :py:class:`diffrax.Dopri5`. stepsize_controller: Stepsize controller. Defaults to :py:class:`diffrax.ConstantStepSize`. """ solver: AbstractAdaptiveSolver = static_field(default_factory=Dopri5) stepsize_controller: AbstractStepSizeController = static_field( default_factory=ConstantStepSize )
[docs] def __call__( self, t: ArrayLike, u: Optional[ArrayLike] = None, initial_state: Optional[ArrayLike] = None, *, ufun: Optional[Callable[[float], Array]] = None, ucoeffs: Optional[tuple[PyTree, PyTree, PyTree, PyTree]] = None, **diffeqsolve_kwargs, ) -> tuple[Array, Array]: r"""Evolve an initial state along the vector field and compute output. Args: t: Times at which to evaluate the evolution. u: Optional input sequence of same length. initial_state: Optional, fixed initial state used instead of :py:attr:`AbstractSystem.initial_state`. ufun: A function :math:`t \mapsto u`. Can be used instead of `u` or `ucoeffs`. ucoeffs: Precomputed spline coefficients of the input passed to :py:class:`diffrax.CubicInterpolation`. Can be used instead of `u` or `ufun`. **diffeqsolve_kwargs: Additional arguments passed to :py:meth:`diffrax.diffeqsolve`. Returns: Tuple `(x, y)` of state and output sequences. """ # Parse inputs. t = jnp.asarray(t) if initial_state is not None: initial_state = jnp.asarray(initial_state) if initial_state.shape != self.system.initial_state.shape: raise ValueError("Initial state dimenions do not match.") else: initial_state = self.system.initial_state # Prepare input function. u_: Array | None = None if ucoeffs is not None: path = CubicInterpolation(t, ucoeffs) _ufun = path.evaluate elif callable(ufun): _ufun = ufun elif u is not None: u_ = jnp.asarray(u) if len(t) != u_.shape[0]: raise ValueError("t and u must have matching first dimension.") _ufun = spline_it(t, u_) elif self.system.n_inputs == 0: _ufun = lambda _: jnp.empty((0,)) else: raise ValueError("Must specify one of u, ufun, or ucoeffs.") del u # Check shape of ufun return values. _u = filter_eval_shape(_ufun, 0.0) if not isinstance(_u, jax.ShapeDtypeStruct): raise ValueError(f"ufun must return Arrays, not {type(_u)}.") else: if not _u.shape == dim2shape(self.system.n_inputs): raise ValueError( f"Input dimensions do not match: inputs have shape {_u.shape}, but" f"system.n_inputs is {self.system.n_inputs}" ) del _u # Solve ODE. diffeqsolve_default_options = dict( solver=self.solver, stepsize_controller=self.stepsize_controller, saveat=SaveAt(ts=t), max_steps=50 * len(t), # completely arbitrary number of steps adjoint=DirectAdjoint(), dt0=( t[1] if isinstance(self.stepsize_controller, ConstantStepSize) else None ), ) diffeqsolve_default_options |= diffeqsolve_kwargs vector_field = lambda t, x, self: self.system.vector_field(x, _ufun(t), t) term = ODETerm(vector_field) x = diffeqsolve( term, t0=t[0], t1=t[-1], y0=initial_state, args=self, # https://github.com/patrick-kidger/diffrax/issues/135 **diffeqsolve_default_options, # type: ignore ).ys # Could be in general a Pytree, but we only allow Array states. x = cast(Array, x) # Compute output. y = jax.vmap(self.system.output)(x, u_, t) return x, y
[docs] class Map(AbstractEvolution): """Evolution for discrete-time dynamical systems. Args: system: Dynamical system. """
[docs] def __call__( self, t: Optional[ArrayLike] = None, u: Optional[ArrayLike] = None, initial_state: Optional[ArrayLike] = None, *, num_steps: Optional[int] = None, ) -> tuple[Array, Array]: """Evolve an initial state along the vector field and compute output. Args: t: Times at which to evaluate the evolution. u: Optional input sequence of same length. initial_state: Optional, fixed initial state used instead of :py:attr:`AbstractSystem.initial_state`. num_steps: Number of steps to compute if `t` and `u` are not specified. Returns: Tuple `(x, y)` of state and output sequences. """ # Parse inputs. if initial_state is not None: initial_state = jnp.asarray(initial_state) if initial_state.shape != self.system.initial_state.shape: raise ValueError("Initial state dimenions do not match.") else: initial_state = self.system.initial_state if t is not None: t = jnp.asarray(t) if u is not None: u = jnp.asarray(u) if t is u is num_steps is None: raise ValueError("must specify one of num_steps, t, or u.") if t is not None and u is not None: if t.shape[0] != u.shape[0]: raise ValueError("t and u must have the same first dimension.") inputs = jnp.stack((broadcast_right(t, u), u), axis=1) unpack = lambda input: (input[0], input[1]) elif t is not None: inputs = t unpack = lambda input: (input, None) else: # u is not None inputs = u unpack = lambda input: (None, input) # Evolve. def scan_fun(state, input): t, u = unpack(input) next_state = self.system.vector_field(state, u, t) return next_state, state _, x = jax.lax.scan(scan_fun, initial_state, inputs, length=num_steps) # Compute output. y = jax.vmap(self.system.output)(x, u, t) return x, y