Maximum Likelihood Estimation

See also the corresponding section in my thesis

Gaussian linear models

For GLSSMs we can evaluate the likelihood analytically with a single pass of the Kalman filter. Based on the predictions \(\hat Y_{t| t - 1}\) and associated covariance matrices \(\Psi_{t + 1 | t}\) for \(t = 0, \dots n\) produced by the Kalman filter we can derive the gaussian negative log- likelihood which is given by the gaussian distribution with that mean and covariance matrix and observation \(Y_t\).


source

gnll_full

 gnll_full (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.GLSSM)
Type Details
y Float[Array, ‘n+1 p’]
model GLSSM observations \(y_t\)

source

gnll

 gnll (y:jaxtyping.Float[Array,'n+1p'],
       x_pred:jaxtyping.Float[Array,'n+1m'],
       Xi_pred:jaxtyping.Float[Array,'n+1mm'],
       B:jaxtyping.Float[Array,'n+1pm'],
       Omega:jaxtyping.Float[Array,'n+1pp'])

Gaussian negative log-likelihood

Type Details
y Float[Array, ‘n+1 p’] observations \(y_t\)
x_pred Float[Array, ‘n+1 m’] predicted states \(\hat X_{t+1\bar t}\)
Xi_pred Float[Array, ‘n+1 m m’] predicted state covariances \(\Xi_{t+1\bar t}\)
B Float[Array, ‘n+1 p m’] state observation matrices \(B_{t}\)
Omega Float[Array, ‘n+1 p p’] observation covariances \(\Omega_{t}\)
Returns Float gaussian negative log-likelihood

MLE in GLSSMs

For a parametrized GLSSM, that is a model that depends on parameters \(\theta\), we can use numerical optimization to find the maximum likelihood estimatior.

Caution

With these methods, the user has to take care that they provide a parametrization that is unconstrained, i.e. using \(\log\) transformations for positive parameters.

Implementation Details
  1. For low dimensional state space models obtaining the gradient of the negative log likelihood may be feasible by automatic differentiation, in this case use the mle_glssm_ad method. Otherwise the derivative free Nelder-Mead method in mle_glssm may be favorable.
  2. To stabilize numerical results we minimize \(\frac{1}{(n + 1)p} \log_{\theta} p(y)\) instead of \(\log p_\theta (y)\).

source

mle_glssm_ad

 mle_glssm_ad (y:jaxtyping.Float[Array,'n+1p'], model_fn,
               theta0:jaxtyping.Float[Array,'k'], aux, options=None)

Maximum likelihood estimation for GLSSM using automatic differentiation

Type Default Details
y Float[Array, ‘n+1 p’] observations \(y_t\)
model_fn parameterize GLSSM
theta0 Float[Array, ‘k’] initial parameter guess
aux auxiliary data for the model
options NoneType None options for the optimizer
Returns OptimizeResults result of MLE optimization

source

mle_glssm

 mle_glssm (y:jaxtyping.Float[Array,'n+1p'], model_fn,
            theta0:jaxtyping.Float[Array,'k'], aux, options=None)

Maximum likelihood estimation for GLSSM

Type Default Details
y Float[Array, ‘n+1 p’] observations \(y_t\)
model_fn parameterize GLSSM
theta0 Float[Array, ‘k’] initial parameter guess
aux auxiliary data for the model
options NoneType None options for the optimizer
Returns OptimizeResult result of MLE optimization
def parameterized_lcm(theta, aux):
    log_s2_eps, log_s2_eta = theta
    n, x0, s2_x0 = aux

    return lcm(n, x0, s2_x0, jnp.exp(log_s2_eps), jnp.exp(log_s2_eta))


theta = jnp.log(jnp.array([2.0, 3.0]))
aux = (100, 0.0, 1.0)
true_model = parameterized_lcm(theta, aux)
_, (y,) = simulate_glssm(true_model, 1, jrn.PRNGKey(15435324))

# start far away from true parameter
result_bfgs = mle_glssm(
    y, parameterized_lcm, 2 * jnp.ones(2), aux, options={"return_all": True}
)
result_ad = mle_glssm_ad(y, parameterized_lcm, 2 * jnp.ones(2), aux)

result_bfgs.x - result_ad.x
Array([ 1.86792884e-05, -9.92347117e-06], dtype=float64)

Numerical differentiation is much faster here, and as accurate as automatic differentiation.

import matplotlib.pyplot as plt
# 2d grid on the log scale
k = 21  # number of evaluations in each dimension
log_s2_eps, log_s2_eta = jnp.meshgrid(
    jnp.linspace(-3, 3, k) + theta[0], jnp.linspace(-3, 3, k) + theta[1]
)
# flatten
thetas = jnp.vstack([log_s2_eps.ravel(), log_s2_eta.ravel()]).T


def gnll_theta(theta):
    return gnll_full(y, parameterized_lcm(theta, aux))


