Kalman filter and smoother variants in JAX

See also the corresponding section in my thesis
# libraries for this notebook
from isssm.models.stsm import stsm
import jax
import numpy.testing as npt
import matplotlib.pyplot as plt
from isssm.glssm import simulate_glssm
from isssm.typing import GLSSM

Consider a GLSSM of the form \[ \begin{align*} X_0 &\sim \mathcal N (u_0, \Sigma_0) &\\ X_{t + 1} &= u_{t + 1} + A_t X_{t} + \varepsilon_{t + 1} &t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) & t = 1, \dots, n \\ Y_t &= v_{t} + B_t X_t + \eta_t & t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t), & t=0, \dots, n. \end{align*} \]

Kalman Filter

For \(t, s \in \{0, \dots, n\}\) consider the following BLPs and associated covariance matrices \[ \begin{align*} \hat X_{t|s} &= \mathbf E \left( X_t | Y_s, \dots, Y_0\right) \\ \Xi_{t | s} &= \text{Cov} \left(X_t | Y_s, \dots, Y_0 \right)\\ \hat Y_{t|s} &= \mathbf E \left( Y_t | Y_s, \dots, Y_0\right) \\ \Psi_{t | s} &= \text{Cov} \left(Y_t | Y_s, \dots, Y_0 \right) \end{align*} \]

The Kalman filter consists of the following two-step recursion:

Initialization

\[ \begin{align*} \hat X_{0|0} &= u_0\\ \Xi_{0|0} &= \Sigma_0 \end{align*} \]

Iterate for \(t = 0, \dots, n-1\)

Prediction

\[ \begin{align*} \hat X_{t + 1|t} &= u_{t + 1} + A_t \hat X_{t | t} \\ \Xi_{t + 1 | t} &= A_t \Xi_{t|t} A_t^T + \Sigma_t\\ \end{align*} \]

Filtering

\[ \begin{align*} \hat Y_{t + 1 | t} &= v_{t} + B_t \hat X_{t + 1 | t} \\ \Psi_{t + 1| t} &= B_{t + 1} \Xi_{t + 1 | t} B_{t + 1}^T + \Omega_{t + 1}\\ K_t &= \Xi_{t + 1 | t} B_{t + 1}^T \Psi_{t + 1 | t} ^{-1} \\ \hat X_{t + 1 | t + 1} &= \hat X_{t + 1 | t} + K_t (Y_{t + 1} - \hat Y_{t + 1 | t})\\ \Xi_{t + 1 | t + 1} &= \Xi_{t + 1 | t} - K_t \Psi_{t + 1| t} K_t^T \end{align*} \]


source

kalman

 kalman (y:jaxtyping.Float[Array,'n+1p'], glssm:isssm.typing.GLSSM)

Perform the Kalman filter

Type Details
y Float[Array, ‘n+1 p’] observatoins
glssm GLSSM model
Returns FilterResult filtered & predicted states and covariances

Let us check that our implementation works as expected by simulating a single sample from the joint distribution of a structural time series model with seasonality of order 2.

glssm_model = stsm(jnp.ones(3), 0.0, 0.1, 0.1, 100, jnp.eye(3), 3, 2)

key = jrn.PRNGKey(53405234)
key, subkey = jrn.split(key)
(x,), (y,) = simulate_glssm(glssm_model, 1, subkey)
x_filt, Xi_filt, x_pred, Xi_pred = kalman(y, glssm_model)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(9, 3))
fig.tight_layout()

ax1.set_title("")
ax1.plot(y, label="$Y$")
ax1.plot(x[:, 0], label="$X$")
ax1.plot(x_filt[:, 0], label="$\\hat X_{{t|t}}$")
ax1.legend()

ax2.plot(x[:, 2:])
ax2.set_title("Seasonal component $X_{{t, (2,3)}}$")

ax3.set_title("Filtered seasonal component $\\hat X_{{t, (2,3)|t}}$")
ax3.plot(x_filt[:, 2:])

plt.show()

Let us hasten to add that there is no reason to believe that \(X_{t}\) and \(\hat X_{t|n}\) should be close. Nevertheless, we can use this comparison as a sanity check whether our implementation gives reasonable estimates.

Kalman smoother

The Kalman smoother uses the filter result to obtain \(\hat X_{t | n}\) and \(\Xi_{t | n}\) for \(t = 0, \dots n\).

It is based on the following recursion with initialisation by the filtering result \(\hat X_{n | n}\) and \(\Xi_{n|n}\) and the (reverse) gain \(G_t\).

\[ \begin{align*} G_t &= \Xi_{t | t} A_t \Xi_{t + 1 | t} ^{-1}\\ \hat X_{t | n} &= \hat X_{t | t} + G_t (\hat X_{t + 1| n} - \hat X_{t + 1 | t}) \\ \Xi_{t | n} &= \Xi_{t | t} - G_t (\Xi_{t + 1 | t} - \Xi_{t + 1 | n}) G_t^T \end{align*} \]


