Importance Sampling for Partially Gaussian State Space Models

See also the corresponding section in my thesis

After having observed \(Y\) one is usually interested in properties of the conditional distribution of states \(X\) given \(Y\). Typically this means terms of the form

\[ \begin{align*} \mathbf E (f(X) | Y) &= \mathbf E (f(X_0, \dots, X_n) | Y_0, \dots, Y_n) \\ &= \int f(x_0, \dots, x_n) p(x_0, \dots, x_n | y_0, \dots, y_n) \mathrm d x_0 \dots \mathrm d x_n. \end{align*} \]

As the density \(p(x|y)\) is known only up to a constant, we resort to importance sampling with a GLSSM, represented by its gaussian densities \(g\). The Laplace approximation and (modified) efficient importance sampling perform this task for loc concave state space models where the states are jointly gaussian.

Both methods construct surrogate linear gaussian state space models that are parameterized by synthetic observations \(z_t\) and their covariance matrices \(\Omega_t\). Usually \(\Omega_t\) is a diagonal matrix which is justified if the components of the observation vector at time \(t\), \(Y^i_t\), \(i = 1, \dots, p\) are conditionally independent given states \(X_t\).

These models are then based on the following SSM: \[ \begin{align*} X_0 &\sim \mathcal N (x_0, \Sigma_0) &&\\ X_{t + 1} &= A_t X_{t} + \varepsilon_{t + 1} &&t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) && t = 1, \dots, n \\ S_t &= B_t X_t &&\\ Z_t &= S_t + \eta_t && t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) && t=0, \dots, n. \end{align*} \]

In this setting we can transform the expectation w.r.t the condtiional density \(p(x|y)\) to one w.r.t the density \(g(x|z)\).

\[ \begin{align*} \int f(x) p(x|y) \mathrm d x &= \int f(x) \frac{p(x|y)}{g(x|z)} g(x|z) \mathrm d x\\ &= \int f(x) \frac{p(y|x)}{g(z|x)} \frac{g(z)}{p(y)} g(x|z) \mathrm d x. \end{align*} \]

Let \(w(x) = \frac{p(y|x)}{g(z|x)} = \frac{p(y|s)}{g(z|s)}\) be the (unnormalized) importance sampling weights which only depend on \(s_t = B_t x_t\), \(t = 0, \dots, n\).


source

log_weights

 log_weights (s:jaxtyping.Float[Array,'n+1p'],
              y:jaxtyping.Float[Array,'n+1p'], dist,
              xi:jaxtyping.Float[Array,'n+1p'],
              z:jaxtyping.Float[Array,'n+1p'],
              Omega:jaxtyping.Float[Array,'n+1pp'])

Log weights for all time points

Type Details
s Float[Array, ‘n+1 p’] signals
y Float[Array, ‘n+1 p’] observations
dist observation distribution
xi Float[Array, ‘n+1 p’] observation parameters
z Float[Array, ‘n+1 p’] synthetic observations
Omega Float[Array, ‘n+1 p p’] synthetic observation covariances:
Returns Float log weights

source

log_weights_t

 log_weights_t (s_t:jaxtyping.Float[Array,'p'],
                y_t:jaxtyping.Float[Array,'p'],
                xi_t:jaxtyping.Float[Array,'p'], dist,
                z_t:jaxtyping.Float[Array,'p'],
                Omega_t:jaxtyping.Float[Array,'pp'])

Log weight for a single time point.

Type Details
s_t Float[Array, ‘p’] signal
y_t Float[Array, ‘p’] observation
xi_t Float[Array, ‘p’] parameters
dist observation distribution
z_t Float[Array, ‘p’] synthetic observation
Omega_t Float[Array, ‘p p’] synthetic observation covariance, assumed diagonal
Returns Float single log weight

Importance samples are generated from the smoothing distribution in the surrogate model, i.e. from \(g(x|z)\) using e.g. the FFBS or simulation smoother algorithm.

from isssm.pgssm import simulate_pgssm, nb_pgssm_running_example
from isssm.laplace_approximation import laplace_approximation
import matplotlib.pyplot as plt

source

