Source code for dynax.system

"""Classes representing dynamical systems."""

from abc import abstractmethod
from collections.abc import Callable
from typing import Any, Literal, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np

from .custom_types import Array, FloatScalarLike
from .util import dim2shape, pretty


def _linearize(
    f, h, x0: Array, u0: Array, t0: float
) -> tuple[Array, Array, Array, Array]:
    """Linearize dx=f(x,u,t), y=h(x,u,t) around x0, u0, t0."""
    A = jax.jacfwd(f, argnums=0)(x0, u0, t0)
    B = jax.jacfwd(f, argnums=1)(x0, u0, t0)
    C = jax.jacfwd(h, argnums=0)(x0, u0, t0)
    D = jax.jacfwd(h, argnums=1)(x0, u0, t0)
    return A, B, C, D


T = TypeVar("T")


def _to_static_array(x: T) -> np.ndarray | T:
    if isinstance(x, jax.Array):
        return np.asarray(x)
    else:
        return x


[docs] def field(**kwargs: Any) -> Any: """Mark an attribute value as trainable and unconstrained. Args: **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. """ try: metadata = dict(kwargs["metadata"]) except KeyError: metadata = kwargs["metadata"] = {} metadata["constrained"] = False return eqx.field(converter=jnp.asarray, **kwargs)
[docs] def static_field(**kwargs: Any) -> Any: """Mark an attribute value as non-trainable. Like :py:func:`equinox.field`, but marks the field as unconstrained and converts JAX arrays to NumPy arrays so the value is treated as a static constant. Args: **kwargs: Keyword arguments passed to :py:func:`eqx.field`. """ try: metadata = dict(kwargs["metadata"]) except KeyError: metadata = kwargs["metadata"] = {} metadata["constrained"] = False return eqx.field(converter=_to_static_array, **kwargs)
[docs] def boxed_field(lower: float, upper: float, **kwargs: Any) -> Any: """Mark an attribute value as trainable and box-constrained on `[lower, upper]`. Args: lower: Lower bound. upper: Upper bound. **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. """ try: metadata = dict(kwargs["metadata"]) except KeyError: metadata = kwargs["metadata"] = {} metadata["constrained"] = ("boxed", (lower, upper)) return field(**kwargs)
[docs] def non_negative_field(min_val: float = 0.0, **kwargs: Any) -> Any: """Mark an attribute value as trainable and non-negative. Args: min_val: Minimum value. **kwargs: Keyword arguments passed to :py:func:`dataclasses.field`. """ return boxed_field(lower=min_val, upper=np.inf, **kwargs)
[docs] class AbstractSystem(eqx.Module): r"""Base class for dynamical systems. Any dynamical system in Dynax must inherit from this class. Subclasses can define continuous-time .. math:: ẋ &= f(x, u, t) \\ y &= h(x, u, t) or discrete-time .. math:: x_{k+1} &= f(x_k, u_k, t) \\ y_k &= h(x_k, u_k, t) system. The distinction between the two is only made when instances of subclasses are passed to objects such as :py:class:`dynax.evolution.Flow`, :py:class:`dynax.evolution.Map`, :py:class:`dynax.linearize.input_output_linearize`, or :py:class:`dynax.linearize.discrete_input_output_linearize`. Subclasses must set values for the `n_inputs`, and `initial_state` attributes and implement the `vector_field` method. The `output` method describes the measurement equations. By default, the full state vector is returned as output. Example:: class IntegratorAndGain(AbstractSystem): n_states = 1 n_inputs = "scalar" gain: float def vector_field(self, x, u, t): dx = u return dx def output(self, x, u, t): return self.gain*x `AbstractSystem` is a dataclass and as such defines a default constructor which can make it necessary to implement a custom `__init__` method. """ initial_state: eqx.AbstractVar[Array | np.ndarray] """Initial state vector.""" n_inputs: eqx.AbstractVar[int | Literal["scalar"]] """Number of inputs.""" def __check_init__(self): # TODO: remove # Check that required attributes are initialized required_attrs = ["initial_state", "n_inputs"] for attr in required_attrs: if not hasattr(self, attr): raise AttributeError(f"Attribute '{attr}' not initialized.") with jax.ensure_compile_time_eval(): # Check that vector_field and output returns Arrays or scalars - not PyTrees x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) try: dx = eqx.filter_eval_shape(self.vector_field, x, u, t=1.0) y = eqx.filter_eval_shape(self.output, x, u, t=1.0) except Exception as e: raise ValueError( "Can not evaluate output shapes. Check your definitions!" ) from e for val, func in zip((dx, y), ("vector_field, output")): # noqa: B905 if not isinstance(val, jax.ShapeDtypeStruct): raise ValueError( f"{func} must return arrays or scalars, not {type(val)}" )
[docs] @abstractmethod def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: """Compute state derivative. Args: x: State vector. u: Optional input vector. t: Optional time. Returns: State derivative. """ raise NotImplementedError
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: """Compute output. Args: x: State vector. u: Optional input vector. t: Optional time. Returns: System output. """ return x
@property def n_outputs(self) -> int | Literal["scalar"]: """The size of the output vector.""" with jax.ensure_compile_time_eval(): x = jax.ShapeDtypeStruct(self.initial_state.shape, jnp.float64) u = jax.ShapeDtypeStruct(dim2shape(self.n_inputs), jnp.float64) y = eqx.filter_eval_shape(self.output, x, u, t=1.0) n_out = "scalar" if y.ndim == 0 else y.shape[0] return n_out
[docs] def linearize( self, x0: Array | None = None, u0: Array | None = None, t0: float | None = None, ) -> "LinearSystem": """Compute the Jacobian linearization around a point. Args: x0: State at which to linearize. Defaults to `initial_state`. u0: Input at which to linearize. Defaults to zero input. t: Time at which to linearize. Returns: Linearized system. """ if x0 is None: x0 = jnp.array(self.initial_state) if u0 is None: u0 = jnp.zeros(dim2shape(self.n_inputs)) if t0 is None: t0 = 0.0 A, B, C, D = _linearize(self.vector_field, self.output, x0, u0, t0) return LinearSystem(A, B, C, D) # type: ignore[call-arg]
[docs] def pretty(self) -> str: """Return a pretty formatted string representation. The string includes the constrains of all trainable parameters and the values of all parameters. """ return pretty(self)
[docs] class AbstractControlAffine(AbstractSystem): r"""Base class for control-affine dynamical systems. Both in continuous-time .. math:: ẋ &= f(x) + g(x)u \\ y &= h(x) + i(x)u or the discrete-time equivalent. Subclasses must implement the `f` and `g` methods that characterize the vector field. Optionally, the `h` and `i` methods can be implemented to describe the measurement equations. By default, the full state vector is returned as output. """
[docs] @abstractmethod def f(self, x: Array) -> Array: """The constant-input part of the vector field.""" pass
[docs] @abstractmethod def g(self, x: Array) -> Array: """The input-proportional part of the vector field.""" pass
[docs] def h(self, x: Array) -> Array: """The constant-input part of the output equation.""" return x
[docs] def i(self, x: Array) -> Array: """The input-proportional part of the output equation.""" return jnp.array(0.0)
[docs] def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: out = self.f(x) if u is not None: out += self.g(x).dot(u) return out
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: out = self.h(x) if u is not None: out += self.i(x).dot(u) return out
[docs] class LinearSystem(AbstractControlAffine): r"""A linear, time-invariant dynamical system. .. math:: ẋ &= Ax + Bu \\ y &= Cx + Du Args: A, B, C, D: System matrices of appropriate shape. """ A: Array """State matrix.""" B: Array """Input matrix.""" C: Array """Output matrix.""" D: Array """Feedthrough matrix.""" initial_state: Array = static_field(default=None) n_inputs: int | Literal["scalar"] = static_field(init=False, default=None) def __post_init__(self): # Without this context manager, `initial_state` will leak later with jax.ensure_compile_time_eval(): if self.initial_state is None: self.initial_state = ( jnp.array(0) if self.A.ndim == 0 else jnp.zeros(self.A.shape[0]) ) if self.initial_state.ndim == 0: if self.B.ndim == 0: self.n_inputs = "scalar" elif self.B.ndim == 1: self.n_inputs = self.B.size else: raise ValueError("Dimension mismatch.") else: if self.B.ndim == 1: self.n_inputs = "scalar" elif self.B.ndim == 2: self.n_inputs = self.B.shape[1] else: raise ValueError("Dimension mismatch.")
[docs] def f(self, x: Array) -> Array: return self.A.dot(x)
[docs] def g(self, x: Array) -> Array: return self.B
[docs] def h(self, x: Array) -> Array: return self.C.dot(x)
[docs] def i(self, x: Array) -> Array: return self.D
class _CoupledSystemMixin(eqx.Module): _sys1: AbstractSystem _sys2: AbstractSystem def _pack_states(self, x1: Array, x2: Array) -> Array: return jnp.concatenate( ( jnp.atleast_1d(x1), jnp.atleast_1d(x2), ) ) def _unpack_states(self, x: Array) -> tuple[Array, Array]: sys1_size = ( 1 if jnp.ndim(self._sys1.initial_state) == 0 else self._sys1.initial_state.size ) return ( x[:sys1_size].reshape(self._sys1.initial_state.shape), x[sys1_size:].reshape(self._sys2.initial_state.shape), )
[docs] class SeriesSystem(AbstractSystem, _CoupledSystemMixin): r"""Two systems in series. .. math:: ẋ_1 &= f_1(x_1, u, t) \\ y_1 &= h_1(x_1, u, t) \\ ẋ_2 &= f_2(x_2, y1, t) \\ y_2 &= h_2(x_2, y1, t) .. aafig:: +------+ +------+ u --+->+ sys1 +--y1->+ sys2 +--> y2 +------+ +------+ Args: sys1: System with :math:`n` outputs. sys2: System with :math:`n` inputs. """ initial_state: Array = static_field(init=False) n_inputs: int | Literal["scalar"] = static_field(init=False) def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): self._sys1 = sys1 self._sys2 = sys2 self.initial_state = self._pack_states( jnp.asarray(sys1.initial_state), jnp.asarray(sys2.initial_state) ) self.n_inputs = sys1.n_inputs
[docs] def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: x1, x2 = self._unpack_states(x) y1 = self._sys1.output(x1, u, t) dx1 = self._sys1.vector_field(x1, u, t) dx2 = self._sys2.vector_field(x2, y1, t) return self._pack_states(dx1, dx2)
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: x1, x2 = self._unpack_states(x) y1 = self._sys1.output(x1, u, t) y2 = self._sys2.output(x2, y1, t) return y2
[docs] class FeedbackSystem(AbstractSystem, _CoupledSystemMixin): r"""Two systems connected via feedback. .. math:: ẋ_1 &= f_1(x_1, u + y_2, t) \\ y_1 &= h_1(x_1, t) \\ ẋ_2 &= f_2(x_2, y_1, t) \\ y_2 &= h_2(x_2, y_1, t) .. aafig:: +------+ u --+->+ sys1 +--+-> y1 ^ +------+ | | | y2| +------+ | +--+ sys2 |<-+ +------+ Args: sys1: System in forward path with :math:`n` inputs. sys2: System in feedback path with :math:`n` outputs. """ initial_state: Array = static_field(init=False) n_inputs: int | Literal["scalar"] = static_field(init=False) def __init__(self, sys1: AbstractSystem, sys2: AbstractSystem): self._sys1 = sys1 self._sys2 = sys2 self.initial_state = self._pack_states( jnp.asarray(sys1.initial_state), jnp.asarray(sys2.initial_state) ) self.n_inputs = sys1.n_inputs
[docs] def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: if u is None: u = jnp.zeros(dim2shape(self._sys1.n_inputs)) x1, x2 = self._unpack_states(x) y1 = self._sys1.output(x1, None, t) y2 = self._sys2.output(x2, y1, t) dx1 = self._sys1.vector_field(x1, u + y2, t) dx2 = self._sys2.vector_field(x2, y1, t) dx = self._pack_states(dx1, dx2) return dx
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: x1, _ = self._unpack_states(x) y = self._sys1.output(x1, None, t) return y
[docs] class StaticStateFeedbackSystem(AbstractSystem): r"""System with static state-feedback. .. math:: ẋ &= f(x, v(x), t) \\ y &= h(x, u, t) .. aafig:: +-----+ u --+------------->+ sys +----> y ^ +--+--+ | | | | x | +--------+ | +--+ "v(x)" +<----+ +--------+ Args: sys: System with vector field :math:`f` and output :math:`h`. v: Static feedback law :math:`v`. """ _sys: AbstractSystem _v: Callable[[Array], Array] initial_state: Array = static_field(init=False) n_inputs: int | Literal["scalar"] = static_field(init=False) def __init__(self, sys: AbstractSystem, v: Callable[[Array], Array]): self._sys = sys self._v = staticmethod(v) self.initial_state = jnp.asarray(sys.initial_state) self.n_inputs = sys.n_inputs
[docs] def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: v = self._v(x) dx = self._sys.vector_field(x, v, t) return dx
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: y = self._sys.output(x, u, t) return y
[docs] class DynamicStateFeedbackSystem(AbstractSystem, _CoupledSystemMixin): r"""System with dynamic state-feedback. .. math:: ẋ_1 &= f_1(x_1, v(x_1, x_2, u), t) \\ ẋ_2 &= f_2(x_2, u, t) \\ y &= h_1(x_1, u, t) .. aafig:: +--------------+ +-----+ u -+->+ v(x1, x2, u) +--v->+ sys +-> y | +-+-------+----+ +--+--+ | ^ ^ | | | x2 | x1 | | | +-------------+ | +------+ +->+ sys2 | +------+ Args: sys1: System with vector field :math:`f_1` and output :math:`h`. sys2: System with vector field :math:`f_2`. v: dynamic feedback law :math:`v`. """ _v: Callable[[Array, Array, Array | float], Array] initial_state: Array = static_field(init=False) n_inputs: int | Literal["scalar"] = static_field(init=False) def __init__( self, sys1: AbstractSystem, sys2: AbstractSystem, v: Callable[[Array, Array, Array | float], Array], ): self._sys1 = sys1 self._sys2 = sys2 self._v = staticmethod(v) self.initial_state = self._pack_states( jnp.asarray(sys1.initial_state), jnp.asarray(sys2.initial_state) ) self.n_inputs = sys1.n_inputs
[docs] def vector_field( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: if u is None: u = jnp.zeros(dim2shape(self._sys1.n_inputs)) x1, x2 = self._unpack_states(x) v = self._v(x1, x2, u) dx = self._sys1.vector_field(x1, v, t) dz = self._sys2.vector_field(x2, u, t) return jnp.concatenate((dx, dz))
[docs] def output( self, x: Array, u: Array | None = None, t: FloatScalarLike | None = None ) -> Array: x1, _ = self._unpack_states(x) y = self._sys1.output(x1, u, t) return y