Cross-Entropy method

See also the corresponding section in my thesis
from typing import Tuple

# | export
import jax.numpy as jnp
import jax.random as jrn
import jax.scipy as jsp
import tensorflow_probability.substrates.jax.distributions as tfd
from jax import jit, vmap
from jax.lax import fori_loop, scan, while_loop
from jaxtyping import Array, Float, PRNGKeyArray
from tensorflow_probability.substrates.jax.distributions import \
    MultivariateNormalFullCovariance as MVN

from isssm.importance_sampling import ess_pct, normalize_weights
from isssm.laplace_approximation import laplace_approximation
from isssm.pgssm import log_prob as log_prob_joint
from isssm.util import converged
import fastcore.test as fct
# |hide
import jax
import matplotlib.pyplot as plt

from isssm.importance_sampling import pgssm_importance_sampling
from isssm.pgssm import nb_pgssm_running_example, simulate_pgssm

The cross entropy method (Rubinstein 1997; Rubinstein and Kroese 2004) is a method for determining good importance sampling proposals. Given a parametric family of proposals \(g_\theta(x)\) and target \(p(x)\), the Cross-Entropy method aims at choosing \(\theta\) such that the Cross Entropy \[ \mathcal H_{\text{CE}} \left( p \middle|\middle| g_{\theta} \right) = \int p(x) \log g_{\theta}(x) \mathrm d x \] is maximized. This is equivalent to minimizing the Kullback Leibler divergence between \(p\) and \(g_\theta\). As \(H_\text{CE}\) is not analytically available, it is approximated by importance sampling itself, usually with a suitable proposal \(g_{\hat\theta_0}\). Then the approximate optimization problem is solved, yielding \(\hat \theta_1\). These steps are then iterated until convergence, using common random numbers to ensure convergence.

Considering the Cross-Entropy method with a Gaussian proposal \(g_\theta\), we see that the optimal \(\theta\) only depends on the first and second order moments of \(p\), indeed the optimal Gaussian is the one that matches these moments. Unfortunately this approach is not feasible for the models we consider in this package as the dimensionality (\(n \cdot m\)) is likely too high to act on the joint distribution directly - matching means is feasible, but simulating from the distribution and evaluating the likelihood is infeasible. However, we can exploit the Markov structure of our models:

For the class of state space models treated in this package, it can be shown that the smoothing distribution, the target of our inference, \(p(x|y)\), is again a Markov process, see Chapter 5 in (Chopin and Papaspiliopoulos 2020), so it makes sense to approximate this distribution with a Gaussian Markov process. Thus, we only need to find the closest (in terms of KL-divergence) Gaussian Markov process, which is feasible, and can be obtained by choosing the approximation to match the mean and consecutive covariances, i.e. match \[ \operatorname{Cov} \left( (X_{t}, X_{t + 1}) \right) \in \mathbf R^{2m \times 2m} \] for all \(t = 0, \dots, n - 1\). These are just \(\mathcal O(nm^2)\) many parameters, instead of the \(\mathcal O(n^2m^2)\) many parameters required to match the whole covariance matrix.


source

proposal_from_moments

 proposal_from_moments (mean:jaxtyping.Float[Array,'n+1m'],
                        consecutive_covs:jaxtyping.Float[Array,'n2*m2*m'])

Find the unique Gaussian Markov proposal that matches means and consecutive covariances.

Type Details
mean Float[Array, ‘n+1 m’] mean \(v\)
consecutive_covs Float[Array, ‘n 2m 2m’]
Returns MarkovProposal corresponding proposal

To verify that this produces the correct proposal, let us check it for a stationary AR(1) process with mean \(0\) and recurrence \[ \begin{align*} X_{t + 1} = \alpha X_{t} + \varepsilon_{t} && \varepsilon_{t} \sim \mathcal N(0, \sigma^{2}) \end{align*} \]

