def parameterized_lcm(theta, aux):
log_s2_eps, log_s2_eta = theta
n, x0, s2_x0 = aux
return lcm(n, x0, s2_x0, jnp.exp(log_s2_eps), jnp.exp(log_s2_eta))
theta = jnp.log(jnp.array([2.0, 3.0]))
aux = (100, 0.0, 1.0)
true_model = parameterized_lcm(theta, aux)
_, (y,) = simulate_glssm(true_model, 1, jrn.PRNGKey(15435324))
# start far away from true parameter
result_bfgs = mle_glssm(
y, parameterized_lcm, 2 * jnp.ones(2), aux, options={"return_all": True}
)
result_ad = mle_glssm_ad(y, parameterized_lcm, 2 * jnp.ones(2), aux)
result_bfgs.x - result_ad.xMaximum Likelihood Estimation
Gaussian linear models
For GLSSMs we can evaluate the likelihood analytically with a single pass of the Kalman filter. Based on the predictions \(\hat Y_{t| t - 1}\) and associated covariance matrices \(\Psi_{t + 1 | t}\) for \(t = 0, \dots n\) produced by the Kalman filter we can derive the gaussian negative log- likelihood which is given by the gaussian distribution with that mean and covariance matrix and observation \(Y_t\).
gnll_full
gnll_full (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM)
| Type | Details | |
|---|---|---|
| y | Float[Array, ‘n+1 p’] | |
| model | GLSSM | observations \(y_t\) |
gnll
gnll (y:jaxtyping.Float[Array,'n+1p'], x_pred:jaxtyping.Float[Array,'n+1m'], Xi_pred:jaxtyping.Float[Array,'n+1mm'], B:jaxtyping.Float[Array,'n+1pm'], Omega:jaxtyping.Float[Array,'n+1pp'])
Gaussian negative log-likelihood
| Type | Details | |
|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations \(y_t\) |
| x_pred | Float[Array, ‘n+1 m’] | predicted states \(\hat X_{t+1\bar t}\) |
| Xi_pred | Float[Array, ‘n+1 m m’] | predicted state covariances \(\Xi_{t+1\bar t}\) |
| B | Float[Array, ‘n+1 p m’] | state observation matrices \(B_{t}\) |
| Omega | Float[Array, ‘n+1 p p’] | observation covariances \(\Omega_{t}\) |
| Returns | Float | gaussian negative log-likelihood |
MLE in GLSSMs
For a parametrized GLSSM, that is a model that depends on parameters \(\theta\), we can use numerical optimization to find the maximum likelihood estimatior.
With these methods, the user has to take care that they provide a parametrization that is unconstrained, i.e. using \(\log\) transformations for positive parameters.
- For low dimensional state space models obtaining the gradient of the negative log likelihood may be feasible by automatic differentiation, in this case use the
mle_glssm_admethod. Otherwise the derivative free Nelder-Mead method inmle_glssmmay be favorable. - To stabilize numerical results we minimize \(\frac{1}{(n + 1)p} \log_{\theta} p(y)\) instead of \(\log p_\theta (y)\).
mle_glssm_ad
mle_glssm_ad (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, options=None)
Maximum likelihood estimation for GLSSM using automatic differentiation
| Type | Default | Details | |
|---|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
| model_fn | parameterize GLSSM | ||
| theta0 | Float[Array, ‘k’] | initial parameter guess | |
| aux | auxiliary data for the model | ||
| options | NoneType | None | options for the optimizer |
| Returns | OptimizeResults | result of MLE optimization |
mle_glssm
mle_glssm (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, options=None)
Maximum likelihood estimation for GLSSM
| Type | Default | Details | |
|---|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
| model_fn | parameterize GLSSM | ||
| theta0 | Float[Array, ‘k’] | initial parameter guess | |
| aux | auxiliary data for the model | ||
| options | NoneType | None | options for the optimizer |
| Returns | OptimizeResult | result of MLE optimization |
Numerical differentiation is much faster here, and as accurate as automatic differentiation.
import matplotlib.pyplot as plt# 2d grid on the log scale
k = 21 # number of evaluations in each dimension
log_s2_eps, log_s2_eta = jnp.meshgrid(
jnp.linspace(-3, 3, k) + theta[0], jnp.linspace(-3, 3, k) + theta[1]
)
# flatten
thetas = jnp.vstack([log_s2_eps.ravel(), log_s2_eta.ravel()]).T
def gnll_theta(theta):
return gnll_full(y, parameterized_lcm(theta, aux))
nlls = vmap(gnll_theta)(thetas)
# location of minium in nlls
i = jnp.argmin(nlls)
# location of minimum in the grid
i_eps, i_eta = i // 21, i % 21
plt.contourf(log_s2_eps, log_s2_eta, nlls.reshape(k, k), alpha=0.5)
plt.scatter(
log_s2_eps[i_eps, i_eta],
log_s2_eta[i_eps, i_eta],
c="white",
marker="x",
label="min_grid",
)
plt.scatter(theta[0], theta[1], c="r", marker="x", label="true")
plt.scatter(*result_bfgs.x, c="g", marker="x", label="$\\hat\\theta$")
plt.legend()
plt.xlabel("$\\log(\\sigma^2_\\varepsilon)$")
plt.ylabel("$\\log(\\sigma^2_\\eta)$")
plt.colorbar()
plt.show()Inference for Log-Concave State Space Models
For non-gaussian state space models we cannot evaluate the likelihood analytically but resort to simulation methods, more specifically importance sampling.
Importance Sampling is performed using a surrogate gaussian model that shares the state density \(g(x) = p(x)\) and is parameterized by synthetic observations \(z\) and their covariance matrices \(\Omega\). In this surrogate model the likeilhood \(\ell_g = g(z)\) and posterior distribution \(g(x|z)\) are tractable and we can simulate from the posterior.
Having obtained \(N\) independent samples \(X^i, i= 1, \dots, N\) from this surrogate posterior we can evaluate the likelihood \(\ell\) by Monte-Carlo integration:
\[ \begin{align*} p(y) &= \int p(x, y) \,\mathrm dx \\ &=\int \frac{p(x,y)}{g(x|z)} g(x|z) \,\mathrm dx \\ &= g(z) \int \frac{p(y|x)}{g(z|x)} g(x|z)\,\mathrm dx \\ &\approx g(z) \frac 1 N \sum_{i =1}^N w(X^i) \end{align*} \]
where \(w(X^i) = \frac{p\left(y|X^i\right)}{g\left(z|X^i\right)}\) are the unnormalized importance sampling weights. Additionally, we use the bias correction term $ $ from (Durbin and Koopman 1997), where \(s^2_w\) is the empirical variance of the weights and \(\bar w\) is their mean.
In total we estimate the negative log-likelihood by
\[ - \log p(y) \approx \ell_g - \log \left(\sum_{i=1}^N w(X^i) \right) + \log N - \frac{s^{2}_{w}}{2 N \bar w^{2}} \]
- Similar to MLE in GLSSMs, we minimize \(-\frac{1}{(n + 1)p} \log p(y)\) instead of \(-\log p(y)\).
pgnll
pgnll (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM, z:jaxtyping.Float[Array,'n+1p'], Omega:jaxtyping.Float[Array,'n+1pp'], N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])
Log-Concave Negative Log-Likelihood
| Type | Details | |
|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations |
| model | PGSSM | the model |
| z | Float[Array, ‘n+1 p’] | synthetic observations |
| Omega | Float[Array, ‘n+1 p p’] | covariance of synthetic observations |
| N | int | number of samples |
| key | Union | random key |
| Returns | Float | the approximate negative log-likelihood |
To perform maximum likelihood estimation in a parameterized log-concave state space model we have to evaluate the likelihood several times. For evaluating the likelihood at \(\theta\) we have to perform the following:
- Find a surrogate Gaussian model \(g(x,z)\) for \(p_\theta(x,y)\) (e.g. Laplace approximation or efficient importance sampling).
- Generate importance samples from these models using the Kalman smoother.
- Approximate the negative log likelihood using the methods of this module.
This makes maximum likelihood an intensive task for these kinds of models.
For an initial guess we optimize the approximatie loglikelihood with the weights component fixed at the mode, see Eq. (21) in(Durbin and Koopman 1997) for further details.
initial_theta
initial_theta (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int, options=None, jit_target=True)
Initial value for Maximum Likelihood Estimation for PGSSMs
| Type | Default | Details | |
|---|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
| model_fn | parameterized PGSSM | ||
| theta0 | Float[Array, ‘k’] | initial parameter guess | |
| aux | auxiliary data for the model | ||
| n_iter_la | int | number of LA iterations | |
| options | NoneType | None | options for the optimizer |
| jit_target | bool | True | whether to jit the function |
As an example consider a parameterized version of the running example with unknown parameters \(\sigma^2_\varepsilon\) and \(r\).
def model_fn(theta, aux) -> PGSSM:
log_s2_eps, log_r = theta
n, x0 = aux
r = jnp.exp(log_r)
s2_eps = jnp.exp(log_s2_eps)
return nb_pgssm_running_example(
s_order=0,
Sigma0_seasonal=jnp.eye(0),
x0_seasonal=jnp.zeros(0),
s2_speed=s2_eps,
r=r,
n=n,
)
n = 100
theta_lc = jnp.array([jnp.log(1), jnp.log(20.0)])
aux = (n, jnp.ones(2))
model = model_fn(theta_lc, aux)
key = jrn.PRNGKey(512)
key, subkey = jrn.split(key)
_, (y,) = simulate_pgssm(model, 1, subkey)
theta_lcinitial_result = initial_theta(y, model_fn, theta_lc, aux, 10)
theta0 = initial_result.x
initial_resultmle_pgssm
mle_pgssm (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int, N:int, key:jax.Array, options=None)
Maximum Likelihood Estimation for PGSSMs
| Type | Default | Details | |
|---|---|---|---|
| y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
| model_fn | parameterized LCSSM | ||
| theta0 | Float[Array, ‘k’] | initial parameter guess | |
| aux | auxiliary data for the model | ||
| n_iter_la | int | number of LA iterations | |
| N | int | number of importance samples | |
| key | Array | random key | |
| options | NoneType | None | options for the optimizer |
| Returns | Float[Array, ‘k’] | MLE |
key, subkey = jrn.split(key)
result = mle_pgssm(y, model_fn, theta0, aux, 10, 1000, subkey)
theta_hat = result.x
result@jit
def pgnll_full(theta, key):
model = model_fn(theta, aux)
proposal, info = laplace_approximation(y, model, 10)
key, subkey = jrn.split(key)
proposal_meis, _ = modified_efficient_importance_sampling(
y, model, proposal.z, proposal.Omega, 10, 100, subkey
)
key, subkey = jrn.split(key)
return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, 100, subkey) / y.size# 2d grid on the log scale
(log_sigma_min, log_r_min), (log_sigma_max, log_r_max) = jnp.min(
jnp.vstack((theta_lc, theta0, theta_hat)), axis=0
), jnp.max(jnp.vstack((theta_lc, theta0, theta_hat)), axis=0)
k = 20 # number of evaluations in each dimension
delta = 0.5
log_sigma, log_r = jnp.meshgrid(
jnp.linspace(log_sigma_min - delta, log_sigma_max + delta, k),
jnp.linspace(log_r_min - delta, log_r_max + delta, k),
)
# flatten
thetas = jnp.vstack([log_sigma.ravel(), log_r.ravel()]).T
key, subkey = jrn.split(key)
nlls = vmap(pgnll_full, (0, None))(thetas, subkey)
# location of minimum in nlls
i = jnp.argmin(nlls)
# location of minimum in the grid
i_sigma, i_r = i // k, i % k
plt.contourf(log_sigma, log_r, nlls.reshape(k, k))
plt.scatter(
log_sigma[i_sigma, i_r],
log_r[i_sigma, i_r],
c="white",
marker="x",
label="min_grid",
)
plt.scatter(theta_lc[0], theta_lc[1], c="r", marker="x", label="true")
plt.scatter(theta0[0], theta0[1], c="black", marker="x", label="$\\theta_0$")
plt.scatter(*theta_hat, c="g", marker="o", label="min_mle")
plt.legend()
plt.xlabel("$\\log(\\sigma^2_\\varepsilon)$")
plt.ylabel("$\\log(r)$")
plt.colorbar()
plt.show()From the above picture, we see that \(\log r\) is hard to determine: the likelihood is very flat in the \(r\) direction, which explains the precision loss warning in the optimizer. Nevertheless, our estimate \(\hat\theta\) seems to have converged to a reasonable value.