Basic Structural Time Series Model

2-nd order stationary time series with seasonal component
from isssm.glssm import simulate_glssm
import jax.random as jrn
import matplotlib.pyplot as plt
from isssm.kalman import kalman, smoother
from isssm.models.glssm import mv_ar1

We implement the univariate model from Chapter 3.2.2 in (Durbin and Koopman 2012) and refer the reader to their discussion.


source

stsm

 stsm (x0:jaxtyping.Float[Array,'m'], s2_mu:jaxtyping.Float,
       s2_nu:jaxtyping.Float, s2_seasonal:jaxtyping.Float, n:int,
       Sigma0:jaxtyping.Float[Array,'mm'], o2:jaxtyping.Float,
       s_order:int, alpha_velocity:jaxtyping.Float=1.0)
Type Default Details
x0 Float[Array, ‘m’] initial state
s2_mu Float variance of trend innovations
s2_nu Float variance of velocity innovations
s2_seasonal Float variance of velocity innovations
n int number of time points
Sigma0 Float[Array, ‘m m’] initial state covariance
o2 Float variance of observation noise
s_order int order of seasonal component
alpha_velocity Float 1.0 dampening factor for velocity
Returns GLSSM
s_ord = 4
glssm = stsm(
    jnp.zeros(2 + s_ord - 1),
    0.0,
    0.01,
    1.0,
    100,
    jnp.eye(2 + s_ord - 1),
    1.0,
    s_ord,
    0.5,
)
key = jrn.PRNGKey(534512423)
key, subkey = jrn.split(key)
(x,), (y,) = simulate_glssm(glssm, 1, subkey)

x_smooth, _ = smoother(kalman(y, glssm), glssm.A)

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
fig.tight_layout()
axs[0].set_title("observations")
axs[0].plot(y)
axs[1].set_title("states")
axs[1].plot(x)
axs[2].set_title("smoothed states")
axs[2].plot(x_smooth)
plt.show()

References

Durbin, J., and S. J. Koopman. 2012. Time Series Analysis by State Space Methods. 2nd ed. Oxford Statistical Science Series 38. Oxford: Oxford University Press.