pgssm_importance_sampling

 pgssm_importance_sampling (y:jaxtyping.Float[Array,'n+1p'],
                            model:isssm.typing.PGSSM,
                            z:jaxtyping.Float[Array,'n+1p'],
                            Omega:jaxtyping.Float[Array,'n+1pp'], N:int, k
                            ey:Union[jaxtyping.Key[Array,''],jaxtyping.UIn
                            t32[Array,'2']])
Type Details
y Float[Array, ‘n+1 p’] observations
model PGSSM model
z Float[Array, ‘n+1 p’] synthetic observations
Omega Float[Array, ‘n+1 p p’] covariance of synthetic observations
N int number of samples
key Union random key
Returns tuple

Let us perform importance sampling for our running example using the model obtained by the laplace approximation. Notice that we obtain four times the number of samples that we specified, which comes from the use of antithetics.

key = jrn.PRNGKey(512)
key, subkey = jrn.split(key)
s_order = 4
model = nb_pgssm_running_example(
    n=100,
    s_order=s_order,
    Sigma0_seasonal=0.1 * jnp.eye(s_order - 1),
    x0_seasonal=jnp.zeros(s_order - 1),
)
N = 1
(x,), (y,) = simulate_pgssm(model, N, subkey)
proposal, info = laplace_approximation(y, model, 100)
fig, axs = plt.subplots(2, 1)
axs[0].plot(y)
axs[1].plot(x[:, 0])

N = 1000
key, subkey = jrn.split(key)
samples, lw = pgssm_importance_sampling(y, model, proposal.z, proposal.Omega, N, subkey)
plt.scatter(jnp.arange(4 * N), lw)
plt.show()

Weights should be calculated on the log scale, but we need them on the usual scale to use for Monte-Carlo integration. These are called auto-normalised weights and defined by \(W(X^i) = \frac{w(X^i)}{\sum_{i = 1}^N w(X^i)}\).

As weights are only known up to a constant, we make exponentiation numerically stable by substracting (on the log-scale) the largest weight, ensuring that \(\log w^i \leq 0\) for all weights and so \(\sum_{i = 1}^N w^i \leq N\).


source

normalize_weights

 normalize_weights (log_weights:jaxtyping.Float[Array,'N'])

Normalize importance sampling weights.

Type Details
log_weights Float[Array, ‘N’] log importance sampling weights
Returns Float[Array, ‘N’] normalized importance sampling weights
weights = normalize_weights(lw)
plt.boxplot(weights[None] * N)
plt.show()

Effective Sample Size

The effective sample size is an important diagnostic for the performance of importance sampling, it is defined by

\[ \text{ESS} = \frac{\left(\sum_{i = 1}^{N} w(X^i)\right)^2}{\sum_{i = 1}^N w^2(X_i)} = \frac{1}{\sum_{i = 1}^N W^2(X^i)} \]

To compare different approximations one may also be interested in \(\frac{\text{ESS}}{N} \cdot 100\%\), the percentage of effective samples.


source

ess_pct

 ess_pct (log_weights:jaxtyping.Float[Array,'N'])
Type Details
log_weights Float[Array, ‘N’] log weights
Returns Float the effective sample size in percent, also called efficiency factor

source

ess_lw

 ess_lw (log_weights:jaxtyping.Float[Array,'N'])

Compute the effective sample size of a set of log weights

Type Details
log_weights Float[Array, ‘N’] the log weights
Returns Float the effective sample size

source

ess

 ess (normalized_weights:jaxtyping.Float[Array,'N'])

Compute the effective sample size of a set of normalized weights

Type Details
normalized_weights Float[Array, ‘N’] normalized weights
Returns Float the effective sample size
ess(weights), ess_pct(lw)
(Array(889.98669, dtype=float64), Array(22.24966725, dtype=float64))

Monte-Carlo Integration


source

mc_integration

 mc_integration (samples:jaxtyping.Float[Array,'N...'],
                 log_weights:jaxtyping.Float[Array,'N'])
sample_mean = mc_integration(samples, lw)
true_signal = (model.B @ x[..., None]).squeeze(axis=-1)
plt.plot(true_signal, label="true signal")
plt.plot(sample_mean, label="estimated mean signal")
plt.legend()
plt.show()

Prediction