The initial variance is \(R_{0}^2= \tau^{2} = \operatorname{Var} (X_0) = \frac{\sigma^{2}}{1 - \alpha^{2}}\), the joint covarinces are \[ \operatorname{Cov}(X_{t}, X_{t + 1}) = \tau^{2}\begin{pmatrix} 1 & \alpha \\ \alpha & 1 \end{pmatrix}, \] with innovations covariance \(R_{t}^2 = \sigma^{2}\).

n, m = 10, 1
alpha = 0.5
s2 = 1.0
tau2 = s2 / (1 - alpha**2)
mu = jnp.zeros((n + 1, 1))
consecutive_covs = tau2 * jnp.broadcast_to(
    jnp.array([[1.0, alpha], [alpha, 1.0]]), (n, 2, 2)
)

proposal = proposal_from_moments(mu, consecutive_covs)

fct.test_eq(proposal.R.shape, (n + 1, 1, 1))
fct.test_eq(proposal.J_tt.shape, (n, 1, 1))
fct.test_eq(proposal.J_tp1t.shape, (n, 1, 1))

fct.test_close(proposal.mean, mu)
fct.test_close(proposal.R[0], jnp.sqrt(tau2) * jnp.eye(1))
fct.test_close(proposal.R[1:], jnp.sqrt(s2) * jnp.eye(1))
fct.test_close(proposal.J_tp1t / proposal.J_tt, jnp.full((n, 1, 1), alpha))

Simulation

Given a Markov proposal, we can simulate from it by repeatedly applying its defining recurrence.


source

