Laplace approximation for log-concave state space models

See also the corresponding section in my thesis

Consider an PGSSM with states \(X\) and observations \(Y\). If the joint distribution of \(X\) and \(Y\) is not Gaussian, we are unable to perform the standard Kalman filter and smoother.

Here we implement the alternative Laplace approximation (LA) method from (Durbin and Koopman 2012), Chapter 10.6, which is called mode approximation there. It’s main idea is to approximate the posterior distribution by a Gaussian distribution by a second-order Taylor expansion for the log-pdf around the posterior mode.

This essentially means matching the first and second-order derivatives of the observation log-likelihoods at the mode. As the mode is a (global) maximum of the posterior distribution, we can find it by a Newton-Raphson iteration for which (Durbin and Koopman 2012) show, that it can be implemented efficiently by a single pass of a Kalman smoother.

The LA procedure is based on the observation that at the mode \(\hat s = (\hat s_0, \dots, \hat s_n)\) the surrogate Gaussian model with the same state equation and observation equations

\[ \begin{align*} S_t &= B_t X_t \\ Z_t &= S_t + \eta_t \\ \eta_t &\sim \mathcal N\left(0, \Omega_t\right) \end{align*} \] for \(\Omega_t^{-1} = -\frac{\partial^2 \log p(y_t|\cdot)}{\partial (s_t)^2}|_{\hat s_t}\) has, for \(z_t = s_t +\Omega_t {\partial p(y_t|\cdot)}{\partial s_t}|_{\hat s_t}\) mode \(\hat s\).

In most cases we are interested in, the observations are conditionally independent given the signals such that \(\Omega\) is a diagonal matrix, which makes inversion much faster as we only have to invert the diagonals. This implementation assumes this to hold, but could be extended to handel the general case as well (replace the calls to vdiag by solve).

This is used in a fixed point iteration:

  1. Start with an initial guess \(\hat s\).
  2. Setup the above Gaussian approximation.
  3. Perform a pass of the signal smoother, obtaining the posterior mode \(\hat s^+\).
  4. Set \(\hat s = \hat s^+\) and iterate until convergence.

Currently, we assume that \(\Omega\) is always positive definite, i.e. that \(s\mapsto p(y|s)\) is strictly log-concave. This is the case for natural exponential families but might be violated otherwise. In this case, the Kalman filter and smoother can still be used to perform the Laplace approximation, but then has to be based on the ideas developed in (Jungbacker and Koopman 2007).

Code
from functools import partial

import fastcore.test as fct
import jax.numpy as jnp
import jax.random as jrn
import matplotlib.pyplot as plt
from jax import jit, vmap
from jax.scipy.special import expit

from isssm.kalman import kalman

source

posterior_mode

 posterior_mode (proposal:isssm.typing.GLSSMProposal)

source

laplace_approximation

 laplace_approximation (y:jaxtyping.Float[Array,'n+1p'],
                        model:isssm.typing.PGSSM, n_iter:int,
                        log_lik=None, d_log_lik=None, dd_log_lik=None,
                        eps:jaxtyping.Float=1e-05, link=<function
                        <lambda>>)
Type Default Details
y Float[Array, ‘n+1 p’] observation
model PGSSM
n_iter int number of iterations
log_lik NoneType None log likelihood function
d_log_lik NoneType None derivative of log likelihood function
dd_log_lik NoneType None second derivative of log likelihood function
eps Float 1e-05 precision of iterations
link function default link to use in initial guess
Returns tuple

We return to our running example and use mode estimation to obtain an estimate for the mode of the conditional distribution of states given the observations.

from isssm.pgssm import nb_pgssm_running_example
s_order = 5
model = nb_pgssm_running_example(
    s_order=s_order,
    Sigma0_seasonal=jnp.eye(s_order - 1),
    x0_seasonal=jnp.zeros(s_order - 1),
)

key = jrn.PRNGKey(511)
key, subkey = jrn.split(key)
N = 1
(X,), (Y,) = simulate_pgssm(model, N, subkey)

proposal, info = laplace_approximation(Y, model, 10)
smooth_signal = posterior_mode(proposal)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.set_title(f"Observations and signal mode after {info.n_iter} iterations")
ax1.plot(Y[:], linestyle="--", color="gray", label="$Y_t$")
ax1.plot(jnp.exp(smooth_signal), color="blue", label="$\\exp (S_t) $")
ax2.set_title("true signal and mode")
ax2.plot(vmap(jnp.matmul)(model.B, X), linestyle="--", color="gray", label="$S_t$")
ax2.plot(smooth_signal, color="blue", label="mode")
ax1.legend()
ax2.legend()

# unique legend
# handles, labels = plt.gca().get_legend_handles_labels()
# by_label = dict(zip(labels, handles))
# plt.legend(by_label.values(), by_label.keys())
plt.show()

The default implementation of the mode_estimation method uses automatic differentiation to evaluate the first and second derivatives necessary to implement the LA. You can also provide the derivatives yourself, e.g. for efficiency or numerical stability.

Code
r = 20.
def nb_log_lik(s_ti, r_ti, y_ti):
    return jnp.sum(y_ti * jnp.log(expit(s_ti - jnp.log(r_ti))) - r_ti * jnp.log(jnp.exp(s_ti) + r_ti))

def d_nb_log_lik(s_ti, r_ti, y_ti):
    return y_ti - (y_ti + r_ti) * expit(s_ti - jnp.log(r_ti))

def dd_nb_log_lik(s_ti, r_ti, y_ti):
    return -(y_ti + r_ti) * expit(s_ti - jnp.log(r_ti)) * (1 - expit(s_ti - jnp.log(r_ti)))

print("10 iterations of LA with AD ")

print("10 iterations of LA with analytical gradients")
10 iterations of LA with AD 
2.27 s ± 190 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
10 iterations of LA with analytical gradients
1.94 s ± 115 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

References

Durbin, J., and S. J. Koopman. 2012. Time Series Analysis by State Space Methods. 2nd ed. Oxford Statistical Science Series 38. Oxford: Oxford University Press.
Jungbacker, Borus, and Siem Jan Koopman. 2007. “Monte Carlo Estimation for Nonlinear Non-Gaussian State Space Models.” Biometrika 94 (4): 827–39. https://doi.org/10.1093/biomet/asm074.