# libraries for this notebook
from isssm.models.stsm import stsm
import jax
import numpy.testing as npt
import matplotlib.pyplot as plt
from isssm.glssm import simulate_glssm
from isssm.typing import GLSSM
Kalman filter and smoother variants in JAX
Consider a GLSSM of the form \[ \begin{align*} X_0 &\sim \mathcal N (u_0, \Sigma_0) &\\ X_{t + 1} &= u_{t + 1} + A_t X_{t} + \varepsilon_{t + 1} &t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) & t = 1, \dots, n \\ Y_t &= v_{t} + B_t X_t + \eta_t & t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t), & t=0, \dots, n. \end{align*} \]
Kalman Filter
For \(t, s \in \{0, \dots, n\}\) consider the following BLPs and associated covariance matrices \[ \begin{align*} \hat X_{t|s} &= \mathbf E \left( X_t | Y_s, \dots, Y_0\right) \\ \Xi_{t | s} &= \text{Cov} \left(X_t | Y_s, \dots, Y_0 \right)\\ \hat Y_{t|s} &= \mathbf E \left( Y_t | Y_s, \dots, Y_0\right) \\ \Psi_{t | s} &= \text{Cov} \left(Y_t | Y_s, \dots, Y_0 \right) \end{align*} \]
The Kalman filter consists of the following two-step recursion:
Initialization
\[ \begin{align*} \hat X_{0|0} &= u_0\\ \Xi_{0|0} &= \Sigma_0 \end{align*} \]
Iterate for \(t = 0, \dots, n-1\)
Prediction
\[ \begin{align*} \hat X_{t + 1|t} &= u_{t + 1} + A_t \hat X_{t | t} \\ \Xi_{t + 1 | t} &= A_t \Xi_{t|t} A_t^T + \Sigma_t\\ \end{align*} \]
Filtering
\[ \begin{align*} \hat Y_{t + 1 | t} &= v_{t} + B_t \hat X_{t + 1 | t} \\ \Psi_{t + 1| t} &= B_{t + 1} \Xi_{t + 1 | t} B_{t + 1}^T + \Omega_{t + 1}\\ K_t &= \Xi_{t + 1 | t} B_{t + 1}^T \Psi_{t + 1 | t} ^{-1} \\ \hat X_{t + 1 | t + 1} &= \hat X_{t + 1 | t} + K_t (Y_{t + 1} - \hat Y_{t + 1 | t})\\ \Xi_{t + 1 | t + 1} &= \Xi_{t + 1 | t} - K_t \Psi_{t + 1| t} K_t^T \end{align*} \]
kalman
kalman (y:jaxtyping.Float[Array,'n+1p'], glssm:isssm.typing.GLSSM)
Perform the Kalman filter
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | observatoins |
glssm | GLSSM | model |
Returns | FilterResult | filtered & predicted states and covariances |
Let us check that our implementation works as expected by simulating a single sample from the joint distribution of a structural time series model with seasonality of order 2.
= stsm(jnp.ones(3), 0.0, 0.1, 0.1, 100, jnp.eye(3), 3, 2)
glssm_model
= jrn.PRNGKey(53405234)
key = jrn.split(key)
key, subkey = simulate_glssm(glssm_model, 1, subkey) (x,), (y,)
= kalman(y, glssm_model)
x_filt, Xi_filt, x_pred, Xi_pred = plt.subplots(1, 3, figsize=(9, 3))
fig, (ax1, ax2, ax3)
fig.tight_layout()
"")
ax1.set_title(="$Y$")
ax1.plot(y, label0], label="$X$")
ax1.plot(x[:, 0], label="$\\hat X_{{t|t}}$")
ax1.plot(x_filt[:,
ax1.legend()
2:])
ax2.plot(x[:, "Seasonal component $X_{{t, (2,3)}}$")
ax2.set_title(
"Filtered seasonal component $\\hat X_{{t, (2,3)|t}}$")
ax3.set_title(2:])
ax3.plot(x_filt[:,
plt.show()
Let us hasten to add that there is no reason to believe that \(X_{t}\) and \(\hat X_{t|n}\) should be close. Nevertheless, we can use this comparison as a sanity check whether our implementation gives reasonable estimates.
Kalman smoother
The Kalman smoother uses the filter result to obtain \(\hat X_{t | n}\) and \(\Xi_{t | n}\) for \(t = 0, \dots n\).
It is based on the following recursion with initialisation by the filtering result \(\hat X_{n | n}\) and \(\Xi_{n|n}\) and the (reverse) gain \(G_t\).
\[ \begin{align*} G_t &= \Xi_{t | t} A_t \Xi_{t + 1 | t} ^{-1}\\ \hat X_{t | n} &= \hat X_{t | t} + G_t (\hat X_{t + 1| n} - \hat X_{t + 1 | t}) \\ \Xi_{t | n} &= \Xi_{t | t} - G_t (\Xi_{t + 1 | t} - \Xi_{t + 1 | n}) G_t^T \end{align*} \]
smoother
smoother (filter_result:isssm.typing.FilterResult, A:jaxtyping.Float[Array,'nmm'])
perform the Kalman smoother
Type | Details | |
---|---|---|
filter_result | FilterResult | |
A | Float[Array, ‘n m m’] | transition matrices |
Returns | SmootherResult |
Let us apply the Kalman smoother to our simulated observations.
= kalman(y, glssm_model)
filtered
= smoother(filtered, glssm_model.A)
x_smooth, Xi_smooth
= plt.subplots(1, 2, figsize=(6, 3))
fig, (ax1, ax2)
fig.tight_layout()"filter and smoother")
ax1.set_title(0], label="$\\hat X_{{t | t}} $")
ax1.plot(x_filt[:, 0], label="$\\hat X_{{t | n}}$")
ax1.plot(x_smooth[:,
ax1.legend()
"Smoothed seasonal components")
ax2.set_title(2])
ax2.plot(x_smooth[:, 2])
ax2.plot(x[:,
plt.show()
The smoothed states are indeed smoother.
Missing observations
When entries of \(Y_{t}\) are missing, we can adapt the model by updating both \(A_{t}\) and \(\Omega_{t}\). Let \[M_{t} = \operatorname{diag} \left( \mathbf 1_{Y_{t, 1} \text{ observed}}, \dots, \mathbf 1_{Y_{t, p} \text{ observed}}\right),\] then we replace \(A_{t}\) by \(M_{t}A_{t}\) and \(\Omega_{t} = M_{t}\Omega_{t}M_{t}^{T}\), see also section 4.10 in (J. Durbin and Koopman 2012).
account_for_nans
account_for_nans (model:isssm.typing.GLSSM, y:jaxtyping.Float[Array,'n+1p'])
Let’s try removing some observations in the middle. Notice that the filter essentially keeps the filtered states, only applying the systems dynamics to propagate the current best estimate.
= y.at[40:60].set(jnp.nan)
y_missing = account_for_nans(glssm_model, y_missing)
model_missing, y_accounted = kalman(y_accounted, model_missing)
filter_result_missing = filter_result_missing
x_filt_missing, _, _, _
= smoother(filter_result_missing, model_missing.A)
x_smooth_missing, _ = plt.subplots(1, 3, figsize=(9, 3))
fig, (ax1, ax2, ax3)
fig.tight_layout()
"")
ax1.set_title(="$Y$")
ax1.plot(y_missing, label0], label="$X$")
ax1.plot(x[:, 0], label="$\\hat X_{{t|t}}$")
ax1.plot(x_filt_missing[:, 0], label="$\\hat X_{{t|n}}$")
ax1.plot(x_smooth_missing[:,
ax1.legend()
2:])
ax2.plot(x[:, "Seasonal component $X_{{t, (2,3)}}$")
ax2.set_title(
"Filtered seasonal component $\\hat X_{{t, (2,3)|t}}$")
ax3.set_title(2:])
ax3.plot(x_filt_missing[:,
plt.show()
/opt/homebrew/Caskroom/miniconda/base/envs/research/lib/python3.10/site-packages/jax/_src/ops/scatter.py:96: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float64 to dtype=int64 with jax_numpy_dtype_promotion='standard'. In future JAX releases this will result in an error.
warnings.warn(
Prediction intervals
As the conditional distribution of states given observations is Gaussian, we can obtain marginal prediction intervals, i.e. for every individual state \(X_{t, i}\), \(t = 0, \dots, n\), \(i = 1, \dots, m\), by using the Gaussian inverse CDF.
smoother_intervals
smoother_intervals (result:isssm.typing.SmootherResult, alpha:jaxtyping.Float=0.05)
filter_intervals
filter_intervals (result:isssm.typing.FilterResult, alpha:jaxtyping.Float=0.05)
= kalman(y, glssm_model)
filtered = smoother(filtered, glssm_model.A)
s_result
= smoother_intervals(s_result)
s_lower, s_upper = filter_intervals(filtered)
f_lower, f_upper
= s_result
x_smooth, _ = plt.subplots(1, 1, figsize=(6, 3))
fig, ax1
fig.tight_layout()"Filtering and smoothing intervals")
ax1.set_title(20, 0], label="$X_{t,1}$")
ax1.plot(x[:20, 0], linestyle="--", color="grey")
ax1.plot(s_lower[:20, 0], linestyle="--", color="grey", label="95% smoothing PI")
ax1.plot(s_upper[:20, 0], linestyle="--", color="orange")
ax1.plot(f_lower[:20, 0], linestyle="--", color="orange", label="95% filtering PI")
ax1.plot(f_upper[:
ax1.legend()
plt.show()
Sampling from the smoothing distribution
After having run the Kalman filter we can use a recursion due to Frühwirth-Schnatter (Frühwirth-Schnatter 1994) to obtain samples from the joint conditional distribution the states given observations.
By the dependency structure of states and observations the conditional densities can be factorized in the following way:
\[ \begin{align*} p(x_0, \dots, x_n | y_0, \dots, y_n) &= p(x_n | y_0, \dots, y_n) \prod_{t = n - 1}^0 p(x_{t}| x_{t + 1}, \dots, x_n, y_0, \dots, y_n) \\ &= p(x_n | y_0, \dots, y_n) \prod_{t = n - 1}^0 p(x_{t}| x_{t + 1}, y_0, \dots, y_n) \end{align*} \]
and the conditional distributions are again gaussian with conditional expecatation \[ \mathbf E (X_{t} | X_{t + 1}, Y_0, \dots, Y_n) = \hat X_{t|t} + G_t (X_{t + 1} - \hat X_{t + 1|t}) \] and conditional covariance matrix \[ \text{Cov} (X_t | X_{t + 1}, Y_0, \dots, Y_n) = \Xi_{t|t} - G_t\Xi_{t + 1 | t} G_t^T \]
where \(G_t = \Xi_{t|t} A_t^T \Xi_{t + 1|t}^{-1}\) is the smoothing gain.
FFBS
FFBS (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])
The Forward-Filter Backwards-Sampling Algorithm from (Frühwirth-Schnatter 1994).
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | Observations \(y\) |
model | GLSSM | GLSSM |
N | int | number of samples |
key | Union | random state |
Returns | Float[Array, ‘N n+1 m’] | N samples from the smoothing distribution |
= jrn.split(key)
key, subkey = FFBS(y, glssm_model, 1, subkey)
(X_sim,)
assert X_sim.shape == x.shape
= smoother(filtered, glssm_model.A)
x_smooth, Xi_smooth
= plt.subplots(1, 2, figsize=(10, 5))
fig, (ax1, ax2)
fig.tight_layout()"Smoothed states")
ax1.set_title(
ax1.plot(x_smooth)"Simulation from smoothing distribution")
ax2.set_title(
ax2.plot(X_sim) plt.show()
The disturbance smoother
When the interest lies in the signal \(S_{t} = B_{t}X_{t}\), \(t = 0, \dots, n\), it is often more efficient to perform the following disturbance smoother, see Section 4.5 in (J. Durbin and Koopman 2012) for details. The recursions run from \(t = n\) to \(t = 0\) and are initialized by \(r_n =0 \in \R^{m}\). While it is also possible to obtain smoothed state innovations \(\hat \varepsilon_{t | n}\), we will not be interested in them in the following, so we skip them.
\[ \begin{align*} \hat\eta_{t | n} &= \Omega_{t} \left( \Psi_{t| t - 1}^{-1}(Y_{t} - Y_{t | t - 1}) - K_{t}^{T}A_{t}^{T}r_{t} \right) \\ L_{t} &= A_{t} \left( I - K_{t}B_{t} \right) \\ r_{t - 1} &= B_{t}^T \Psi_{t | t - 1}\left( Y_{t} - \hat Y_{t| t - 1} \right) + L_{t}^{T}r_{t} \end{align*} \] While it is also possible to derive smoothed covariance matrices, we will not need them, as we can use the simulation smoother, which is based on mean adjustments.
smoothed_signals
smoothed_signals (filtered:isssm.typing.FilterResult, y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM)
compute smoothed signals from filter result
Type | Details | |
---|---|---|
filtered | FilterResult | filter result |
y | Float[Array, ‘n+1 p’] | observations |
model | GLSSM | model |
Returns | Float[Array, ‘n+1 m’] | smoothed signals |
disturbance_smoother
disturbance_smoother (filtered:isssm.typing.FilterResult, y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM)
perform the disturbance smoother for observation disturbances only
Type | Details | |
---|---|---|
filtered | FilterResult | filter result |
y | Float[Array, ‘n+1 p’] | observations |
model | GLSSM | model |
Returns | Float[Array, ‘n+1 p’] | smoothed disturbances |
= vmap(jnp.matmul)(glssm_model.B, x_smooth)
s_smooth_ks = smoothed_signals(filtered, y, glssm_model)
s_smooth
= plt.subplots(1, 3, figsize=(15, 5))
fig, (ax1, ax2, ax3)
fig.tight_layout()"kalman smoother")
ax1.set_title(
ax1.plot(s_smooth_ks)"disturbance smoother")
ax2.set_title(
ax2.plot(s_smooth)"difference")
ax3.set_title(- s_smooth)
ax3.plot(s_smooth_ks plt.show()
= vmap(jnp.matmul)
vmm = 100
s_big = stsm(jnp.zeros(2 + s_big - 1), 0., .1, .1, 100, jnp.eye(2 + s_big - 1), 3, s_big)
big_model = jrn.split(key)
key, subkey = simulate_glssm(big_model, 1, subkey)
_, (big_y,) = kalman(y, big_model) big_filtered
48.4 ms ± 312 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
322 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
We see that for large state space, but low dimensional signal, the signal smoother drastically outperforms the simple smoother, and should be preferred if our main interest lies in the signals.
The simulation smoother
The simulation smoother (James Durbin and Koopman 2002) is a method for sampling from the smoothing distribution, without explicitly calculating conditional covariance matrices. It is based on the disturbance smoother. We will implement it for the signal only.
- Calculate the conditional expectation \(\mathbf E \left( \eta_{t} | Y_{0} = y_{0}, \dots, Y_{n} = y_{n} \right)\) by the disturbance smoother.
- Generate a new draw \((X^+, Y^+)\) from the state space model using innovations \(\eta^+\).
- Calculate the conditional expectation \(\mathbf E \left( \eta^{+} | Y_{0} = y_{0}^+, \dots, Y_{n} = y_{n}^+\right)\) by the disturbance smoother.
Then \[ \mathbf E \left( \eta_{t} | Y_{0} = y_{0}, \dots, Y_{n} = y_{n} \right)+ \left(\eta^{+} - \mathbf E \left( \eta^{+} | Y_{0} = y_{0}^+, \dots, Y_{n} = y_{n}^+\right)\right) \] is a draw from the smoothing distribution \(\eta | Y_{0} = y_{0}, \dots, Y_{n} = y_{n}\), because the second term is centered and independent from the first term. The first term contributes the mean, the second term the covariance.
simulation_smoother
simulation_smoother (model:isssm.typing.GLSSM, y:jaxtyping.Float[Array,'n+1p'], N:int, key:Union[ja xtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])
Simulate from the smoothing distribution of signals
Type | Details | |
---|---|---|
model | GLSSM | model |
y | Float[Array, ‘n+1 p’] | observations |
N | int | number of samples to draw |
key | Union | random number seed |
Returns | Float[Array, ‘N n+1 m’] | N samples from the smoothing distribution of signals |
= simulation_smoother(glssm_model, y, 10, subkey)
smooth_signals_sim
= vmap(lambda B, Xi: B @ Xi @ B.T)(glssm_model.B, Xi_smooth)[:, 0, 0]
signal_vars 2 * signal_vars, color="black", label="95% PI")
plt.plot("Residuals in smoothed signals w/ 95% marginal PI")
plt.title(-2 * signal_vars, color="black")
plt.plot(0].T - s_smooth, alpha=0.05)
plt.plot(smooth_signals_sim[:, :, plt.show()
= vmap(vmap(jnp.matmul), (None, 0))
vmm = 300
s_big = stsm(jnp.zeros(2 + s_big - 1), 0., .1, .1, 100, jnp.eye(2 + s_big - 1), 3, s_big)
big_model = jrn.split(key)
key, subkey = simulate_glssm(big_model, 1, subkey)
_, (big_y,) = 100
N = jrn.split(key)
key, subkey
# ignore antithetics, could also be used for FFBS
974 ms ± 14.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.71 s ± 124 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Recovering states from signals
Both the signal and simulation smoother operate on the signals. If we are interested in the states, we have to recover them from the signals. As
\[ S_{t} = B_{t}X_{t} \]
for all \(t\), we can recover the mode (which is the mean in the Gaussian case) of the states by performing the Kalman-filter and smoother for the signal model, i.e. where we set \(\Omega_t = \mathbf 0_{p\times p}\) and \(y_t = s_t\). As the joint distribution of \((X,S)\) is Gaussian, the Kalman filter and smoother compute the conditional distribution \(X | S\), which is Gaussian again and its mean coincides with its mode.
state_mode
state_mode (model:isssm.typing.GLSSM|isssm.typing.PGSSM, signal_mode:jaxtyping.Float[Array,'n+1p'])
state_conditional_on_signal
state_conditional_on_signal (model:isssm.typing.GLSSM|isssm.typing.PGSSM, signal_mode:jaxtyping.Float[Array,'n+1p'])
to_signal_model
to_signal_model (model:isssm.typing.GLSSM|isssm.typing.PGSSM)