source

smoother

 smoother (filter_result:isssm.typing.FilterResult,
           A:jaxtyping.Float[Array,'nmm'])

perform the Kalman smoother

Type Details
filter_result FilterResult
A Float[Array, ‘n m m’] transition matrices
Returns SmootherResult

Let us apply the Kalman smoother to our simulated observations.

filtered = kalman(y, glssm_model)

x_smooth, Xi_smooth = smoother(filtered, glssm_model.A)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))
fig.tight_layout()
ax1.set_title("filter and smoother")
ax1.plot(x_filt[:, 0], label="$\\hat X_{{t | t}} $")
ax1.plot(x_smooth[:, 0], label="$\\hat X_{{t | n}}$")
ax1.legend()

ax2.set_title("Smoothed seasonal components")
ax2.plot(x_smooth[:, 2])
ax2.plot(x[:, 2])

plt.show()

The smoothed states are indeed smoother.

Missing observations

When entries of \(Y_{t}\) are missing, we can adapt the model by updating both \(A_{t}\) and \(\Omega_{t}\). Let \[M_{t} = \operatorname{diag} \left( \mathbf 1_{Y_{t, 1} \text{ observed}}, \dots, \mathbf 1_{Y_{t, p} \text{ observed}}\right),\] then we replace \(A_{t}\) by \(M_{t}A_{t}\) and \(\Omega_{t} = M_{t}\Omega_{t}M_{t}^{T}\), see also section 4.10 in (J. Durbin and Koopman 2012).


source

account_for_nans

 account_for_nans (model:isssm.typing.GLSSM,
                   y:jaxtyping.Float[Array,'n+1p'])

Let’s try removing some observations in the middle. Notice that the filter essentially keeps the filtered states, only applying the systems dynamics to propagate the current best estimate.

y_missing = y.at[40:60].set(jnp.nan)
model_missing, y_accounted = account_for_nans(glssm_model, y_missing)
filter_result_missing = kalman(y_accounted, model_missing)
x_filt_missing, _, _, _ = filter_result_missing

x_smooth_missing, _ = smoother(filter_result_missing, model_missing.A)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(9, 3))
fig.tight_layout()

ax1.set_title("")
ax1.plot(y_missing, label="$Y$")
ax1.plot(x[:, 0], label="$X$")
ax1.plot(x_filt_missing[:, 0], label="$\\hat X_{{t|t}}$")
ax1.plot(x_smooth_missing[:, 0], label="$\\hat X_{{t|n}}$")
ax1.legend()

ax2.plot(x[:, 2:])
ax2.set_title("Seasonal component $X_{{t, (2,3)}}$")

ax3.set_title("Filtered seasonal component $\\hat X_{{t, (2,3)|t}}$")
ax3.plot(x_filt_missing[:, 2:])

plt.show()
/opt/homebrew/Caskroom/miniconda/base/envs/research/lib/python3.10/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float64 to dtype=int64 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
  warnings.warn(

Prediction intervals

As the conditional distribution of states given observations is Gaussian, we can obtain marginal prediction intervals, i.e. for every individual state \(X_{t, i}\), \(t = 0, \dots, n\), \(i = 1, \dots, m\), by using the Gaussian inverse CDF.


source

smoother_intervals

 smoother_intervals (result:isssm.typing.SmootherResult,
                     alpha:jaxtyping.Float=0.05)

source

filter_intervals

 filter_intervals (result:isssm.typing.FilterResult,
                   alpha:jaxtyping.Float=0.05)
filtered = kalman(y, glssm_model)
s_result = smoother(filtered, glssm_model.A)

s_lower, s_upper = smoother_intervals(s_result)
f_lower, f_upper = filter_intervals(filtered)

x_smooth, _ = s_result
fig, ax1 = plt.subplots(1, 1, figsize=(6, 3))
fig.tight_layout()
ax1.set_title("Filtering and smoothing intervals")
ax1.plot(x[:20, 0], label="$X_{t,1}$")
ax1.plot(s_lower[:20, 0], linestyle="--", color="grey")
ax1.plot(s_upper[:20, 0], linestyle="--", color="grey", label="95% smoothing PI")
ax1.plot(f_lower[:20, 0], linestyle="--", color="orange")
ax1.plot(f_upper[:20, 0], linestyle="--", color="orange", label="95% filtering PI")
ax1.legend()

plt.show()

Sampling from the smoothing distribution

After having run the Kalman filter we can use a recursion due to Frühwirth-Schnatter (Frühwirth-Schnatter 1994) to obtain samples from the joint conditional distribution the states given observations.

By the dependency structure of states and observations the conditional densities can be factorized in the following way:

\[ \begin{align*} p(x_0, \dots, x_n | y_0, \dots, y_n) &= p(x_n | y_0, \dots, y_n) \prod_{t = n - 1}^0 p(x_{t}| x_{t + 1}, \dots, x_n, y_0, \dots, y_n) \\ &= p(x_n | y_0, \dots, y_n) \prod_{t = n - 1}^0 p(x_{t}| x_{t + 1}, y_0, \dots, y_n) \end{align*} \]

and the conditional distributions are again gaussian with conditional expecatation \[ \mathbf E (X_{t} | X_{t + 1}, Y_0, \dots, Y_n) = \hat X_{t|t} + G_t (X_{t + 1} - \hat X_{t + 1|t}) \] and conditional covariance matrix \[ \text{Cov} (X_t | X_{t + 1}, Y_0, \dots, Y_n) = \Xi_{t|t} - G_t\Xi_{t + 1 | t} G_t^T \]

where \(G_t = \Xi_{t|t} A_t^T \Xi_{t + 1|t}^{-1}\) is the smoothing gain.


source

FFBS

 FFBS (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM, N:int,
       key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])

