Code
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax.random as jrn
A partially gaussian state space model with linear signal is a state space model where the distribution of states \(X_t\) is gaussian, but the (conditional) distribution of observations \(Y_t\) is non-gaussian, but depends only on signals \(S_t = B_tX_t\) for a matrix \(B_t\).
That is we consider
\[ \begin{align*} X_0 &\sim \mathcal N(u_{0}, \Sigma_0)\\ X_{t + 1} &= u_{t + 1} + A_t X_t + \varepsilon_{t + 1}\\ \varepsilon_{t} &\sim \mathcal N(0, \Sigma_t)\\ Y_t | X_t &\sim Y_t | S_t \sim p(y_t|s_t). \end{align*} \]
dependency on parameters
To facilitate faster evaluation we assume that the conditional density of observations \(p(y_{t}|s_{t})\) depends on parameters \(\xi \in \mathbb R^{(n + 1)\times p \times l}\), i.e. for every observation there are exactly \(l\) parameters to consider. The implementation will always assume that we can call
where dist
is a tensorflow_probability.distributions.distribution
object that for broadcasting in both \(s\) and \(\xi\).
As the states are gaussian, we can first simulate the states \(X\) and then, conditional on them, calculate \(S\) and \(\xi\).
simulate_pgssm (pgssm:isssm.typing.PGSSM, N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array, '2']])
Type | Details | |
---|---|---|
pgssm | PGSSM | |
N | int | number of samples |
key | Union | random key |
Returns | tuple |
As an example consider a variant of the multivariate AR(1) process model with a seasonal component where observations now follow a conditional negative binomial distribution, i.e. \[Y^i_t| X_{t} \sim \text{NegBinom}(\exp((BX_t)^{i}), r),\] independent for \(i = 1, 2\).
The states \(X_t\) consist of two components:
For the velocity component, we will model a stationary distribution with a small stationary variance. Stationarity allows us to ensure that sampling from the model will, usually, not lead to numerical issues. Due to the log-link for negative binomial observations we want states to stay within, say, \((-2, 2)\) most of the time, otherwise, we will see many \(0\) observations (below -2), or may have problems sampling when \(\mathbf E (Y^{i}_{t} | X_{t}) = \exp (BX_{t}^{i})\) becomes large.
This model has the advantage that we can check whether our implementation can handle multiple issues:
We set some sensible defaults and will reuse this model throughout this documentation.
nb_pgssm_running_example (x0_trend:jaxtyping.Float[Array,'m']=Array([0., 0.], dtype=float64), r:jaxtyping.Float=20.0, s2_trend:jaxtyping.Float=0.01, s2_speed:jaxtyping.Float=0.1, alpha:jaxtyping.Float=0.1, omega2:jaxtyping.Float=0.01, n:int=100, x0_seas onal:jaxtyping.Float[Array,'s']=Array([0., 0., 0., 0.], dtype=float64), s2_seasonal:jaxtyping.Float=0.1, Sigma0_seasona l:jaxtyping.Float[Array,'ss']=Array([[0.1, 0. , 0. , 0. ], [0. , 0.1, 0. , 0. ], [0. , 0. , 0.1, 0. ], [0. , 0. , 0. , 0.1]], dtype=float64), s_order:int=5)
a structural time series model with NBinom observations
Type | Default | Details | |
---|---|---|---|
x0_trend | Float[Array, ‘m’] | [0. 0.] | |
r | Float | 20.0 | |
s2_trend | Float | 0.01 | |
s2_speed | Float | 0.1 | |
alpha | Float | 0.1 | |
omega2 | Float | 0.01 | |
n | int | 100 | |
x0_seasonal | Float[Array, ‘s’] | [0. 0. 0. 0.] | |
s2_seasonal | Float | 0.1 | |
Sigma0_seasonal | Float[Array, ‘s s’] | [[0.1 0. 0. 0. ] [0. 0.1 0. 0. ] [0. 0. 0.1 0. ] [0. 0. 0. 0.1]] |
|
s_order | int | 5 | |
Returns | PGSSM | the running example for this package |
key = jrn.PRNGKey(518)
model = nb_pgssm_running_example()
N = 1
key, subkey = jrn.split(key)
(X,), (Y,) = simulate_pgssm(model, N, subkey)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
fig.tight_layout()
ax1.set_title("trend")
ax2.set_title("seasonal component")
ax3.set_title("signals")
ax4.set_title("observations")
ax1.plot(X[:, 0])
ax2.plot(X[:, 2])
ax3.plot(vmap(jnp.matmul, (0, 0))(model.B, X))
ax4.plot(Y)
plt.show()
Notice that the observations are now integer valued.
To evaluate the joint density we use the same approach as described in [00_glssm#Joint Density], replacing the observation density with the PGSSM one.
log_prob (x:jaxtyping.Float[Array,'n+1m'], y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM)
Type | Details | |
---|---|---|
x | Float[Array, ‘n+1 m’] | states |
y | Float[Array, ‘n+1 p’] | observations |
model | PGSSM |
log_probs_y (x:jaxtyping.Float[Array,'n+1m'], y:jaxtyping.Float[Array,'n+1p'], v:jaxtyping.Float[Array,'n+1p'], B:jaxtyping.Float[Array,'n+1pm'], dist, xi)
Type | Details | |
---|---|---|
x | Float[Array, ‘n+1 m’] | states |
y | Float[Array, ‘n+1 p’] | observations |
v | Float[Array, ‘n+1 p’] | signal biases |
B | Float[Array, ‘n+1 p m’] | signal matrices |
dist | observation distribution | |
xi | observation parameters |