from isssm.pgssm import simulate_pgssm, nb_pgssm_running_example
from isssm.laplace_approximation import laplace_approximation
import matplotlib.pyplot as plt
Importance Sampling for Partially Gaussian State Space Models
After having observed \(Y\) one is usually interested in properties of the conditional distribution of states \(X\) given \(Y\). Typically this means terms of the form
\[ \begin{align*} \mathbf E (f(X) | Y) &= \mathbf E (f(X_0, \dots, X_n) | Y_0, \dots, Y_n) \\ &= \int f(x_0, \dots, x_n) p(x_0, \dots, x_n | y_0, \dots, y_n) \mathrm d x_0 \dots \mathrm d x_n. \end{align*} \]
As the density \(p(x|y)\) is known only up to a constant, we resort to importance sampling with a GLSSM, represented by its gaussian densities \(g\). The Laplace approximation and (modified) efficient importance sampling perform this task for loc concave state space models where the states are jointly gaussian.
Both methods construct surrogate linear gaussian state space models that are parameterized by synthetic observations \(z_t\) and their covariance matrices \(\Omega_t\). Usually \(\Omega_t\) is a diagonal matrix which is justified if the components of the observation vector at time \(t\), \(Y^i_t\), \(i = 1, \dots, p\) are conditionally independent given states \(X_t\).
These models are then based on the following SSM: \[ \begin{align*} X_0 &\sim \mathcal N (x_0, \Sigma_0) &&\\ X_{t + 1} &= A_t X_{t} + \varepsilon_{t + 1} &&t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) && t = 1, \dots, n \\ S_t &= B_t X_t &&\\ Z_t &= S_t + \eta_t && t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) && t=0, \dots, n. \end{align*} \]
In this setting we can transform the expectation w.r.t the condtiional density \(p(x|y)\) to one w.r.t the density \(g(x|z)\).
\[ \begin{align*} \int f(x) p(x|y) \mathrm d x &= \int f(x) \frac{p(x|y)}{g(x|z)} g(x|z) \mathrm d x\\ &= \int f(x) \frac{p(y|x)}{g(z|x)} \frac{g(z)}{p(y)} g(x|z) \mathrm d x. \end{align*} \]
Let \(w(x) = \frac{p(y|x)}{g(z|x)} = \frac{p(y|s)}{g(z|s)}\) be the (unnormalized) importance sampling weights which only depend on \(s_t = B_t x_t\), \(t = 0, \dots, n\).
log_weights
log_weights (s:jaxtyping.Float[Array,'n+1p'], y:jaxtyping.Float[Array,'n+1p'], dist, xi:jaxtyping.Float[Array,'n+1p'], z:jaxtyping.Float[Array,'n+1p'], Omega:jaxtyping.Float[Array,'n+1pp'])
Log weights for all time points
Type | Details | |
---|---|---|
s | Float[Array, ‘n+1 p’] | signals |
y | Float[Array, ‘n+1 p’] | observations |
dist | observation distribution | |
xi | Float[Array, ‘n+1 p’] | observation parameters |
z | Float[Array, ‘n+1 p’] | synthetic observations |
Omega | Float[Array, ‘n+1 p p’] | synthetic observation covariances: |
Returns | Float | log weights |
log_weights_t
log_weights_t (s_t:jaxtyping.Float[Array,'p'], y_t:jaxtyping.Float[Array,'p'], xi_t:jaxtyping.Float[Array,'p'], dist, z_t:jaxtyping.Float[Array,'p'], Omega_t:jaxtyping.Float[Array,'pp'])
Log weight for a single time point.
Type | Details | |
---|---|---|
s_t | Float[Array, ‘p’] | signal |
y_t | Float[Array, ‘p’] | observation |
xi_t | Float[Array, ‘p’] | parameters |
dist | observation distribution | |
z_t | Float[Array, ‘p’] | synthetic observation |
Omega_t | Float[Array, ‘p p’] | synthetic observation covariance, assumed diagonal |
Returns | Float | single log weight |
Importance samples are generated from the smoothing distribution in the surrogate model, i.e. from \(g(x|z)\) using e.g. the FFBS or simulation smoother algorithm.
pgssm_importance_sampling
pgssm_importance_sampling (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM, z:jaxtyping.Float[Array,'n+1p'], Omega:jaxtyping.Float[Array,'n+1pp'], N:int, k ey:Union[jaxtyping.Key[Array,''],jaxtyping.UIn t32[Array,'2']])
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | observations |
model | PGSSM | model |
z | Float[Array, ‘n+1 p’] | synthetic observations |
Omega | Float[Array, ‘n+1 p p’] | covariance of synthetic observations |
N | int | number of samples |
key | Union | random key |
Returns | tuple |
Let us perform importance sampling for our running example using the model obtained by the laplace approximation. Notice that we obtain four times the number of samples that we specified, which comes from the use of antithetics.
= jrn.PRNGKey(512)
key = jrn.split(key)
key, subkey = 4
s_order = nb_pgssm_running_example(
model =100,
n=s_order,
s_order=0.1 * jnp.eye(s_order - 1),
Sigma0_seasonal=jnp.zeros(s_order - 1),
x0_seasonal
)= 1
N = simulate_pgssm(model, N, subkey)
(x,), (y,) = laplace_approximation(y, model, 100) proposal, info
= plt.subplots(2, 1)
fig, axs 0].plot(y)
axs[1].plot(x[:, 0]) axs[
= 1000
N = jrn.split(key)
key, subkey = pgssm_importance_sampling(y, model, proposal.z, proposal.Omega, N, subkey)
samples, lw 4 * N), lw)
plt.scatter(jnp.arange( plt.show()
Weights should be calculated on the log scale, but we need them on the usual scale to use for Monte-Carlo integration. These are called auto-normalised weights and defined by \(W(X^i) = \frac{w(X^i)}{\sum_{i = 1}^N w(X^i)}\).
As weights are only known up to a constant, we make exponentiation numerically stable by substracting (on the log-scale) the largest weight, ensuring that \(\log w^i \leq 0\) for all weights and so \(\sum_{i = 1}^N w^i \leq N\).
normalize_weights
normalize_weights (log_weights:jaxtyping.Float[Array,'N'])
Normalize importance sampling weights.
Type | Details | |
---|---|---|
log_weights | Float[Array, ‘N’] | log importance sampling weights |
Returns | Float[Array, ‘N’] | normalized importance sampling weights |
= normalize_weights(lw)
weights None] * N)
plt.boxplot(weights[ plt.show()
Effective Sample Size
The effective sample size is an important diagnostic for the performance of importance sampling, it is defined by
\[ \text{ESS} = \frac{\left(\sum_{i = 1}^{N} w(X^i)\right)^2}{\sum_{i = 1}^N w^2(X_i)} = \frac{1}{\sum_{i = 1}^N W^2(X^i)} \]
To compare different approximations one may also be interested in \(\frac{\text{ESS}}{N} \cdot 100\%\), the percentage of effective samples.
ess_pct
ess_pct (log_weights:jaxtyping.Float[Array,'N'])
Type | Details | |
---|---|---|
log_weights | Float[Array, ‘N’] | log weights |
Returns | Float | the effective sample size in percent, also called efficiency factor |
ess_lw
ess_lw (log_weights:jaxtyping.Float[Array,'N'])
Compute the effective sample size of a set of log weights
Type | Details | |
---|---|---|
log_weights | Float[Array, ‘N’] | the log weights |
Returns | Float | the effective sample size |
ess
ess (normalized_weights:jaxtyping.Float[Array,'N'])
Compute the effective sample size of a set of normalized weights
Type | Details | |
---|---|---|
normalized_weights | Float[Array, ‘N’] | normalized weights |
Returns | Float | the effective sample size |
ess(weights), ess_pct(lw)
(Array(889.98669, dtype=float64), Array(22.24966725, dtype=float64))
Monte-Carlo Integration
mc_integration
mc_integration (samples:jaxtyping.Float[Array,'N...'], log_weights:jaxtyping.Float[Array,'N'])
= mc_integration(samples, lw)
sample_mean = (model.B @ x[..., None]).squeeze(axis=-1)
true_signal ="true signal")
plt.plot(true_signal, label="estimated mean signal")
plt.plot(sample_mean, label
plt.legend() plt.show()
Prediction
For prediction we are interested in the conditional expectations \[ \mathbf E \left(Y_{n + t} | Y_0, \dots, Y_n\right), \] where we assume that the PGSSM has some known continuation after time \(n\). We can estimate this conditional expectation by importance sampling. Given our samples \(X^{i}_n\), \(i = 1, \dots, N\), we simulate forward in time to obtain \(X^i_{n + t}\). We may then estimate the conditional expectation by \[ \sum_{i = 1}^N W^i \mathbf E \left( Y_{n + t} | X_{n + 1} = X^i_{n + t}\right), \] by the dependency structure of the model (\(Y_{n + t}\) is independent of \(Y_0, \dots, Y_n\) given \(X_{t + n}\).)
For prediction intervals, we follow the strategy provided by (Durbin and Koopman 2012), Chapter 11.5.3: sort the univariate predictions, and the corresponding weights in the same order. Then the ECDF at \(Y^i\), can be estimated as \[ \sum_{j = 1}^i W^i. \] Linearly interpolating between these values gives an ECDF which we use to create prediction intervals.
predict
predict (model:isssm.typing.PGSSM, y:jaxtyping.Float[Array,'n+1p'], proposal:isssm.typing.GLSSMProposal, future_model:isssm.typing.PGSSM, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])
future_prediction_interval
future_prediction_interval (dist, signal_samples, xi, log_weights, p)
= jrn.split(key)
key, subkey = 10
n_ahead = PGSSM(
ten_steps_ahead_model + 1],
model.u[: n_ahead
model.A[:n_ahead],
model.D[:n_ahead],
model.Sigma0,
model.Sigma[:n_ahead],+ 1],
model.v[: n_ahead + 1],
model.B[: n_ahead
model.dist,+ 1],
model.xi[: n_ahead
)= predict(
(_, s_pred, y_pred), log_weights_pred 1000, subkey
model, y, proposal, ten_steps_ahead_model,
)
= (y_pred * normalize_weights(log_weights_pred)[:, None, None]).sum()
mean_y = jnp.arange(y.shape[0])
past_inds = jnp.arange(y.shape[0], y.shape[0] + n_ahead)
future_inds = prediction_percentiles(
percentiles 0.1, 0.5, 0.9])
y_pred, normalize_weights(log_weights_pred), jnp.array([
)= percentiles
lower, mid, upper ="observed")
plt.plot(past_inds, y, label1:], label="median")
plt.plot(future_inds, mid[
plt.plot(
future_inds,1:],
lower[="--",
linestyle="grey",
color="80% prediction interval",
label
)1:], linestyle="--", color="grey")
plt.plot(future_inds, upper[
plt.legend() plt.show()
prediction
prediction (f:<built-infunctioncallable>, y, proposal:isssm.typing.GLSSMProposal, model:isssm.typing.PGSSM, N:int, key:Union[jaxtyping.Key[Arra y,''],jaxtyping.UInt32[Array,'2']], probs:jaxtyping.Float[Array,'k'], prediction_model=None)
def f(x, s, y_prime):
return y_prime[:, 0:1]
= jrn.split(key)
key, subkey = prediction(
mean, sd, quants 10000, subkey, jnp.array([0.1, 0.5, 0.9])
f, y, proposal, model,
)
="observed")
plt.plot(y, label="predicted")
plt.plot(mean, label0], linestyle="--", color="grey", label="prediction interval")
plt.plot(quants[2], linestyle="--", color="grey")
plt.plot(quants[90, 100)
plt.xlim(
plt.legend() plt.show()
2, 99], y[99] quants[
(Array([557.], dtype=float64), Array([356.], dtype=float64))
from tensorflow_probability.substrates.jax.distributions import NegativeBinomial
20, jnp.log(356) - jnp.log(20)).cdf(557) NegativeBinomial(
Array(0.98567769, dtype=float64)
future_prediction_interval(0], model.xi[0], log_weights_pred, 0.5
model.dist, s_pred[:, 0] ), mid[
(array([49.47880729]), Array([52.31199429], dtype=float64))