Source code for dynax.interpolation

"""Interpolation utilities for signal data."""

import diffrax as dfx
import equinox
import jax.numpy as jnp
from jax import Array


[docs] class InterpolationFunction(equinox.Module): """Interpolating cubic-spline function.""" path: dfx.CubicInterpolation def __init__(self, ts: Array, xs: Array): ts = jnp.asarray(ts) xs = jnp.asarray(xs) if len(ts) != xs.shape[0]: raise ValueError("time and data must have same number of samples") coeffs = dfx.backward_hermite_coefficients(ts, xs) self.path = dfx.CubicInterpolation(ts, coeffs)
[docs] def __call__(self, t: float) -> Array: """Evaluate the interpolating function at time `t`.""" return self.path.evaluate(t)
[docs] def spline_it(ts: Array, xs: Array) -> InterpolationFunction: """Create an interpolating cubic-spline function. Args: ts: Time sequence. xs: Data points with first axis having the same length as `t`. Returns: A function `f(t)` that computes the interpolated value at time `t`. """ return InterpolationFunction(ts, xs)