Modified Efficient Importance Sampling

See also the corresponding section in my thesis

Modified efficient importance sampling is used to improve on the Laplace approximation. The goal is to minimize the variance of log-weights in importance sampling Koopman, Lit, and Nguyen (2019). If the importance sampling model is a Gaussian one where observations are conditionally independent across time and space, its iterations reduce to a least squares problem which can be solved efficiently.

\[ \int \left(\log p (y | x) - \log g(y|x) - \lambda \right)^2 \log p(x|y) \mathrm d x \] where \(\lambda = \mathbf E \left(\log p(y|x) - \log g(z|x)| Y = y\right)\). This is approximated by an importance sampling version

\[ \sum_{i = 1}^N \left(\log p(y|X^i) - \log g(z|X^i) - \lambda\right) w(X^i). \]

Using gaussian proposals \(g\) we have for signals \(S_t = B_t X_t\)

\[ \begin{align*} \log g(z_{t}|s_{t}) &= -\frac{1}{2} (z_{t} - s_{t})^T\Omega^{-1}_{t}(z_{t} - s_{t}) - \frac{p}{2} \log (2\pi) - \frac{1}{2} \log \det \Omega_{t}. \end{align*} \] where \(\Omega_t = \operatorname{diag} \left( \omega_{t,1}, \dots, \omega_{t,p}\right)\).

Due to the large dimension of the problem we solve it for each \(t\) separately

\[ \begin{align*} \sum_{i = 1}^N(\log p(y_t|s_t^{i}) - \log g(z_{t}| s_t^{i}) - \lambda_{t})^2 w(s_t^i) &= \sum_{i = 1}^N\left(\log p(y_t|s_t^{i}) +\frac{1}{2} (z_{t} - s_{t}^{i})^T\Omega^{-1}_{t}(z_{t} - s_{t}^{i}) + \frac{p}{2} \log (2\pi) + \frac{1}{2} \log \det \Omega_{t} - \lambda_{t}\right)^2 w(s_t^i) %&= \sum_{i = 1}^{N}(\log p(y_t|s_t^{i}) - (- 2 \Omega_t ^{-1}z_t)^{T} s^{i}_t - (s^{i}_t)^T\Omega_t^{-1}s^{i}_t - \lambda_t - C_t)^2w(s_t^i) \\ %&= \sum_{i = 1}^N \left(\log p(y_t|s_t^{i}) - (s^{i}_t)^{T}(- 2 \Omega_t ^{-1}z_t) + \frac{1}{2} \sum_{j = 1}^{p} (s^{i}_{t,j})^{2} \frac{1}{\omega_{t,j}} - \lambda_t - C_t \right) \end{align*} \]

and minimized over the unknown parameters \(\left(z_t, \Omega_t, \lambda_t - C_t\right)\), which is a weighted least squares setting with “observations” \(\log p(y_t|s_t)\).

To perform the estimation memory efficient, we combine the FFBS algorithm (see [00_glssm.ipynb]) with the optimization procedure, so the memory requirement of this algorithm is \(\mathcal O(N)\) instead of \(\mathcal O(N\cdot n)\).


source

modified_efficient_importance_sampling

 modified_efficient_importance_sampling (y:jaxtyping.Float[Array,'n+1p'],
                                         model:isssm.typing.PGSSM, z_init:
                                         jaxtyping.Float[Array,'n+1p'], Om
                                         ega_init:jaxtyping.Float[Array,'n
                                         +1pp'], n_iter:int, N:int, key:Un
                                         ion[jaxtyping.Key[Array,''],jaxty
                                         ping.UInt32[Array,'2']],
                                         eps:jaxtyping.Float=1e-05)
Type Default Details
y Float[Array, ‘n+1 p’] observations
model PGSSM model
z_init Float[Array, ‘n+1 p’] initial z estimate
Omega_init Float[Array, ‘n+1 p p’] initial Omega estimate
n_iter int number of iterations
N int number of samples
key Union random key
eps Float 1e-05 convergence threshold
Returns tuple

source

optimal_parameters

 optimal_parameters (signal:jaxtyping.Float[Array,'Np'],
                     weights:jaxtyping.Float[Array,'N'],
                     log_p:jaxtyping.Float[Array,'N'])
from isssm.importance_sampling import log_weights
from isssm.pgssm import simulate_pgssm, nb_pgssm_running_example
from isssm.kalman import FFBS
from isssm.kalman import kalman, smoother
import jax.random as jrn
from isssm.laplace_approximation import laplace_approximation
import jax.numpy as jnp
from jax import vmap
from functools import partial
import matplotlib.pyplot as plt
from isssm.importance_sampling import ess_lw
from isssm.typing import PGSSM
from isssm.kalman import smoothed_signals
from isssm.laplace_approximation import posterior_mode
from isssm.typing import GLSSMProposal
model = nb_pgssm_running_example(
    s_order=2, Sigma0_seasonal=jnp.eye(2 - 1), x0_seasonal=jnp.zeros(2 - 1)
)
key = jrn.PRNGKey(511)
key, subkey = jrn.split(key)
N = 1
(x,), (y,) = simulate_pgssm(model, N, subkey)
proposal_la, info_la = laplace_approximation(y, model, 10)

N = int(1e4)
key, subkey = jrn.split(key)
proposal_meis, info_meis = modified_efficient_importance_sampling(
    y, model, proposal_la.z, proposal_la.Omega, 10, N, subkey
)
glssm_meis = GLSSM(
    model.u,
    model.A,
    model.A,
    model.Sigma0,
    model.Sigma,
    model.v,
    model.B,
    proposal_meis.Omega,
)

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.tight_layout()

z_diff = ((proposal_meis.z - proposal_la.z) ** 2).mean(axis=1)
axs[0].set_title("Difference in z ME vs. MEIS")
axs[0].plot(z_diff)

Omega_diff = ((proposal_meis.Omega - proposal_la.Omega) ** 2).mean(axis=(1, 2))
axs[1].set_title("Difference in Omega ME vs. MEIS")
axs[1].plot(Omega_diff)

s_smooth_la = posterior_mode(proposal_la)
s_smooth_meis = posterior_mode(proposal_meis)

axs[2].set_title("Smoothed signals")
axs[2].plot(s_smooth_la, label="LA")
axs[2].plot(s_smooth_meis, label="MEIS")
axs[2].legend()

plt.show()

from isssm.importance_sampling import pgssm_importance_sampling
from isssm.importance_sampling import ess_pct
N = 1000
_, lw_la = pgssm_importance_sampling(
    y, model, proposal_la.z, proposal_la.Omega, N, jrn.PRNGKey(423423)
)
samples, lw = pgssm_importance_sampling(
    y, model, proposal_meis.z, proposal_meis.Omega, N, jrn.PRNGKey(423423)
)
weights = normalize_weights(lw)
plt.title(f"EIS weights, ESS = {ess_pct(lw):.2f}% (vs. LA ESS = {ess_pct(lw_la):.2f}%)")
plt.hist(4 * N * weights[None, :], bins=50)
plt.show()

EIS increases ESS of importance sampling from LA.

References

Koopman, Siem Jan, Rutger Lit, and Thuy Minh Nguyen. 2019. “Modified Efficient Importance Sampling for Partially Non-Gaussian State Space Models.” Statistica Neerlandica 73 (1): 44–62. https://doi.org/10.1111/stan.12128.
Richard, Jean-Francois, and Wei Zhang. 2007. “Efficient High-Dimensional Importance Sampling.” Journal of Econometrics 141 (2): 1385–1411. https://doi.org/10.1016/j.jeconom.2007.02.007.