Code
from functools import partial
import fastcore.test as fct
import jax.numpy as jnp
import jax.random as jrn
import matplotlib.pyplot as plt
from jax import jit, vmap
from jax.scipy.special import expit
from isssm.kalman import kalman
Consider an PGSSM with states \(X\) and observations \(Y\). If the joint distribution of \(X\) and \(Y\) is not Gaussian, we are unable to perform the standard Kalman filter and smoother.
Here we implement the alternative Laplace approximation (LA) method from (Durbin and Koopman 2012), Chapter 10.6, which is called mode approximation there. It’s main idea is to approximate the posterior distribution by a Gaussian distribution by a second-order Taylor expansion for the log-pdf around the posterior mode.
This essentially means matching the first and second-order derivatives of the observation log-likelihoods at the mode. As the mode is a (global) maximum of the posterior distribution, we can find it by a Newton-Raphson iteration for which (Durbin and Koopman 2012) show, that it can be implemented efficiently by a single pass of a Kalman smoother.
The LA procedure is based on the observation that at the mode \(\hat s = (\hat s_0, \dots, \hat s_n)\) the surrogate Gaussian model with the same state equation and observation equations
\[ \begin{align*} S_t &= B_t X_t \\ Z_t &= S_t + \eta_t \\ \eta_t &\sim \mathcal N\left(0, \Omega_t\right) \end{align*} \] for \(\Omega_t^{-1} = -\frac{\partial^2 \log p(y_t|\cdot)}{\partial (s_t)^2}|_{\hat s_t}\) has, for \(z_t = s_t +\Omega_t {\partial p(y_t|\cdot)}{\partial s_t}|_{\hat s_t}\) mode \(\hat s\).
In most cases we are interested in, the observations are conditionally independent given the signals such that \(\Omega\) is a diagonal matrix, which makes inversion much faster as we only have to invert the diagonals. This implementation assumes this to hold, but could be extended to handel the general case as well (replace the calls to vdiag
by solve
).
This is used in a fixed point iteration:
Currently, we assume that \(\Omega\) is always positive definite, i.e. that \(s\mapsto p(y|s)\) is strictly log-concave. This is the case for natural exponential families but might be violated otherwise. In this case, the Kalman filter and smoother can still be used to perform the Laplace approximation, but then has to be based on the ideas developed in (Jungbacker and Koopman 2007).
posterior_mode (proposal:isssm.typing.GLSSMProposal)
laplace_approximation (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM, n_iter:int, log_lik=None, d_log_lik=None, dd_log_lik=None, eps:jaxtyping.Float=1e-05, link=<function <lambda>>)
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observation | |
model | PGSSM | ||
n_iter | int | number of iterations | |
log_lik | NoneType | None | log likelihood function |
d_log_lik | NoneType | None | derivative of log likelihood function |
dd_log_lik | NoneType | None | second derivative of log likelihood function |
eps | Float | 1e-05 | precision of iterations |
link | function | default link to use in initial guess | |
Returns | tuple |
We return to our running example and use mode estimation to obtain an estimate for the mode of the conditional distribution of states given the observations.
s_order = 5
model = nb_pgssm_running_example(
s_order=s_order,
Sigma0_seasonal=jnp.eye(s_order - 1),
x0_seasonal=jnp.zeros(s_order - 1),
)
key = jrn.PRNGKey(511)
key, subkey = jrn.split(key)
N = 1
(X,), (Y,) = simulate_pgssm(model, N, subkey)
proposal, info = laplace_approximation(Y, model, 10)
smooth_signal = posterior_mode(proposal)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.set_title(f"Observations and signal mode after {info.n_iter} iterations")
ax1.plot(Y[:], linestyle="--", color="gray", label="$Y_t$")
ax1.plot(jnp.exp(smooth_signal), color="blue", label="$\\exp (S_t) $")
ax2.set_title("true signal and mode")
ax2.plot(vmap(jnp.matmul)(model.B, X), linestyle="--", color="gray", label="$S_t$")
ax2.plot(smooth_signal, color="blue", label="mode")
ax1.legend()
ax2.legend()
# unique legend
# handles, labels = plt.gca().get_legend_handles_labels()
# by_label = dict(zip(labels, handles))
# plt.legend(by_label.values(), by_label.keys())
plt.show()
The default implementation of the mode_estimation
method uses automatic differentiation to evaluate the first and second derivatives necessary to implement the LA. You can also provide the derivatives yourself, e.g. for efficiency or numerical stability.
r = 20.
def nb_log_lik(s_ti, r_ti, y_ti):
return jnp.sum(y_ti * jnp.log(expit(s_ti - jnp.log(r_ti))) - r_ti * jnp.log(jnp.exp(s_ti) + r_ti))
def d_nb_log_lik(s_ti, r_ti, y_ti):
return y_ti - (y_ti + r_ti) * expit(s_ti - jnp.log(r_ti))
def dd_nb_log_lik(s_ti, r_ti, y_ti):
return -(y_ti + r_ti) * expit(s_ti - jnp.log(r_ti)) * (1 - expit(s_ti - jnp.log(r_ti)))
print("10 iterations of LA with AD ")
print("10 iterations of LA with analytical gradients")
10 iterations of LA with AD
2.27 s ± 190 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
10 iterations of LA with analytical gradients
1.94 s ± 115 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)