# use x86 for testing purposes
"jax_enable_x64", True) jax.config.update(
Gaussian Linear State Space Models
Consider a Gaussian state space model of the form \[ \begin{align*} X_0 &\sim \mathcal N (u_{0}, \Sigma_0) &\\ X_{t + 1} &= u_{t + 1} + A_t X_{t} + D_{t}\varepsilon_{t + 1} &t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) & t = 1, \dots, n \\ Y_t &= v_{t} + B_t X_t + \eta_t & t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) & t=0, \dots, n. \end{align*} \] As the joint distribution of \((X_0, \dots, X_n, Y_0, \dots, Y_n)\) is Gaussian, we call it a Gaussian linear state space model (GLSSM).
The dimensions of the components are as follows: \[ \begin{align*} u_{t}, X_{t} &\in \mathbf R^{m} \\ \varepsilon_{t} &\in \mathbf R^{l} \\ v_{t}, Y_{t}, \eta_{t} &\in \mathbf R^{p} \end{align*} \] and \[ \begin{align*} A_{t} &\in \mathbf R^{m\times m} \\ D_{t} &\in \mathbf R^{m \times l} \\ \Sigma_{0} &\in \mathbf R^{m \times m} \\ \Sigma_{t} &\in \mathbf R^{l \times l}\\ B_{t} &\in \mathbf R^{p \times m} \\ \Omega_{t} &\in\mathbf R^{p \times p} \end{align*} \] and we assume that \(D_t\) is a submatrix of the identity matrix, such that \(D_t^T D_t = I_{l}\).
Sampling from the joint distribution
To obtain a sample \((X_0, \dots, X_n), (Y_0, \dots, Y_n)\) we first simulate from the joint distribution of the states and then, as observations are coniditionally independent of one another given the states, simulate all states at once.
simulate_states
simulate_states (state:isssm.typing.GLSSMState, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array ,'2']])
Simulate states of a GLSSM
Type | Details | |
---|---|---|
state | GLSSMState | |
N | int | number of samples to draw |
key | Union | the random state |
Returns | Float[Array, ‘N n+1 m’] | array of N samples from the state distribution |
simulate_glssm
simulate_glssm (glssm:isssm.typing.GLSSM, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array, '2']])
Simulate states and observations of a GLSSM
Type | Details | |
---|---|---|
glssm | GLSSM | |
N | int | number of sample paths |
key | Union | the random state |
Returns | (<class ‘jaxtyping.Float[Array, ’N n+1 m’]’>, <class ‘jaxtyping.Float[Array, ’N n+1 p’]’>) |
from isssm.models.stsm import stsm
As a toy example, we consider a structural time series model with trend and velocity, as well as a seasonal component of order 2, see here for details. The following code creates the model and simulates once from its joint distribution.
= stsm(jnp.zeros(2 + 1), 0.0, 0.01, 1.0, 100, jnp.eye(2 + 1), 1.0, 2)
model
= jrn.PRNGKey(53412312)
key = jrn.split(key)
key, subkey = simulate_glssm(model, 1, subkey) (X,), (Y,)
= plt.subplots(2)
fig, (ax1, ax2)
"$X_t$")
ax1.set_title(
ax1.plot(X)
"$Y_t$")
ax2.set_title(
ax2.plot(Y) plt.show()
In this figure, we see that the trend varies smoothly, while the seasonal component is visible in the observations \(Y_t\).
Joint Density
By the dependency structure of the model, the joint density factorizes as
\[ p(x,y) = \prod_{t = 0}^n p(x_{t}| x_{t -1}) p(y_{t}|x_{t}) \] where \(p(x_0|x_{-1}) = p(x_0)\). The following functions return these components or evaluate the joint density directly.
log_prob
log_prob (x:jaxtyping.Float[Array,'n+1m'], y:jaxtyping.Float[Array,'n+1p'], glssm:isssm.typing.GLSSM)
joint log probability of states and observations
Type | Details | |
---|---|---|
x | Float[Array, ‘n+1 m’] | |
y | Float[Array, ‘n+1 p’] | |
glssm | GLSSM | |
Returns | Float | \(\log p(x,y)\) |
log_probs_y
log_probs_y (y:jaxtyping.Float[Array,'n+1p'], x:jaxtyping.Float[Array,'n+1m'], obs_model:isssm.typing.GLSSMObservationModel)
log probabilities \(\log p(y_t | x_t)\)
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | the observations |
x | Float[Array, ‘n+1 m’] | the states |
obs_model | GLSSMObservationModel | the observation model |
Returns | Float[Array, ‘n+1’] | log probabilities \(\log p(y_t \vert x_t)\) |
log_probs_x
log_probs_x (x:jaxtyping.Float[Array,'n+1m'], state:isssm.typing.GLSSMState)
log probabilities \(\log p(x_t | x_{t-1})\)
Type | Details | |
---|---|---|
x | Float[Array, ‘n+1 m’] | |
state | GLSSMState | the states # the state model |
Returns | Float[Array, ‘n+1’] | log probabilities \(\log p(x_t \vert x_{t-1})\) |
fct.test_eq(log_probs_x(X, to_states(model)).shape, (np1,))
fct.test_eq(log_probs_y(Y, X, to_observation_model(model)).shape, (np1,)) fct.test_eq(log_prob(X, Y, model).shape, ())