Partially Gaussian State Space Models with linear Signal

See also the corresponding section in my thesis
Code
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jrn

A partially gaussian state space model with linear signal is a state space model where the distribution of states \(X_t\) is gaussian, but the (conditional) distribution of observations \(Y_t\) is non-gaussian, but depends only on signals \(S_t = B_tX_t\) for a matrix \(B_t\).

That is we consider

\[ \begin{align*} X_0 &\sim \mathcal N(u_{0}, \Sigma_0)\\ X_{t + 1} &= u_{t + 1} + A_t X_t + \varepsilon_{t + 1}\\ \varepsilon_{t} &\sim \mathcal N(0, \Sigma_t)\\ Y_t | X_t &\sim Y_t | S_t \sim p(y_t|s_t). \end{align*} \]

Implementation Detail

dependency on parameters

To facilitate faster evaluation we assume that the conditional density of observations \(p(y_{t}|s_{t})\) depends on parameters \(\xi \in \mathbb R^{(n + 1)\times p \times l}\), i.e. for every observation there are exactly \(l\) parameters to consider. The implementation will always assume that we can call

dist(s, xi)

where dist is a tensorflow_probability.distributions.distribution object that for broadcasting in both \(s\) and \(\xi\).

Simulation

As the states are gaussian, we can first simulate the states \(X\) and then, conditional on them, calculate \(S\) and \(\xi\).


source

simulate_pgssm

 simulate_pgssm (pgssm:isssm.typing.PGSSM, N:int,
                 key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,
                 '2']])
Type Details
pgssm PGSSM
N int number of samples
key Union random key
Returns tuple

Running example: Negative Binomial model

As an example consider a variant of the multivariate AR(1) process model with a seasonal component where observations now follow a conditional negative binomial distribution, i.e. \[Y^i_t| X_{t} \sim \text{NegBinom}(\exp((BX_t)^{i}), r),\] independent for \(i = 1, 2\).

The states \(X_t\) consist of two components:

For the velocity component, we will model a stationary distribution with a small stationary variance. Stationarity allows us to ensure that sampling from the model will, usually, not lead to numerical issues. Due to the log-link for negative binomial observations we want states to stay within, say, \((-2, 2)\) most of the time, otherwise, we will see many \(0\) observations (below -2), or may have problems sampling when \(\mathbf E (Y^{i}_{t} | X_{t}) = \exp (BX_{t}^{i})\) becomes large.

This model has the advantage that we can check whether our implementation can handle multiple issues:

  • the states have degenerate distribution (due to the seasonal component),
  • the observations are multivariate,
  • the observations are integer-valued and
  • the observations are non-Gaussian.

We set some sensible defaults and will reuse this model throughout this documentation.


source

nb_pgssm_running_example

 nb_pgssm_running_example (x0_trend:jaxtyping.Float[Array,'m']=Array([0.,
                           0.], dtype=float64), r:jaxtyping.Float=20.0,
                           s2_trend:jaxtyping.Float=0.01,
                           s2_speed:jaxtyping.Float=0.1,
                           alpha:jaxtyping.Float=0.1,
                           omega2:jaxtyping.Float=0.01, n:int=100, x0_seas
                           onal:jaxtyping.Float[Array,'s']=Array([0., 0.,
                           0., 0.], dtype=float64),
                           s2_seasonal:jaxtyping.Float=0.1, Sigma0_seasona
                           l:jaxtyping.Float[Array,'ss']=Array([[0.1, 0. ,
                           0. , 0. ],        [0. , 0.1, 0. , 0. ],
                           [0. , 0. , 0.1, 0. ],        [0. , 0. , 0. ,
                           0.1]], dtype=float64), s_order:int=5)

a structural time series model with NBinom observations

Type Default Details
x0_trend Float[Array, ‘m’] [0. 0.]
r Float 20.0
s2_trend Float 0.01
s2_speed Float 0.1
alpha Float 0.1
omega2 Float 0.01
n int 100
x0_seasonal Float[Array, ‘s’] [0. 0. 0. 0.]
s2_seasonal Float 0.1
Sigma0_seasonal Float[Array, ‘s s’] [[0.1 0. 0. 0. ]
[0. 0.1 0. 0. ]
[0. 0. 0.1 0. ]
[0. 0. 0. 0.1]]
s_order int 5
Returns PGSSM the running example for this package
key = jrn.PRNGKey(518)
model = nb_pgssm_running_example()
N = 1
key, subkey = jrn.split(key)
(X,), (Y,) = simulate_pgssm(model, N, subkey)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
fig.tight_layout()

ax1.set_title("trend")
ax2.set_title("seasonal component")
ax3.set_title("signals")
ax4.set_title("observations")

ax1.plot(X[:, 0])
ax2.plot(X[:, 2])
ax3.plot(vmap(jnp.matmul, (0, 0))(model.B, X))
ax4.plot(Y)

plt.show()

Notice that the observations are now integer valued.

joint density

To evaluate the joint density we use the same approach as described in [00_glssm#Joint Density], replacing the observation density with the PGSSM one.


source

log_prob

 log_prob (x:jaxtyping.Float[Array,'n+1m'],
           y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM)
Type Details
x Float[Array, ‘n+1 m’] states
y Float[Array, ‘n+1 p’] observations
model PGSSM

source

log_probs_y

 log_probs_y (x:jaxtyping.Float[Array,'n+1m'],
              y:jaxtyping.Float[Array,'n+1p'],
              v:jaxtyping.Float[Array,'n+1p'],
              B:jaxtyping.Float[Array,'n+1pm'], dist, xi)
Type Details
x Float[Array, ‘n+1 m’] states
y Float[Array, ‘n+1 p’] observations
v Float[Array, ‘n+1 p’] signal biases
B Float[Array, ‘n+1 p m’] signal matrices
dist observation distribution
xi observation parameters
log_prob(X, Y, model)
Array(-17.57867495, dtype=float64)