from functools import partial
import jax.numpy as jnp
from jax import vmap
from jaxtyping import Array, Float
# | export
from tensorflow_probability.substrates.jax.distributions import \
MultivariateNormalFullCovariance as MVN
from isssm.typing import PGSSM
def log_weights_t(
s_t: Float[Array, "p"], # signal
y_t: Float[Array, "p"], # observation
xi_t: Float[Array, "p"], # parameters
dist, # observation distribution
z_t: Float[Array, "p"], # synthetic observation
Omega_t: Float[Array, "p p"], # synthetic observation covariance, assumed diagonal
) -> Float: # single log weight
"""Log weight for a single time point."""
p_ys = dist(s_t, xi_t).log_prob(y_t).sum()
# omega_t = jnp.sqrt(jnp.diag(Omega_t))
# g_zs = MVN_diag(s_t, omega_t).log_prob(z_t).sum()
g_zs = MVN(s_t, Omega_t).log_prob(z_t).sum()
return p_ys - g_zs
def log_weights(
s: Float[Array, "n+1 p"], # signals
y: Float[Array, "n+1 p"], # observations
dist, # observation distribution
xi: Float[Array, "n+1 p"], # observation parameters
z: Float[Array, "n+1 p"], # synthetic observations
Omega: Float[Array, "n+1 p p"], # synthetic observation covariances:
) -> Float: # log weights
"""Log weights for all time points"""
p_ys = dist(s, xi).log_prob(y).sum()
# avoid triangular solve problems
# omega = jnp.sqrt(vmap(jnp.diag)(Omega))
# g_zs = MVN_diag(s, omega).log_prob(z).sum()
g_zs = MVN(s, Omega).log_prob(z).sum()
return p_ys - g_zsImportance Sampling for Partially Gaussian State Space Models
After having observed \(Y\) one is usually interested in properties of the conditional distribution of states \(X\) given \(Y\). Typically this means terms of the form
\[ \begin{align*} \mathbf E (f(X) | Y) &= \mathbf E (f(X_0, \dots, X_n) | Y_0, \dots, Y_n) \\ &= \int f(x_0, \dots, x_n) p(x_0, \dots, x_n | y_0, \dots, y_n) \mathrm d x_0 \dots \mathrm d x_n. \end{align*} \]
As the density \(p(x|y)\) is known only up to a constant, we resort to importance sampling with a GLSSM, represented by its gaussian densities \(g\). The Laplace approximation and (modified) efficient importance sampling perform this task for loc concave state space models where the states are jointly gaussian.
Both methods construct surrogate linear gaussian state space models that are parameterized by synthetic observations \(z_t\) and their covariance matrices \(\Omega_t\). Usually \(\Omega_t\) is a diagonal matrix which is justified if the components of the observation vector at time \(t\), \(Y^i_t\), \(i = 1, \dots, p\) are conditionally independent given states \(X_t\).
These models are then based on the following SSM: \[ \begin{align*} X_0 &\sim \mathcal N (x_0, \Sigma_0) &&\\ X_{t + 1} &= A_t X_{t} + \varepsilon_{t + 1} &&t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) && t = 1, \dots, n \\ S_t &= B_t X_t &&\\ Z_t &= S_t + \eta_t && t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) && t=0, \dots, n. \end{align*} \]
In this setting we can transform the expectation w.r.t the condtiional density \(p(x|y)\) to one w.r.t the density \(g(x|z)\).
\[ \begin{align*} \int f(x) p(x|y) \mathrm d x &= \int f(x) \frac{p(x|y)}{g(x|z)} g(x|z) \mathrm d x\\ &= \int f(x) \frac{p(y|x)}{g(z|x)} \frac{g(z)}{p(y)} g(x|z) \mathrm d x. \end{align*} \]
Let \(w(x) = \frac{p(y|x)}{g(z|x)} = \frac{p(y|s)}{g(z|s)}\) be the (unnormalized) importance sampling weights which only depend on \(s_t = B_t x_t\), \(t = 0, \dots, n\).
Importance samples are generated from the smoothing distribution in the surrogate model, i.e. from \(g(x|z)\) using e.g. the FFBS or simulation smoother algorithm.
import matplotlib.pyplot as plt
from isssm.laplace_approximation import laplace_approximation
from isssm.pgssm import nb_pgssm_running_example, simulate_pgssmfrom functools import partial
import jax.random as jrn
# | export
from jaxtyping import Array, Float, PRNGKeyArray
from isssm.kalman import FFBS, simulation_smoother
from isssm.typing import GLSSM, PGSSM
def pgssm_importance_sampling(
y: Float[Array, "n+1 p"], # observations
model: PGSSM, # model
z: Float[Array, "n+1 p"], # synthetic observations
Omega: Float[Array, "n+1 p p"], # covariance of synthetic observations
N: int, # number of samples
key: PRNGKeyArray, # random key
) -> tuple[
Float[Array, "N n+1 m"], Float[Array, "N"]
]: # importance samples and weights
u, A, D, Sigma0, Sigma, v, B, dist, xi = model
glssm = GLSSM(u, A, D, Sigma0, Sigma, v, B, Omega)
key, subkey = jrn.split(key)
s = simulation_smoother(glssm, z, N, subkey)
model_log_weights = partial(log_weights, y=y, dist=dist, xi=xi, z=z, Omega=Omega)
lw = vmap(model_log_weights)(s)
return s, lwLet us perform importance sampling for our running example using the model obtained by the laplace approximation. Notice that we obtain four times the number of samples that we specified, which comes from the use of antithetics.
key = jrn.PRNGKey(512)
key, subkey = jrn.split(key)
s_order = 4
model = nb_pgssm_running_example(
n=100,
s_order=s_order,
Sigma0_seasonal=0.1 * jnp.eye(s_order - 1),
x0_seasonal=jnp.zeros(s_order - 1),
)
N = 1
(x,), (y,) = simulate_pgssm(model, N, subkey)
proposal, info = laplace_approximation(y, model, 100)fig, axs = plt.subplots(2, 1)
axs[0].plot(y)
axs[1].plot(x[:, 0])N = 1000
key, subkey = jrn.split(key)
samples, lw = pgssm_importance_sampling(y, model, proposal.z, proposal.Omega, N, subkey)
plt.scatter(jnp.arange(4 * N), lw)
plt.show()Weights should be calculated on the log scale, but we need them on the usual scale to use for Monte-Carlo integration. These are called auto-normalised weights and defined by \(W(X^i) = \frac{w(X^i)}{\sum_{i = 1}^N w(X^i)}\).
As weights are only known up to a constant, we make exponentiation numerically stable by substracting (on the log-scale) the largest weight, ensuring that \(\log w^i \leq 0\) for all weights and so \(\sum_{i = 1}^N w^i \leq N\).
normalize_weights
normalize_weights (log_weights:jaxtyping.Float[Array,'N'])
Normalize importance sampling weights.
| Type | Details | |
|---|---|---|
| log_weights | Float[Array, ‘N’] | log importance sampling weights |
| Returns | Float[Array, ‘N’] | normalized importance sampling weights |
weights = normalize_weights(lw)
plt.boxplot(weights[None] * N)
plt.show()Effective Sample Size
The effective sample size is an important diagnostic for the performance of importance sampling, it is defined by
\[ \text{ESS} = \frac{\left(\sum_{i = 1}^{N} w(X^i)\right)^2}{\sum_{i = 1}^N w^2(X_i)} = \frac{1}{\sum_{i = 1}^N W^2(X^i)} \]
To compare different approximations one may also be interested in \(\frac{\text{ESS}}{N} \cdot 100\%\), the percentage of effective samples.
ess_pct
ess_pct (log_weights:jaxtyping.Float[Array,'N'])
| Type | Details | |
|---|---|---|
| log_weights | Float[Array, ‘N’] | log weights |
| Returns | Float | the effective sample size in percent, also called efficiency factor |
ess_lw
ess_lw (log_weights:jaxtyping.Float[Array,'N'])
Compute the effective sample size of a set of log weights
| Type | Details | |
|---|---|---|
| log_weights | Float[Array, ‘N’] | the log weights |
| Returns | Float | the effective sample size |
ess
ess (normalized_weights:jaxtyping.Float[Array,'N'])
Compute the effective sample size of a set of normalized weights
| Type | Details | |
|---|---|---|
| normalized_weights | Float[Array, ‘N’] | normalized weights |
| Returns | Float | the effective sample size |
ess(weights), ess_pct(lw)Monte-Carlo Integration
mc_integration
mc_integration (samples:jaxtyping.Float[Array,'N...'], log_weights:jaxtyping.Float[Array,'N'])
sample_mean = mc_integration(samples, lw)
true_signal = (model.B @ x[..., None]).squeeze(axis=-1)
plt.plot(true_signal, label="true signal")
plt.plot(sample_mean, label="estimated mean signal")
plt.legend()
plt.show()Prediction
For prediction we are interested in the conditional expectations \[ \mathbf E \left(Y_{n + t} | Y_0, \dots, Y_n\right), \] where we assume that the PGSSM has some known continuation after time \(n\). We can estimate this conditional expectation by importance sampling. Given our samples \(X^{i}_n\), \(i = 1, \dots, N\), we simulate forward in time to obtain \(X^i_{n + t}\). We may then estimate the conditional expectation by \[ \sum_{i = 1}^N W^i \mathbf E \left( Y_{n + t} | X_{n + 1} = X^i_{n + t}\right), \] by the dependency structure of the model (\(Y_{n + t}\) is independent of \(Y_0, \dots, Y_n\) given \(X_{t + n}\).)
For prediction intervals, we follow the strategy provided by (Durbin and Koopman 2012), Chapter 11.5.3: sort the univariate predictions, and the corresponding weights in the same order. Then the ECDF at \(Y^i\), can be estimated as \[ \sum_{j = 1}^i W^i. \] Linearly interpolating between these values gives an ECDF which we use to create prediction intervals.
from jax import jit
from scipy.optimize import minimize
from isssm.glssm import simulate_states
from isssm.kalman import kalman
# | export
from isssm.typing import GLSSMProposal, GLSSMState
from isssm.util import mm_time_sim
def future_prediction_interval(dist, signal_samples, xi, log_weights, p):
def integer_ecdf(y):
return (
dist(signal_samples, xi).cdf(y).squeeze(axis=-1)
* normalize_weights(log_weights)
).sum()
def ecdf(y):
y_floor = jnp.floor(y)
y_ceil = jnp.ceil(y)
y_gauss = y - y_floor
return integer_ecdf(y_floor) * (1 - y_gauss) + integer_ecdf(y_ceil) * y_gauss
def pinball_loss(y, p):
return (jnp.abs(ecdf(y) - p).sum()) ** 2
mean = mc_integration(dist(signal_samples, xi).mean(), log_weights)
result = minimize(pinball_loss, mean, args=(p,), method="Nelder-Mead")
return result.x
def _prediction_percentiles(Y, weights, probs):
Y_sorted = jnp.sort(Y)
weights_sorted = weights[jnp.argsort(Y)]
cumsum = jnp.cumsum(weights_sorted)
# find indices of cumulative sum closest to probs
# take corresponding Y_sorted values
# with linear interpolation if necessary
indices = jnp.searchsorted(cumsum, probs)
indices = jnp.clip(indices, 1, len(Y_sorted) - 1)
left_indices = indices - 1
right_indices = indices
left_cumsum = cumsum[left_indices]
right_cumsum = cumsum[right_indices]
left_Y = Y_sorted[left_indices]
right_Y = Y_sorted[right_indices]
# linear interpolation
quantiles = left_Y + (probs - left_cumsum) / (right_cumsum - left_cumsum) * (
right_Y - left_Y
)
return quantiles
prediction_percentiles = vmap(
vmap(_prediction_percentiles, (1, None, None), 1), (2, None, None), 2
)
def predict(
model: PGSSM,
y: Float[Array, "n+1 p"],
proposal: GLSSMProposal,
future_model: PGSSM,
N: int,
key: PRNGKeyArray,
):
key, subkey = jrn.split(key)
signal_samples, log_weights = pgssm_importance_sampling(
y, model, proposal.z, proposal.Omega, N, subkey
)
(N,) = log_weights.shape
signal_model = GLSSM(
proposal.u,
proposal.A,
proposal.D,
proposal.Sigma0,
proposal.Sigma,
proposal.v,
proposal.B,
proposal.Omega,
)
@jit
def future_sample(signal_sample, key):
x_filt, Xi_filt, _, _ = kalman(signal_sample, signal_model)
state = GLSSMState(
future_model.u.at[0].set(x_filt[-1]),
future_model.A,
future_model.D,
Xi_filt[-1],
future_model.Sigma,
)
(x,) = simulate_states(state, 1, key)
return x
key, *subkeys = jrn.split(key, N + 1)
subkeys = jnp.array(subkeys)
future_x = vmap(future_sample)(signal_samples, subkeys)
future_s = mm_time_sim(future_model.B, future_x)
future_y = future_model.dist(future_s, future_model.xi).mean()
return (future_x, future_s, future_y), log_weightskey, subkey = jrn.split(key)
n_ahead = 10
ten_steps_ahead_model = PGSSM(
model.u[: n_ahead + 1],
model.A[:n_ahead],
model.D[:n_ahead],
model.Sigma0,
model.Sigma[:n_ahead],
model.v[: n_ahead + 1],
model.B[: n_ahead + 1],
model.dist,
model.xi[: n_ahead + 1],
)
(_, s_pred, y_pred), log_weights_pred = predict(
model, y, proposal, ten_steps_ahead_model, 1000, subkey
)
mean_y = (y_pred * normalize_weights(log_weights_pred)[:, None, None]).sum()
past_inds = jnp.arange(y.shape[0])
future_inds = jnp.arange(y.shape[0], y.shape[0] + n_ahead)
percentiles = prediction_percentiles(
y_pred, normalize_weights(log_weights_pred), jnp.array([0.1, 0.5, 0.9])
)
lower, mid, upper = percentiles
plt.plot(past_inds, y, label="observed")
plt.plot(future_inds, mid[1:], label="median")
plt.plot(
future_inds,
lower[1:],
linestyle="--",
color="grey",
label="80% prediction interval",
)
plt.plot(future_inds, upper[1:], linestyle="--", color="grey")
plt.legend()
plt.show()prediction
prediction (f:<built-infunctioncallable>, y, proposal:isssm.typing.GLSSMProposal, model:isssm.typing.PGSSM, N:int, key:Union[jaxtyping.Key[Arra y,''],jaxtyping.UInt32[Array,'2']], probs:jaxtyping.Float[Array,'k'], prediction_model=None)
def f(x, s, y_prime):
return y_prime[:, 0:1]
key, subkey = jrn.split(key)
mean, sd, quants = prediction(
f, y, proposal, model, 10000, subkey, jnp.array([0.1, 0.5, 0.9])
)
plt.plot(y, label="observed")
plt.plot(mean, label="predicted")
plt.plot(quants[0], linestyle="--", color="grey", label="prediction interval")
plt.plot(quants[2], linestyle="--", color="grey")
plt.xlim(90, 100)
plt.legend()
plt.show()quants[2, 99], y[99]from tensorflow_probability.substrates.jax.distributions import \
NegativeBinomial
NegativeBinomial(20, jnp.log(356) - jnp.log(20)).cdf(557)Array(0.98567769, dtype=float64)
future_prediction_interval(
model.dist, s_pred[:, 0], model.xi[0], log_weights_pred, 0.5
), mid[0]