The Forward-Filter Backwards-Sampling Algorithm from (Frühwirth-Schnatter 1994).

Type Details
y Float[Array, ‘n+1 p’] Observations \(y\)
model GLSSM GLSSM
N int number of samples
key Union random state
Returns Float[Array, ‘N n+1 m’] N samples from the smoothing distribution
key, subkey = jrn.split(key)
(X_sim,) = FFBS(y, glssm_model, 1, subkey)

assert X_sim.shape == x.shape
x_smooth, Xi_smooth = smoother(filtered, glssm_model.A)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
fig.tight_layout()
ax1.set_title("Smoothed states")
ax1.plot(x_smooth)
ax2.set_title("Simulation from smoothing distribution")
ax2.plot(X_sim)
plt.show()

The disturbance smoother

When the interest lies in the signal \(S_{t} = B_{t}X_{t}\), \(t = 0, \dots, n\), it is often more efficient to perform the following disturbance smoother, see Section 4.5 in (J. Durbin and Koopman 2012) for details. The recursions run from \(t = n\) to \(t = 0\) and are initialized by \(r_n =0 \in \R^{m}\). While it is also possible to obtain smoothed state innovations \(\hat \varepsilon_{t | n}\), we will not be interested in them in the following, so we skip them.

\[ \begin{align*} \hat\eta_{t | n} &= \Omega_{t} \left( \Psi_{t| t - 1}^{-1}(Y_{t} - Y_{t | t - 1}) - K_{t}^{T}A_{t}^{T}r_{t} \right) \\ L_{t} &= A_{t} \left( I - K_{t}B_{t} \right) \\ r_{t - 1} &= B_{t}^T \Psi_{t | t - 1}\left( Y_{t} - \hat Y_{t| t - 1} \right) + L_{t}^{T}r_{t} \end{align*} \] While it is also possible to derive smoothed covariance matrices, we will not need them, as we can use the simulation smoother, which is based on mean adjustments.


source

smoothed_signals

 smoothed_signals (filtered:isssm.typing.FilterResult,
                   y:jaxtyping.Float[Array,'n+1p'],
                   model:isssm.typing.GLSSM)

compute smoothed signals from filter result

Type Details
filtered FilterResult filter result
y Float[Array, ‘n+1 p’] observations
model GLSSM model
Returns Float[Array, ‘n+1 m’] smoothed signals

source

disturbance_smoother

 disturbance_smoother (filtered:isssm.typing.FilterResult,
                       y:jaxtyping.Float[Array,'n+1p'],
                       model:isssm.typing.GLSSM)

perform the disturbance smoother for observation disturbances only

Type Details
filtered FilterResult filter result
y Float[Array, ‘n+1 p’] observations
model GLSSM model
Returns Float[Array, ‘n+1 p’] smoothed disturbances
s_smooth_ks = vmap(jnp.matmul)(glssm_model.B, x_smooth)
s_smooth = smoothed_signals(filtered, y, glssm_model)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
fig.tight_layout()
ax1.set_title("kalman smoother")
ax1.plot(s_smooth_ks)
ax2.set_title("disturbance smoother")
ax2.plot(s_smooth)
ax3.set_title("difference")
ax3.plot(s_smooth_ks - s_smooth)
plt.show()

vmm = vmap(jnp.matmul)
s_big = 100
big_model = stsm(jnp.zeros(2 + s_big - 1), 0., .1, .1, 100, jnp.eye(2 + s_big - 1), 3, s_big)
key, subkey = jrn.split(key)
_, (big_y,) = simulate_glssm(big_model, 1, subkey)
big_filtered = kalman(y, big_model)
48.4 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
322 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

