def parameterized_lcm(theta, aux):
= theta
log_s2_eps, log_s2_eta = aux
n, x0, s2_x0
return lcm(n, x0, s2_x0, jnp.exp(log_s2_eps), jnp.exp(log_s2_eta))
= jnp.log(jnp.array([2.0, 3.0]))
theta = (100, 0.0, 1.0)
aux = parameterized_lcm(theta, aux)
true_model = simulate_glssm(true_model, 1, jrn.PRNGKey(15435324))
_, (y,)
# start far away from true parameter
= mle_glssm(
result_bfgs 2 * jnp.ones(2), aux, options={"return_all": True}
y, parameterized_lcm,
)= mle_glssm_ad(y, parameterized_lcm, 2 * jnp.ones(2), aux)
result_ad
- result_ad.x result_bfgs.x
Maximum Likelihood Estimation
Gaussian linear models
For GLSSMs we can evaluate the likelihood analytically with a single pass of the Kalman filter. Based on the predictions \(\hat Y_{t| t - 1}\) and associated covariance matrices \(\Psi_{t + 1 | t}\) for \(t = 0, \dots n\) produced by the Kalman filter we can derive the gaussian negative log- likelihood which is given by the gaussian distribution with that mean and covariance matrix and observation \(Y_t\).
gnll_full
gnll_full (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM)
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | |
model | GLSSM | observations \(y_t\) |
gnll
gnll (y:jaxtyping.Float[Array,'n+1p'], x_pred:jaxtyping.Float[Array,'n+1m'], Xi_pred:jaxtyping.Float[Array,'n+1mm'], B:jaxtyping.Float[Array,'n+1pm'], Omega:jaxtyping.Float[Array,'n+1pp'])
Gaussian negative log-likelihood
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | observations \(y_t\) |
x_pred | Float[Array, ‘n+1 m’] | predicted states \(\hat X_{t+1\bar t}\) |
Xi_pred | Float[Array, ‘n+1 m m’] | predicted state covariances \(\Xi_{t+1\bar t}\) |
B | Float[Array, ‘n+1 p m’] | state observation matrices \(B_{t}\) |
Omega | Float[Array, ‘n+1 p p’] | observation covariances \(\Omega_{t}\) |
Returns | Float | gaussian negative log-likelihood |
MLE in GLSSMs
For a parametrized GLSSM, that is a model that depends on parameters \(\theta\), we can use numerical optimization to find the maximum likelihood estimatior.
With these methods, the user has to take care that they provide a parametrization that is unconstrained, i.e. using \(\log\) transformations for positive parameters.
- For low dimensional state space models obtaining the gradient of the negative log likelihood may be feasible by automatic differentiation, in this case use the
mle_glssm_ad
method. Otherwise the derivative free Nelder-Mead method inmle_glssm
may be favorable. - To stabilize numerical results we minimize \(\frac{1}{(n + 1)p} \log_{\theta} p(y)\) instead of \(\log p_\theta (y)\).
mle_glssm_ad
mle_glssm_ad (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, options=None)
Maximum likelihood estimation for GLSSM using automatic differentiation
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
model_fn | parameterize GLSSM | ||
theta0 | Float[Array, ‘k’] | initial parameter guess | |
aux | auxiliary data for the model | ||
options | NoneType | None | options for the optimizer |
Returns | OptimizeResults | result of MLE optimization |
mle_glssm
mle_glssm (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, options=None)
Maximum likelihood estimation for GLSSM
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
model_fn | parameterize GLSSM | ||
theta0 | Float[Array, ‘k’] | initial parameter guess | |
aux | auxiliary data for the model | ||
options | NoneType | None | options for the optimizer |
Returns | OptimizeResult | result of MLE optimization |
Numerical differentiation is much faster here, and as accurate as automatic differentiation.
import matplotlib.pyplot as plt
# 2d grid on the log scale
= 21 # number of evaluations in each dimension
k = jnp.meshgrid(
log_s2_eps, log_s2_eta -3, 3, k) + theta[0], jnp.linspace(-3, 3, k) + theta[1]
jnp.linspace(
)# flatten
= jnp.vstack([log_s2_eps.ravel(), log_s2_eta.ravel()]).T
thetas
def gnll_theta(theta):
return gnll_full(y, parameterized_lcm(theta, aux))
= vmap(gnll_theta)(thetas)
nlls # location of minium in nlls
= jnp.argmin(nlls)
i # location of minimum in the grid
= i // 21, i % 21
i_eps, i_eta
=0.5)
plt.contourf(log_s2_eps, log_s2_eta, nlls.reshape(k, k), alpha
plt.scatter(
log_s2_eps[i_eps, i_eta],
log_s2_eta[i_eps, i_eta],="white",
c="x",
marker="min_grid",
label
)0], theta[1], c="r", marker="x", label="true")
plt.scatter(theta[*result_bfgs.x, c="g", marker="x", label="$\\hat\\theta$")
plt.scatter(
plt.legend()"$\\log(\\sigma^2_\\varepsilon)$")
plt.xlabel("$\\log(\\sigma^2_\\eta)$")
plt.ylabel(
plt.colorbar() plt.show()
Inference for Log-Concave State Space Models
For non-gaussian state space models we cannot evaluate the likelihood analytically but resort to simulation methods, more specifically importance sampling.
Importance Sampling is performed using a surrogate gaussian model that shares the state density \(g(x) = p(x)\) and is parameterized by synthetic observations \(z\) and their covariance matrices \(\Omega\). In this surrogate model the likeilhood \(\ell_g = g(z)\) and posterior distribution \(g(x|z)\) are tractable and we can simulate from the posterior.
Having obtained \(N\) independent samples \(X^i, i= 1, \dots, N\) from this surrogate posterior we can evaluate the likelihood \(\ell\) by Monte-Carlo integration:
\[ \begin{align*} p(y) &= \int p(x, y) \,\mathrm dx \\ &=\int \frac{p(x,y)}{g(x|z)} g(x|z) \,\mathrm dx \\ &= g(z) \int \frac{p(y|x)}{g(z|x)} g(x|z)\,\mathrm dx \\ &\approx g(z) \frac 1 N \sum_{i =1}^N w(X^i) \end{align*} \]
where \(w(X^i) = \frac{p\left(y|X^i\right)}{g\left(z|X^i\right)}\) are the unnormalized importance sampling weights. Additionally, we use the bias correction term $ $ from (Durbin and Koopman 1997), where \(s^2_w\) is the empirical variance of the weights and \(\bar w\) is their mean.
In total we estimate the negative log-likelihood by
\[ - \log p(y) \approx \ell_g - \log \left(\sum_{i=1}^N w(X^i) \right) + \log N - \frac{s^{2}_{w}}{2 N \bar w^{2}} \]
- Similar to MLE in GLSSMs, we minimize \(-\frac{1}{(n + 1)p} \log p(y)\) instead of \(-\log p(y)\).
pgnll
pgnll (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM, z:jaxtyping.Float[Array,'n+1p'], Omega:jaxtyping.Float[Array,'n+1pp'], N:int, key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])
Log-Concave Negative Log-Likelihood
Type | Details | |
---|---|---|
y | Float[Array, ‘n+1 p’] | observations |
model | PGSSM | the model |
z | Float[Array, ‘n+1 p’] | synthetic observations |
Omega | Float[Array, ‘n+1 p p’] | covariance of synthetic observations |
N | int | number of samples |
key | Union | random key |
Returns | Float | the approximate negative log-likelihood |
To perform maximum likelihood estimation in a parameterized log-concave state space model we have to evaluate the likelihood several times. For evaluating the likelihood at \(\theta\) we have to perform the following:
- Find a surrogate Gaussian model \(g(x,z)\) for \(p_\theta(x,y)\) (e.g. Laplace approximation or efficient importance sampling).
- Generate importance samples from these models using the Kalman smoother.
- Approximate the negative log likelihood using the methods of this module.
This makes maximum likelihood an intensive task for these kinds of models.
For an initial guess we optimize the approximatie loglikelihood with the weights component fixed at the mode, see Eq. (21) in(Durbin and Koopman 1997) for further details.
initial_theta
initial_theta (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int, options=None, jit_target=True)
Initial value for Maximum Likelihood Estimation for PGSSMs
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
model_fn | parameterized PGSSM | ||
theta0 | Float[Array, ‘k’] | initial parameter guess | |
aux | auxiliary data for the model | ||
n_iter_la | int | number of LA iterations | |
options | NoneType | None | options for the optimizer |
jit_target | bool | True | whether to jit the function |
As an example consider a parameterized version of the running example with unknown parameters \(\sigma^2_\varepsilon\) and \(r\).
def model_fn(theta, aux) -> PGSSM:
= theta
log_s2_eps, log_r
= aux
n, x0
= jnp.exp(log_r)
r = jnp.exp(log_s2_eps)
s2_eps return nb_pgssm_running_example(
=0,
s_order=jnp.eye(0),
Sigma0_seasonal=jnp.zeros(0),
x0_seasonal=s2_eps,
s2_speed=r,
r=n,
n
)
= 100
n = jnp.array([jnp.log(1), jnp.log(20.0)])
theta_lc = (n, jnp.ones(2))
aux = model_fn(theta_lc, aux)
model = jrn.PRNGKey(512)
key = jrn.split(key)
key, subkey = simulate_pgssm(model, 1, subkey)
_, (y,) theta_lc
= initial_theta(y, model_fn, theta_lc, aux, 10)
initial_result = initial_result.x
theta0 initial_result
mle_pgssm
mle_pgssm (y:jaxtyping.Float[Array,'n+1p'], model_fn, theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int, N:int, key:jax.Array, options=None)
Maximum Likelihood Estimation for PGSSMs
Type | Default | Details | |
---|---|---|---|
y | Float[Array, ‘n+1 p’] | observations \(y_t\) | |
model_fn | parameterized LCSSM | ||
theta0 | Float[Array, ‘k’] | initial parameter guess | |
aux | auxiliary data for the model | ||
n_iter_la | int | number of LA iterations | |
N | int | number of importance samples | |
key | Array | random key | |
options | NoneType | None | options for the optimizer |
Returns | Float[Array, ‘k’] | MLE |
= jrn.split(key)
key, subkey = mle_pgssm(y, model_fn, theta0, aux, 10, 1000, subkey)
result = result.x
theta_hat result
@jit
def pgnll_full(theta, key):
= model_fn(theta, aux)
model
= laplace_approximation(y, model, 10)
proposal, info = jrn.split(key)
key, subkey = modified_efficient_importance_sampling(
proposal_meis, _ 10, 100, subkey
y, model, proposal.z, proposal.Omega,
)
= jrn.split(key)
key, subkey return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, 100, subkey) / y.size
# 2d grid on the log scale
= jnp.min(
(log_sigma_min, log_r_min), (log_sigma_max, log_r_max) =0
jnp.vstack((theta_lc, theta0, theta_hat)), axismax(jnp.vstack((theta_lc, theta0, theta_hat)), axis=0)
), jnp.= 20 # number of evaluations in each dimension
k = 0.5
delta = jnp.meshgrid(
log_sigma, log_r - delta, log_sigma_max + delta, k),
jnp.linspace(log_sigma_min - delta, log_r_max + delta, k),
jnp.linspace(log_r_min
)# flatten
= jnp.vstack([log_sigma.ravel(), log_r.ravel()]).T
thetas
= jrn.split(key)
key, subkey = vmap(pgnll_full, (0, None))(thetas, subkey)
nlls # location of minimum in nlls
= jnp.argmin(nlls)
i # location of minimum in the grid
= i // k, i % k
i_sigma, i_r
plt.contourf(log_sigma, log_r, nlls.reshape(k, k))
plt.scatter(
log_sigma[i_sigma, i_r],
log_r[i_sigma, i_r],="white",
c="x",
marker="min_grid",
label
)0], theta_lc[1], c="r", marker="x", label="true")
plt.scatter(theta_lc[0], theta0[1], c="black", marker="x", label="$\\theta_0$")
plt.scatter(theta0[*theta_hat, c="g", marker="o", label="min_mle")
plt.scatter(
plt.legend()"$\\log(\\sigma^2_\\varepsilon)$")
plt.xlabel("$\\log(r)$")
plt.ylabel(
plt.colorbar() plt.show()
From the above picture, we see that \(\log r\) is hard to determine: the likelihood is very flat in the \(r\) direction, which explains the precision loss warning in the optimizer. Nevertheless, our estimate \(\hat\theta\) seems to have converged to a reasonable value.