nlls = vmap(gnll_theta)(thetas)
# location of minium in nlls
i = jnp.argmin(nlls)
# location of minimum in the grid
i_eps, i_eta = i // 21, i % 21

plt.contourf(log_s2_eps, log_s2_eta, nlls.reshape(k, k), alpha=0.5)
plt.scatter(
    log_s2_eps[i_eps, i_eta],
    log_s2_eta[i_eps, i_eta],
    c="white",
    marker="x",
    label="min_grid",
)
plt.scatter(theta[0], theta[1], c="r", marker="x", label="true")
plt.scatter(*result_bfgs.x, c="g", marker="x", label="$\\hat\\theta$")
plt.legend()
plt.xlabel("$\\log(\\sigma^2_\\varepsilon)$")
plt.ylabel("$\\log(\\sigma^2_\\eta)$")
plt.colorbar()
plt.show()

Inference for Log-Concave State Space Models

For non-gaussian state space models we cannot evaluate the likelihood analytically but resort to simulation methods, more specifically importance sampling.

Importance Sampling is performed using a surrogate gaussian model that shares the state density \(g(x) = p(x)\) and is parameterized by synthetic observations \(z\) and their covariance matrices \(\Omega\). In this surrogate model the likeilhood \(\ell_g = g(z)\) and posterior distribution \(g(x|z)\) are tractable and we can simulate from the posterior.

Having obtained \(N\) independent samples \(X^i, i= 1, \dots, N\) from this surrogate posterior we can evaluate the likelihood \(\ell\) by Monte-Carlo integration:

\[ \begin{align*} p(y) &= \int p(x, y) \,\mathrm dx \\ &=\int \frac{p(x,y)}{g(x|z)} g(x|z) \,\mathrm dx \\ &= g(z) \int \frac{p(y|x)}{g(z|x)} g(x|z)\,\mathrm dx \\ &\approx g(z) \frac 1 N \sum_{i =1}^N w(X^i) \end{align*} \]

where \(w(X^i) = \frac{p\left(y|X^i\right)}{g\left(z|X^i\right)}\) are the unnormalized importance sampling weights. Additionally, we use the bias correction term $ $ from (Durbin and Koopman 1997), where \(s^2_w\) is the empirical variance of the weights and \(\bar w\) is their mean.

In total we estimate the negative log-likelihood by

\[ - \log p(y) \approx \ell_g - \log \left(\sum_{i=1}^N w(X^i) \right) + \log N - \frac{s^{2}_{w}}{2 N \bar w^{2}} \]

Implementation Details
  1. Similar to MLE in GLSSMs, we minimize \(-\frac{1}{(n + 1)p} \log p(y)\) instead of \(-\log p(y)\).

source

