Gaussian Linear State Space Models

Simulation and components, see also the corresponding chapter in my thesis.
# use x86 for testing purposes
jax.config.update("jax_enable_x64", True)

Consider a Gaussian state space model of the form \[ \begin{align*} X_0 &\sim \mathcal N (u_{0}, \Sigma_0) &\\ X_{t + 1} &= u_{t + 1} + A_t X_{t} + D_{t}\varepsilon_{t + 1} &t = 0, \dots, n - 1\\ \varepsilon_t &\sim \mathcal N (0, \Sigma_t) & t = 1, \dots, n \\ Y_t &= v_{t} + B_t X_t + \eta_t & t =0, \dots, n & \\ \eta_t &\sim \mathcal N(0, \Omega_t) & t=0, \dots, n. \end{align*} \] As the joint distribution of \((X_0, \dots, X_n, Y_0, \dots, Y_n)\) is Gaussian, we call it a Gaussian linear state space model (GLSSM).

The dimensions of the components are as follows: \[ \begin{align*} u_{t}, X_{t} &\in \mathbf R^{m} \\ \varepsilon_{t} &\in \mathbf R^{l} \\ v_{t}, Y_{t}, \eta_{t} &\in \mathbf R^{p} \end{align*} \] and \[ \begin{align*} A_{t} &\in \mathbf R^{m\times m} \\ D_{t} &\in \mathbf R^{m \times l} \\ \Sigma_{0} &\in \mathbf R^{m \times m} \\ \Sigma_{t} &\in \mathbf R^{l \times l}\\ B_{t} &\in \mathbf R^{p \times m} \\ \Omega_{t} &\in\mathbf R^{p \times p} \end{align*} \] and we assume that \(D_t\) is a submatrix of the identity matrix, such that \(D_t^T D_t = I_{l}\).

Sampling from the joint distribution

To obtain a sample \((X_0, \dots, X_n), (Y_0, \dots, Y_n)\) we first simulate from the joint distribution of the states and then, as observations are coniditionally independent of one another given the states, simulate all states at once.


source

simulate_states

 simulate_states (state:isssm.typing.GLSSMState, N:int,
                  key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array
                  ,'2']])

Simulate states of a GLSSM

Type Details
state GLSSMState
N int number of samples to draw
key Union the random state
Returns Float[Array, ‘N n+1 m’] array of N samples from the state distribution

source

simulate_glssm

 simulate_glssm (glssm:isssm.typing.GLSSM, N:int,
                 key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,
                 '2']])

Simulate states and observations of a GLSSM

Type Details
glssm GLSSM
N int number of sample paths
key Union the random state
Returns (<class ‘jaxtyping.Float[Array, ’N n+1 m’]’>, <class ‘jaxtyping.Float[Array, ’N n+1 p’]’>)
from isssm.models.stsm import stsm

As a toy example, we consider a structural time series model with trend and velocity, as well as a seasonal component of order 2, see here for details. The following code creates the model and simulates once from its joint distribution.

model = stsm(jnp.zeros(2 + 1), 0.0, 0.01, 1.0, 100, jnp.eye(2 + 1), 1.0, 2)

key = jrn.PRNGKey(53412312)
key, subkey = jrn.split(key)
(X,), (Y,) = simulate_glssm(model, 1, subkey)
fig, (ax1, ax2) = plt.subplots(2)

ax1.set_title("$X_t$")
ax1.plot(X)

ax2.set_title("$Y_t$")
ax2.plot(Y)
plt.show()

In this figure, we see that the trend varies smoothly, while the seasonal component is visible in the observations \(Y_t\).

Joint Density

By the dependency structure of the model, the joint density factorizes as

\[ p(x,y) = \prod_{t = 0}^n p(x_{t}| x_{t -1}) p(y_{t}|x_{t}) \] where \(p(x_0|x_{-1}) = p(x_0)\). The following functions return these components or evaluate the joint density directly.


source

log_prob

 log_prob (x:jaxtyping.Float[Array,'n+1m'],
           y:jaxtyping.Float[Array,'n+1p'], glssm:isssm.typing.GLSSM)

joint log probability of states and observations

Type Details
x Float[Array, ‘n+1 m’]
y Float[Array, ‘n+1 p’]
glssm GLSSM
Returns Float \(\log p(x,y)\)

source

log_probs_y

 log_probs_y (y:jaxtyping.Float[Array,'n+1p'],
              x:jaxtyping.Float[Array,'n+1m'],
              obs_model:isssm.typing.GLSSMObservationModel)

log probabilities \(\log p(y_t | x_t)\)

Type Details
y Float[Array, ‘n+1 p’] the observations
x Float[Array, ‘n+1 m’] the states
obs_model GLSSMObservationModel the observation model
Returns Float[Array, ‘n+1’] log probabilities \(\log p(y_t \vert x_t)\)

source

log_probs_x

 log_probs_x (x:jaxtyping.Float[Array,'n+1m'],
              state:isssm.typing.GLSSMState)

log probabilities \(\log p(x_t | x_{t-1})\)

Type Details
x Float[Array, ‘n+1 m’]
state GLSSMState the states # the state model
Returns Float[Array, ‘n+1’] log probabilities \(\log p(x_t \vert x_{t-1})\)
fct.test_eq(log_probs_x(X, to_states(model)).shape, (np1,))
fct.test_eq(log_probs_y(Y, X, to_observation_model(model)).shape, (np1,))
fct.test_eq(log_prob(X, Y, model).shape, ())