from isssm.glssm import simulate_glssm
import jax.random as jrn
import matplotlib.pyplot as plt
from isssm.kalman import kalman, smoother
from isssm.models.glssm import mv_ar1
Basic Structural Time Series Model
2-nd order stationary time series with seasonal component
We implement the univariate model from Chapter 3.2.2 in (Durbin and Koopman 2012) and refer the reader to their discussion.
stsm
stsm (x0:jaxtyping.Float[Array,'m'], s2_mu:jaxtyping.Float, s2_nu:jaxtyping.Float, s2_seasonal:jaxtyping.Float, n:int, Sigma0:jaxtyping.Float[Array,'mm'], o2:jaxtyping.Float, s_order:int, alpha_velocity:jaxtyping.Float=1.0)
Type | Default | Details | |
---|---|---|---|
x0 | Float[Array, ‘m’] | initial state | |
s2_mu | Float | variance of trend innovations | |
s2_nu | Float | variance of velocity innovations | |
s2_seasonal | Float | variance of velocity innovations | |
n | int | number of time points | |
Sigma0 | Float[Array, ‘m m’] | initial state covariance | |
o2 | Float | variance of observation noise | |
s_order | int | order of seasonal component | |
alpha_velocity | Float | 1.0 | dampening factor for velocity |
Returns | GLSSM |
= 4
s_ord = stsm(
glssm 2 + s_ord - 1),
jnp.zeros(0.0,
0.01,
1.0,
100,
2 + s_ord - 1),
jnp.eye(1.0,
s_ord,0.5,
)= jrn.PRNGKey(534512423)
key = jrn.split(key)
key, subkey = simulate_glssm(glssm, 1, subkey)
(x,), (y,)
= smoother(kalman(y, glssm), glssm.A)
x_smooth, _
= plt.subplots(1, 3, figsize=(12, 4))
fig, axs
fig.tight_layout()0].set_title("observations")
axs[0].plot(y)
axs[1].set_title("states")
axs[1].plot(x)
axs[2].set_title("smoothed states")
axs[2].plot(x_smooth)
axs[ plt.show()
References
Durbin, J., and S. J. Koopman. 2012. Time Series Analysis by State Space Methods. 2nd ed. Oxford Statistical Science Series 38. Oxford: Oxford University Press.