# use x86 for testing purposes
jax.config.update("jax_enable_x64", True)Gaussian Linear State Space Models
Consider a Gaussian state space model of the form \[ \begin{align*} X_0 &\sim \mathcal N (u_{0}, \Sigma_0) &\\ X_{t + 1} &= u_{t + 1} + A_t X_{t} + D_{t}\varepsilon_{t + 1} &t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) & t = 1, \dots, n \\ Y_t &= v_{t} + B_t X_t + \eta_t & t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) & t=0, \dots, n. \end{align*} \] As the joint distribution of \((X_0, \dots, X_n, Y_0, \dots, Y_n)\) is Gaussian, we call it a Gaussian linear state space model (GLSSM).
The dimensions of the components are as follows: \[ \begin{align*} u_{t}, X_{t} &\in \mathbf R^{m} \\ \varepsilon_{t} &\in \mathbf R^{l} \\ v_{t}, Y_{t}, \eta_{t} &\in \mathbf R^{p} \end{align*} \] and \[ \begin{align*} A_{t} &\in \mathbf R^{m\times m} \\ D_{t} &\in \mathbf R^{m \times l} \\ \Sigma_{0} &\in \mathbf R^{m \times m} \\ \Sigma_{t} &\in \mathbf R^{l \times l}\\ B_{t} &\in \mathbf R^{p \times m} \\ \Omega_{t} &\in\mathbf R^{p \times p} \end{align*} \] and we assume that \(D_t\) is a submatrix of the identity matrix, such that \(D_t^T D_t = I_{l}\).
Sampling from the joint distribution
To obtain a sample \((X_0, \dots, X_n), (Y_0, \dots, Y_n)\) we first simulate from the joint distribution of the states and then, as observations are coniditionally independent of one another given the states, simulate all states at once.
simulate_states
simulate_states (state:isssm.typing.GLSSMState, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array ,'2']])
Simulate states of a GLSSM
| Type | Details | |
|---|---|---|
| state | GLSSMState | |
| N | int | number of samples to draw |
| key | Union | the random state |
| Returns | Float[Array, ‘N n+1 m’] | array of N samples from the state distribution |
simulate_glssm
simulate_glssm (glssm:isssm.typing.GLSSM, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array, '2']])
Simulate states and observations of a GLSSM
| Type | Details | |
|---|---|---|
| glssm | GLSSM | |
| N | int | number of sample paths |
| key | Union | the random state |
| Returns | (<class ‘jaxtyping.Float[Array, ’N n+1 m’]’>, <class ‘jaxtyping.Float[Array, ’N n+1 p’]’>) |
from isssm.models.stsm import stsmAs a toy example, we consider a structural time series model with trend and velocity, as well as a seasonal component of order 2, see here for details. The following code creates the model and simulates once from its joint distribution.
model = stsm(jnp.zeros(2 + 1), 0.0, 0.01, 1.0, 100, jnp.eye(2 + 1), 1.0, 2)
key = jrn.PRNGKey(53412312)
key, subkey = jrn.split(key)
(X,), (Y,) = simulate_glssm(model, 1, subkey)fig, (ax1, ax2) = plt.subplots(2)
ax1.set_title("$X_t$")
ax1.plot(X)
ax2.set_title("$Y_t$")
ax2.plot(Y)
plt.show()In this figure, we see that the trend varies smoothly, while the seasonal component is visible in the observations \(Y_t\).
Joint Density
By the dependency structure of the model, the joint density factorizes as
\[ p(x,y) = \prod_{t = 0}^n p(x_{t}| x_{t -1}) p(y_{t}|x_{t}) \] where \(p(x_0|x_{-1}) = p(x_0)\). The following functions return these components or evaluate the joint density directly.
log_prob
log_prob (x:jaxtyping.Float[Array,'n+1m'], y:jaxtyping.Float[Array,'n+1p'], glssm:isssm.typing.GLSSM)
joint log probability of states and observations
| Type | Details | |
|---|---|---|
| x | Float[Array, ‘n+1 m’] | |
| y | Float[Array, ‘n+1 p’] | |
| glssm | GLSSM | |
| Returns | Float | \(\log p(x,y)\) |
log_probs_y
log_probs_y (y:jaxtyping.Float[Array,'n+1p'], x:jaxtyping.Float[Array,'n+1m'], obs_model:isssm.typing.GLSSMObservationModel)
log probabilities \(\log p(y_t | x_t)\)
| Type | Details | |
|---|---|---|
| y | Float[Array, ‘n+1 p’] | the observations |
| x | Float[Array, ‘n+1 m’] | the states |
| obs_model | GLSSMObservationModel | the observation model |
| Returns | Float[Array, ‘n+1’] | log probabilities \(\log p(y_t \vert x_t)\) |
log_probs_x
log_probs_x (x:jaxtyping.Float[Array,'n+1m'], state:isssm.typing.GLSSMState)
log probabilities \(\log p(x_t | x_{t-1})\)
| Type | Details | |
|---|---|---|
| x | Float[Array, ‘n+1 m’] | |
| state | GLSSMState | the states # the state model |
| Returns | Float[Array, ‘n+1’] | log probabilities \(\log p(x_t \vert x_{t-1})\) |
fct.test_eq(log_probs_x(X, to_states(model)).shape, (np1,))
fct.test_eq(log_probs_y(Y, X, to_observation_model(model)).shape, (np1,))
fct.test_eq(log_prob(X, Y, model).shape, ())