from isssm.importance_sampling import log_weights
from isssm.pgssm import simulate_pgssm, nb_pgssm_running_example
from isssm.kalman import FFBS
from isssm.kalman import kalman, smoother
import jax.random as jrn
from isssm.laplace_approximation import laplace_approximation
import jax.numpy as jnp
from jax import vmap
from functools import partial
import matplotlib.pyplot as plt
from isssm.importance_sampling import ess_lw
from isssm.typing import PGSSM
from isssm.kalman import smoothed_signals
from isssm.laplace_approximation import posterior_mode
from isssm.typing import GLSSMProposal
Modified Efficient Importance Sampling
Modified efficient importance sampling is used to improve on the Laplace approximation. The goal is to minimize the variance of log-weights in importance sampling Koopman, Lit, and Nguyen (2019). If the importance sampling model is a Gaussian one where observations are conditionally independent across time and space, its iterations reduce to a least squares problem which can be solved efficiently.
\[ \int \left(\log p (y | x) - \log g(y|x) - \lambda \right)^2 \log p(x|y) \mathrm d x \] where \(\lambda = \mathbf E \left(\log p(y|x) - \log g(z|x)| Y = y\right)\). This is approximated by an importance sampling version
\[ \sum_{i = 1}^N \left(\log p(y|X^i) - \log g(z|X^i) - \lambda\right) w(X^i). \]
Using gaussian proposals \(g\) we have for signals \(S_t = B_t X_t\)
\[ \begin{align*} \log g(z_{t}|s_{t}) &= -\frac{1}{2} (z_{t} - s_{t})^T\Omega^{-1}_{t}(z_{t} - s_{t}) - \frac{p}{2} \log (2\pi) - \frac{1}{2} \log \det \Omega_{t}. \end{align*} \] where \(\Omega_t = \operatorname{diag} \left( \omega_{t,1}, \dots, \omega_{t,p}\right)\).
Due to the large dimension of the problem we solve it for each \(t\) separately
\[ \begin{align*} \sum_{i = 1}^N(\log p(y_t|s_t^{i}) - \log g(z_{t}| s_t^{i}) - \lambda_{t})^2 w(s_t^i) &= \sum_{i = 1}^N\left(\log p(y_t|s_t^{i}) +\frac{1}{2} (z_{t} - s_{t}^{i})^T\Omega^{-1}_{t}(z_{t} - s_{t}^{i}) + \frac{p}{2} \log (2\pi) + \frac{1}{2} \log \det \Omega_{t} - \lambda_{t}\right)^2 w(s_t^i) %&= \sum_{i = 1}^{N}(\log p(y_t|s_t^{i}) - (- 2 \Omega_t ^{-1}z_t)^{T} s^{i}_t - (s^{i}_t)^T\Omega_t^{-1}s^{i}_t - \lambda_t - C_t)^2w(s_t^i) \\ %&= \sum_{i = 1}^N \left(\log p(y_t|s_t^{i}) - (s^{i}_t)^{T}(- 2 \Omega_t ^{-1}z_t) + \frac{1}{2} \sum_{j = 1}^{p} (s^{i}_{t,j})^{2} \frac{1}{\omega_{t,j}} - \lambda_t - C_t \right) \end{align*} \]
and minimized over the unknown parameters \(\left(z_t, \Omega_t, \lambda_t - C_t\right)\), which is a weighted least squares setting with “observations” \(\log p(y_t|s_t)\).
To perform the estimation memory efficient, we combine the FFBS algorithm (see [00_glssm.ipynb]) with the optimization procedure, so the memory requirement of this algorithm is \(\mathcal O(N)\) instead of \(\mathcal O(N\cdot n)\).
modified_efficient_importance_sampling
modified_efficient_importance_sampling (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM, z_init: jaxtyping.Float[Array,'n+1p'], Om ega_init:jaxtyping.Float[Array,'n +1pp'], n_iter:int, N:int, key:Un ion[jaxtyping.Key[Array,''],jaxty ping.UInt32[Array,'2']], eps:jaxtyping.Float=1e-05)
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observations | |
model | PGSSM | model | |
z_init | Float[Array, ‘n+1 p’] | initial z estimate | |
Omega_init | Float[Array, ‘n+1 p p’] | initial Omega estimate | |
n_iter | int | number of iterations | |
N | int | number of samples | |
key | Union | random key | |
eps | Float | 1e-05 | convergence threshold |
Returns | tuple |
optimal_parameters
optimal_parameters (signal:jaxtyping.Float[Array,'Np'], weights:jaxtyping.Float[Array,'N'], log_p:jaxtyping.Float[Array,'N'])
= nb_pgssm_running_example(
model =2, Sigma0_seasonal=jnp.eye(2 - 1), x0_seasonal=jnp.zeros(2 - 1)
s_order
)= jrn.PRNGKey(511)
key = jrn.split(key)
key, subkey = 1
N = simulate_pgssm(model, N, subkey)
(x,), (y,) = laplace_approximation(y, model, 10)
proposal_la, info_la
= int(1e4)
N = jrn.split(key)
key, subkey = modified_efficient_importance_sampling(
proposal_meis, info_meis 10, N, subkey
y, model, proposal_la.z, proposal_la.Omega,
)= GLSSM(
glssm_meis
model.u,
model.A,
model.A,
model.Sigma0,
model.Sigma,
model.v,
model.B,
proposal_meis.Omega,
)
= plt.subplots(1, 3, figsize=(15, 5))
fig, axs
fig.tight_layout()
= ((proposal_meis.z - proposal_la.z) ** 2).mean(axis=1)
z_diff 0].set_title("Difference in z ME vs. MEIS")
axs[0].plot(z_diff)
axs[
= ((proposal_meis.Omega - proposal_la.Omega) ** 2).mean(axis=(1, 2))
Omega_diff 1].set_title("Difference in Omega ME vs. MEIS")
axs[1].plot(Omega_diff)
axs[
= posterior_mode(proposal_la)
s_smooth_la = posterior_mode(proposal_meis)
s_smooth_meis
2].set_title("Smoothed signals")
axs[2].plot(s_smooth_la, label="LA")
axs[2].plot(s_smooth_meis, label="MEIS")
axs[2].legend()
axs[
plt.show()
from isssm.importance_sampling import pgssm_importance_sampling
from isssm.importance_sampling import ess_pct
= 1000
N = pgssm_importance_sampling(
_, lw_la 423423)
y, model, proposal_la.z, proposal_la.Omega, N, jrn.PRNGKey(
)= pgssm_importance_sampling(
samples, lw 423423)
y, model, proposal_meis.z, proposal_meis.Omega, N, jrn.PRNGKey(
)= normalize_weights(lw)
weights f"EIS weights, ESS = {ess_pct(lw):.2f}% (vs. LA ESS = {ess_pct(lw_la):.2f}%)")
plt.title(4 * N * weights[None, :], bins=50)
plt.hist( plt.show()
EIS increases ESS of importance sampling from LA.