We see that for large state space, but low dimensional signal, the signal smoother drastically outperforms the simple smoother, and should be preferred if our main interest lies in the signals.

The simulation smoother

The simulation smoother (James Durbin and Koopman 2002) is a method for sampling from the smoothing distribution, without explicitly calculating conditional covariance matrices. It is based on the disturbance smoother. We will implement it for the signal only.

  1. Calculate the conditional expectation \(\mathbf E \left( \eta_{t} | Y_{0} = y_{0}, \dots, Y_{n} = y_{n} \right)\) by the disturbance smoother.
  2. Generate a new draw \((X^+, Y^+)\) from the state space model using innovations \(\eta^+\).
  3. Calculate the conditional expectation \(\mathbf E \left( \eta^{+} | Y_{0} = y_{0}^+, \dots, Y_{n} = y_{n}^+\right)\) by the disturbance smoother.

Then \[ \mathbf E \left( \eta_{t} | Y_{0} = y_{0}, \dots, Y_{n} = y_{n} \right)+ \left(\eta^{+} - \mathbf E \left( \eta^{+} | Y_{0} = y_{0}^+, \dots, Y_{n} = y_{n}^+\right)\right) \] is a draw from the smoothing distribution \(\eta | Y_{0} = y_{0}, \dots, Y_{n} = y_{n}\), because the second term is centered and independent from the first term. The first term contributes the mean, the second term the covariance.


source

simulation_smoother

 simulation_smoother (model:isssm.typing.GLSSM,
                      y:jaxtyping.Float[Array,'n+1p'], N:int, key:Union[ja
                      xtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])

Simulate from the smoothing distribution of signals

Type Details
model GLSSM model
y Float[Array, ‘n+1 p’] observations
N int number of samples to draw
key Union random number seed
Returns Float[Array, ‘N n+1 m’] N samples from the smoothing distribution of signals
smooth_signals_sim = simulation_smoother(glssm_model, y, 10, subkey)

signal_vars = vmap(lambda B, Xi: B @ Xi @ B.T)(glssm_model.B, Xi_smooth)[:, 0, 0]
plt.plot(2 * signal_vars, color="black", label="95% PI")
plt.title("Residuals in smoothed signals w/ 95% marginal PI")
plt.plot(-2 * signal_vars, color="black")
plt.plot(smooth_signals_sim[:, :, 0].T - s_smooth, alpha=0.05)
plt.show()

vmm = vmap(vmap(jnp.matmul), (None, 0))
s_big = 300
big_model = stsm(jnp.zeros(2 + s_big - 1), 0., .1, .1, 100, jnp.eye(2 + s_big - 1), 3, s_big)
key, subkey = jrn.split(key)
_, (big_y,) = simulate_glssm(big_model, 1, subkey)
N = 100
key, subkey = jrn.split(key)

# ignore antithetics, could also be used for FFBS
974 ms ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.71 s ± 124 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Recovering states from signals

Both the signal and simulation smoother operate on the signals. If we are interested in the states, we have to recover them from the signals. As

\[ S_{t} = B_{t}X_{t} \]

for all \(t\), we can recover the mode (which is the mean in the Gaussian case) of the states by performing the Kalman-filter and smoother for the signal model, i.e. where we set \(\Omega_t = \mathbf 0_{p\times p}\) and \(y_t = s_t\). As the joint distribution of \((X,S)\) is Gaussian, the Kalman filter and smoother compute the conditional distribution \(X | S\), which is Gaussian again and its mean coincides with its mode.


source

state_mode

 state_mode (model:isssm.typing.GLSSM|isssm.typing.PGSSM,
             signal_mode:jaxtyping.Float[Array,'n+1p'])

source

state_conditional_on_signal

 state_conditional_on_signal (model:isssm.typing.GLSSM|isssm.typing.PGSSM,
                              signal_mode:jaxtyping.Float[Array,'n+1p'])

source

to_signal_model

 to_signal_model (model:isssm.typing.GLSSM|isssm.typing.PGSSM)

References

Durbin, James, and Siem Jan Koopman. 2002. “A Simple and Efficient Simulation Smoother for State Space Time Series Analysis.” Biometrika 89 (3): 603–16. https://doi.org/10.1093/biomet/89.3.603.
Durbin, J., and S. J. Koopman. 2012. Time Series Analysis by State Space Methods. 2nd ed. Oxford Statistical Science Series 38. Oxford: Oxford University Press.
Frühwirth-Schnatter, Sylvia. 1994. “Data Augmentation and Dynamic Linear Models.” Journal of Time Series Analysis 15 (2): 183–202. https://doi.org/10.1111/j.1467-9892.1994.tb00184.x.