For prediction we are interested in the conditional expectations \[ \mathbf E \left(Y_{n + t} | Y_0, \dots, Y_n\right), \] where we assume that the PGSSM has some known continuation after time \(n\). We can estimate this conditional expectation by importance sampling. Given our samples \(X^{i}_n\), \(i = 1, \dots, N\), we simulate forward in time to obtain \(X^i_{n + t}\). We may then estimate the conditional expectation by \[ \sum_{i = 1}^N W^i \mathbf E \left( Y_{n + t} | X_{n + 1} = X^i_{n + t}\right), \] by the dependency structure of the model (\(Y_{n + t}\) is independent of \(Y_0, \dots, Y_n\) given \(X_{t + n}\).)

For prediction intervals, we follow the strategy provided by (Durbin and Koopman 2012), Chapter 11.5.3: sort the univariate predictions, and the corresponding weights in the same order. Then the ECDF at \(Y^i\), can be estimated as \[ \sum_{j = 1}^i W^i. \] Linearly interpolating between these values gives an ECDF which we use to create prediction intervals.


source

predict

 predict (model:isssm.typing.PGSSM, y:jaxtyping.Float[Array,'n+1p'],
          proposal:isssm.typing.GLSSMProposal,
          future_model:isssm.typing.PGSSM, N:int,
          key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])

source

future_prediction_interval

 future_prediction_interval (dist, signal_samples, xi, log_weights, p)
key, subkey = jrn.split(key)
n_ahead = 10
ten_steps_ahead_model = PGSSM(
    model.u[: n_ahead + 1],
    model.A[:n_ahead],
    model.D[:n_ahead],
    model.Sigma0,
    model.Sigma[:n_ahead],
    model.v[: n_ahead + 1],
    model.B[: n_ahead + 1],
    model.dist,
    model.xi[: n_ahead + 1],
)
(_, s_pred, y_pred), log_weights_pred = predict(
    model, y, proposal, ten_steps_ahead_model, 1000, subkey
)

mean_y = (y_pred * normalize_weights(log_weights_pred)[:, None, None]).sum()
past_inds = jnp.arange(y.shape[0])
future_inds = jnp.arange(y.shape[0], y.shape[0] + n_ahead)
percentiles = prediction_percentiles(
    y_pred, normalize_weights(log_weights_pred), jnp.array([0.1, 0.5, 0.9])
)
lower, mid, upper = percentiles
plt.plot(past_inds, y, label="observed")
plt.plot(future_inds, mid[1:], label="median")
plt.plot(
    future_inds,
    lower[1:],
    linestyle="--",
    color="grey",
    label="80% prediction interval",
)
plt.plot(future_inds, upper[1:], linestyle="--", color="grey")
plt.legend()
plt.show()


source

prediction

 prediction (f:<built-infunctioncallable>, y,
             proposal:isssm.typing.GLSSMProposal,
             model:isssm.typing.PGSSM, N:int, key:Union[jaxtyping.Key[Arra
             y,''],jaxtyping.UInt32[Array,'2']],
             probs:jaxtyping.Float[Array,'k'], prediction_model=None)
def f(x, s, y_prime):
    return y_prime[:, 0:1]


key, subkey = jrn.split(key)
mean, sd, quants = prediction(
    f, y, proposal, model, 10000, subkey, jnp.array([0.1, 0.5, 0.9])
)

plt.plot(y, label="observed")
plt.plot(mean, label="predicted")
plt.plot(quants[0], linestyle="--", color="grey", label="prediction interval")
plt.plot(quants[2], linestyle="--", color="grey")
plt.xlim(90, 100)
plt.legend()
plt.show()

quants[2, 99], y[99]
(Array([557.], dtype=float64), Array([356.], dtype=float64))
from tensorflow_probability.substrates.jax.distributions import NegativeBinomial

NegativeBinomial(20, jnp.log(356) - jnp.log(20)).cdf(557)
Array(0.98567769, dtype=float64)
future_prediction_interval(
    model.dist, s_pred[:, 0], model.xi[0], log_weights_pred, 0.5
), mid[0]
(array([49.47880729]), Array([52.31199429], dtype=float64))

References

Durbin, J., and S. J. Koopman. 2012. Time Series Analysis by State Space Methods. 2nd ed. Oxford Statistical Science Series 38. Oxford: Oxford University Press.