simulate_cem

 simulate_cem (proposal:isssm.typing.MarkovProposal, N:int,
               key:Union[jaxtyping.Key[Array,''],jaxtyping.UInt32[Array,'2
               ']])
Type Details
proposal MarkovProposal proposal
N int number of samples
key Union random number seed
Returns Float[Array, ‘N n+1 m’]

probability density function

For importance sampling we need to evaluate the pdf of the proposal. We do this by first substracting the mean \(v\), \(U = X - v\) and going back to the innovations \[ \varepsilon_{t} = U_{t} - C_{t - 1}U_{t - 1} \sim \mathcal N(0, R_{t}R_{t}^T) \] where \(\varepsilon_0 = U_0\). These are jointly independent and so their pdf is easy to compute. j

For the log-weights, note that we cannot use the same weights as for MEIS as \(p(x) \neq g(x)\). Instead we have to calculate \[ \log w(x) = \log p(x,y) - \log g(x). \]


source

log_weight_cem

 log_weight_cem (x:jaxtyping.Float[Array,'n+1m'],
                 y:jaxtyping.Float[Array,'n+1p'],
                 model:isssm.typing.PGSSM,
                 proposal:isssm.typing.MarkovProposal)
Type Details
x Float[Array, ‘n+1 m’] point at which to evaluate the weights
y Float[Array, ‘n+1 p’] observations
model PGSSM modle
proposal MarkovProposal proposal
Returns Float log weights

source

log_pdf

 log_pdf (x:jaxtyping.Float[Array,'n+1m'],
          proposal:isssm.typing.MarkovProposal)

SSM to Markov Model

To initialize the Cross-Entropy method, we will use the Laplace approximation, see [30_laplace_approximation.ipynb]. This approximates the true posterior by the posterior of a Gaussian state space model. To initiate the Cross-entropy procedure, we determine the Cholesky root of this Gaussian posterior and use it as an initial value. To determine the diagonal and off-diagonal components of the Cholesky root, we calculate the joint covariance matrix \(\text{Cov} \left( X_t, X_{t + 1} | Y_1, \dots, Y_n \right)\) using the Kalman smoother and the FFBS, which results in \[ \text{Cov} \left( X_t, X_{t + 1} | Y_1, \dots, Y_n \right) = \begin{pmatrix} \Xi_{t|n} & \Xi_{t|t} A_t^T \Xi_{t + 1|t}^{-1} \Xi_{t + 1|n} \\ \left(\Xi_{t|t} A_t^T \Xi_{t + 1|t}^{-1} \Xi_{t + 1|n} \right)^T & \Xi_{t + 1 | n} \end{pmatrix}. \]

from isssm.kalman import kalman, smoother
# | export
from isssm.typing import GLSSM


def _joint_cov(Xi_smooth_t, Xi_smooth_tp1, Xi_filt_t, Xi_pred_tp1, A_t):
    """Joint covariance of conditional Markov process"""
    off_diag = (
        Xi_filt_t @ A_t.T @ jnp.linalg.pinv(Xi_pred_tp1, hermitian=True) @ Xi_smooth_tp1
    )  # jnp.linalg.solve(Xi_pred_tp1, Xi_smooth_tp1)
    return jnp.block([[Xi_smooth_t, off_diag], [off_diag.T, Xi_smooth_tp1]])


def posterior_markov_proposal(
    y: Observations, model: GLSSM  # observations  # model
) -> MarkovProposal:  # Markov proposal of posterior X|Y
    """calculate the Markov proposal of the smoothing distribution using the Kalman smoother"""
    filtered = kalman(y, model)
    _, Xi_filter, _, Xi_pred = filtered
    x_smooth, Xi_smooth = smoother(filtered, model.A)

    covs = vmap(_joint_cov)(
        Xi_smooth[:-1], Xi_smooth[1:], Xi_filter[:-1], Xi_pred[1:], model.A
    )

    return proposal_from_moments(x_smooth, covs)

The Cross-Entropy Method

Finally, we have all the ingredients together to apply the CE-method to perform importance sampling in a PGSSM with observations \(y\). We start by calculating the LA, convert its posterior distribution to a Markov-proposal and then repeatedly sample and update the proposal.


source

cross_entropy_method

 cross_entropy_method (model:isssm.typing.PGSSM,
                       y:jaxtyping.Float[Array,'n+1p'], N:int, key:Union[j
                       axtyping.Key[Array,''],jaxtyping.UInt32[Array,'2']]
                       , n_iter:int)

iteratively perform the CEM to find an optimal proposal

Type Details
model PGSSM model
y Float[Array, ‘n+1 p’] observations
N int number of samples to use in the CEM
key Union random number seed
n_iter int number of iterations
Returns MarkovProposal the CEM proposal

Let us perform the CE-method on our example model. We’ll set the number of observations somewhat lower than for EIS, as the CE-method is less performant in this setting.

from isssm.importance_sampling import pgssm_importance_sampling
s_order = 5
model = nb_pgssm_running_example(
    n=100,
    s_order=s_order,
    Sigma0_seasonal=jnp.eye(s_order - 1),
    x0_seasonal=jnp.zeros(s_order - 1),
)
key = jrn.PRNGKey(511)
key, subkey = jrn.split(key)
_, (y,) = simulate_pgssm(model, 1, subkey)
proposal_la, info_la = laplace_approximation(y, model, 10)
key, subkey = jrn.split(key)
samples_la, log_w_la = pgssm_importance_sampling(
    y, model, proposal_la.z, proposal_la.Omega, 1000, subkey
)
key, subkey = jrn.split(subkey)
proposal, log_w = cross_entropy_method(model, y, 10000, subkey, 10)
ess_pct(log_w), ess_pct(log_w_la)
(Array(0.00499999, dtype=float64), Array(19.70405657, dtype=float64))

The CEM can improve on the LA, but requires more samples than MEIS to do so.

References

Chopin, Nicolas, and Omiros Papaspiliopoulos. 2020. An Introduction to Sequential Monte Carlo. Springer Series in Statistics. Cham, Switzerland: Springer.
Rubinstein, Reuven Y. 1997. “Optimization of Computer Simulation Models with Rare Events.” European Journal of Operational Research 99 (1): 89–112. https://doi.org/10.1016/S0377-2217(96)00385-2.
Rubinstein, Reuven Y., and Dirk P. Kroese. 2004. The Cross-Entropy Method: A Unified Approach to Combinatorial Optimization, Monte-Carlo Simulation and Machine Learning. New York, NY: Springer New York.