pgnll

 pgnll (y:jaxtyping.Float[Array,'n+1p'], model:isssm.typing.PGSSM,
        z:jaxtyping.Float[Array,'n+1p'],
        Omega:jaxtyping.Float[Array,'n+1pp'], N:int,
        key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']])

Log-Concave Negative Log-Likelihood

Type Details
y Float[Array, ‘n+1 p’] observations
model PGSSM the model
z Float[Array, ‘n+1 p’] synthetic observations
Omega Float[Array, ‘n+1 p p’] covariance of synthetic observations
N int number of samples
key Union random key
Returns Float the approximate negative log-likelihood

To perform maximum likelihood estimation in a parameterized log-concave state space model we have to evaluate the likelihood several times. For evaluating the likelihood at \(\theta\) we have to perform the following:

  1. Find a surrogate Gaussian model \(g(x,z)\) for \(p_\theta(x,y)\) (e.g. Laplace approximation or efficient importance sampling).
  2. Generate importance samples from these models using the Kalman smoother.
  3. Approximate the negative log likelihood using the methods of this module.

This makes maximum likelihood an intensive task for these kinds of models.

For an initial guess we optimize the approximatie loglikelihood with the weights component fixed at the mode, see Eq. (21) in(Durbin and Koopman 1997) for further details.


source

initial_theta

 initial_theta (y:jaxtyping.Float[Array,'n+1p'], model_fn,
                theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int,
                options=None, jit_target=True)

Initial value for Maximum Likelihood Estimation for PGSSMs

Type Default Details
y Float[Array, ‘n+1 p’] observations \(y_t\)
model_fn parameterized PGSSM
theta0 Float[Array, ‘k’] initial parameter guess
aux auxiliary data for the model
n_iter_la int number of LA iterations
options NoneType None options for the optimizer
jit_target bool True whether to jit the function

As an example consider a parameterized version of the running example with unknown parameters \(\sigma^2_\varepsilon\) and \(r\).

def model_fn(theta, aux) -> PGSSM:
    log_s2_eps, log_r = theta

    n, x0 = aux

    r = jnp.exp(log_r)
    s2_eps = jnp.exp(log_s2_eps)
    return nb_pgssm_running_example(
        s_order=0,
        Sigma0_seasonal=jnp.eye(0),
        x0_seasonal=jnp.zeros(0),
        s2_speed=s2_eps,
        r=r,
        n=n,
    )


n = 100
theta_lc = jnp.array([jnp.log(1), jnp.log(20.0)])
aux = (n, jnp.ones(2))
model = model_fn(theta_lc, aux)
key = jrn.PRNGKey(512)
key, subkey = jrn.split(key)
_, (y,) = simulate_pgssm(model, 1, subkey)
theta_lc
Array([0.        , 2.99573227], dtype=float64)
initial_result = initial_theta(y, model_fn, theta_lc, aux, 10)
theta0 = initial_result.x
initial_result
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: 0.6582286439541333
        x: [ 3.894e-01  4.153e+00]
      nit: 18
      jac: [-7.251e-06  8.361e-08]
 hess_inv: [[ 2.320e+01  5.623e+02]
            [ 5.623e+02  3.945e+04]]
     nfev: 110
     njev: 22

source

mle_pgssm

 mle_pgssm (y:jaxtyping.Float[Array,'n+1p'], model_fn,
            theta0:jaxtyping.Float[Array,'k'], aux, n_iter_la:int, N:int,
            key:jax.Array, options=None)

Maximum Likelihood Estimation for PGSSMs

Type Default Details
y Float[Array, ‘n+1 p’] observations \(y_t\)
model_fn parameterized LCSSM
theta0 Float[Array, ‘k’] initial parameter guess
aux auxiliary data for the model
n_iter_la int number of LA iterations
N int number of importance samples
key Array random key
options NoneType None options for the optimizer
Returns Float[Array, ‘k’] MLE
key, subkey = jrn.split(key)
result = mle_pgssm(y, model_fn, theta0, aux, 10, 1000, subkey)
theta_hat = result.x
result
  message: Optimization terminated successfully.
  success: True
   status: 0
      fun: 0.6555639773794885
        x: [ 4.335e-01  3.698e+00]
      nit: 19
      jac: [ 6.729e-06 -1.113e-06]
 hess_inv: [[ 1.908e+01  1.101e+02]
            [ 1.101e+02  1.307e+04]]
     nfev: 105
     njev: 21
@jit
def pgnll_full(theta, key):
    model = model_fn(theta, aux)

    proposal, info = laplace_approximation(y, model, 10)
    key, subkey = jrn.split(key)
    proposal_meis, _ = modified_efficient_importance_sampling(
        y, model, proposal.z, proposal.Omega, 10, 100, subkey
    )

    key, subkey = jrn.split(key)
    return pgnll(y, model, proposal_meis.z, proposal_meis.Omega, 100, subkey) / y.size
# 2d grid on the log scale
(log_sigma_min, log_r_min), (log_sigma_max, log_r_max) = jnp.min(
    jnp.vstack((theta_lc, theta0, theta_hat)), axis=0
), jnp.max(jnp.vstack((theta_lc, theta0, theta_hat)), axis=0)
k = 20  # number of evaluations in each dimension
delta = 0.5
log_sigma, log_r = jnp.meshgrid(
    jnp.linspace(log_sigma_min - delta, log_sigma_max + delta, k),
    jnp.linspace(log_r_min - delta, log_r_max + delta, k),
)
# flatten
thetas = jnp.vstack([log_sigma.ravel(), log_r.ravel()]).T

key, subkey = jrn.split(key)
nlls = vmap(pgnll_full, (0, None))(thetas, subkey)
# location of minimum in nlls
i = jnp.argmin(nlls)
# location of minimum in the grid
i_sigma, i_r = i // k, i % k

plt.contourf(log_sigma, log_r, nlls.reshape(k, k))
plt.scatter(
    log_sigma[i_sigma, i_r],
    log_r[i_sigma, i_r],
    c="white",
    marker="x",
    label="min_grid",
)
plt.scatter(theta_lc[0], theta_lc[1], c="r", marker="x", label="true")
plt.scatter(theta0[0], theta0[1], c="black", marker="x", label="$\\theta_0$")
plt.scatter(*theta_hat, c="g", marker="o", label="min_mle")
plt.legend()
plt.xlabel("$\\log(\\sigma^2_\\varepsilon)$")
plt.ylabel("$\\log(r)$")
plt.colorbar()
plt.show()

From the above picture, we see that \(\log r\) is hard to determine: the likelihood is very flat in the \(r\) direction, which explains the precision loss warning in the optimizer. Nevertheless, our estimate \(\hat\theta\) seems to have converged to a reasonable value.

References

Durbin, James, and Siem Jan Koopman. 1997. “Monte Carlo Maximum Likelihood Estimation for Non-Gaussian State Space Models.” Biometrika 84 (3): 669–84. https://doi.org/10.1093/biomet/84.3.669.