Comparison of EIS and the CEM for SSMs

Simplified version of regional model in Chapter 4.1, keeping only \(\log I_t\) and \(\log \rho_t\) in the states.

  • States \(X_t = \left(\log I_{t}, \log \rho_{t + 1}\right)\)
  • Observations \(Y_t | X_t \sim \operatorname{Pois} \left( \exp \log I_{t}\right)\)

Varying \(n = 10, 100, 1000\). Initialize \(\log \rho_0 = 0\) with small variance and \(\log I_0 = \log 1000\) with small variance as well.

Let \(\sigma^2_\rho = \frac{1}{n}0.05\), s.t. \(\operatorname{Var} (\log \rho_{n +1}) = 0.05\) and approx. \(\mathbf P(\log \rho_{n + 1} \in [-0.1, 0.1]) \geq 0.95\), so approx. \(\rho_{n +1} \in [0.9, 1.1]\), ensuring stabilitiy of infections counts (don’t go to \(0\) or \(\infty\)).

from pyprojroot import here
from isssm.laplace_approximation import posterior_mode
from isssm.laplace_approximation import posterior_mode
from isssm.importance_sampling import ess_pct
import pandas as pd
from isssm.importance_sampling import pgssm_importance_sampling
from isssm.ce_method import log_weight_cem, simulate_cem
from jax import vmap
from functools import partial
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)
from isssm.ce_method import cross_entropy_method as CEM
from isssm.pgssm import simulate_pgssm
import jax.random as jrn
import jax.numpy as jnp
import jax
from isssm.typing import PGSSM
from tensorflow_probability.substrates.jax.distributions import Poisson

from tqdm.notebook import tqdm
jax.config.update("jax_enable_x64", True)
def _model(n, I0):
    np1 = n + 1
    s2_rho = 0.05 / n if n > 1 else 1

    m = 2
    p = 1
    l = 1

    # states
    u = jnp.zeros((np1, m))
    u = u.at[0, 0].set(jnp.log(I0))

    A = jnp.broadcast_to(jnp.array([[1.0, 1.0], [0.0, 1.0]]), (n, m, m))
    D = jnp.broadcast_to(jnp.eye(m)[:, 1:2], (n, m, l))  # only update rho

    Sigma0 = jnp.array([[1.0, 0.0], [0.0, 0.1]])
    Sigma = jnp.broadcast_to(s2_rho * jnp.eye(1), (n, l, l))

    # observations
    B = jnp.broadcast_to(jnp.eye(m)[:1], (np1, p, m))

    v = jnp.zeros((np1, p))

    def poisson_obs(s, xi):
        return Poisson(log_rate=s)

    dist = poisson_obs

    xi = jnp.empty((np1, p, 1))
    return PGSSM(u, A, D, Sigma0, Sigma, v, B, dist, xi)
def determine_efficiency_factor(n, key):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    N_iter = 1000
    N_samples = 10000

    key, sk_meis, sk_cem = jrn.split(key, 3)
    prop_la, _ = LA(Y, pgssm, N_iter)
    prop_meis, _ = MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk_meis)
    prop_cem, lw_cem = CEM(pgssm, Y, N_samples, sk_cem, N_iter)

    N_ef = 10000
    key, sk_la, sk_meis, sk_cem = jrn.split(key, 4)
    _, lw_la = pgssm_importance_sampling(
        Y, pgssm, prop_la.z, prop_la.Omega, N_ef, sk_la
    )
    _, lw_meis = pgssm_importance_sampling(
        Y, pgssm, prop_meis.z, prop_meis.Omega, N_ef, sk_meis
    )

    lw_cem = vmap(partial(log_weight_cem, y=Y, model=pgssm, proposal=prop_cem))(
        simulate_cem(prop_cem, N_samples, sk_cem)
    )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "EF_LA": ess_pct(lw_la),
            "EF_MEIS": ess_pct(lw_meis),
            "EF_CEM": ess_pct(lw_cem),
        }
    )

    return result
key = jrn.PRNGKey(140235293)
ns_ef = jnp.repeat(jnp.array([1, 10, 20, 50, 100]), 10)
key, *keys_ef = jrn.split(key, len(ns_ef) + 1)
results_ef = pd.DataFrame([determine_efficiency_factor(n, k) for n, k in zip(ns_ef, keys_ef)])
results_ef.to_csv(here("data/figures/ef_meis_cem_ssms.csv"), index=False)
def asymptotic_det_meis(Y, pgssm, prop_la, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [
        MEIS(Y, pgssm, prop_la.z, prop_la.Omega, N_iter, N_samples, sk)[0]
        for sk in subkeys
    ]
    modes = jnp.array([posterior_mode(proposal).reshape(-1) for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples
    _, logdet = jnp.linalg.slogdet(cov)

    return logdet

def asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M: int):
    key, *subkeys = jrn.split(key, 1 + M)
    proposals = [CEM(pgssm, Y, N_samples, sk_cem, N_iter)[0] for sk_cem in subkeys]
    modes = jnp.array([proposal.mean[:, 0] for proposal in proposals])
    cov = jnp.cov(modes, rowvar=False) * N_samples

    _, logdet = jnp.linalg.slogdet(cov)
    return logdet


def asymptotic_variance(n: int, key: jrn.PRNGKey, N_var: int = 10):
    pgssm = _model(n, I0=1000)
    key, subkey = jrn.split(key)

    _, (Y,) = simulate_pgssm(pgssm, 1, subkey)

    N_iter = 1000
    N_samples = 10000

    prop_la, _ = LA(Y, pgssm, N_iter)

    key, *sks = jrn.split(key, 1 + 2 * N_var)

    sks_meis = sks[:N_var]
    sks_cem = sks[N_var:]

    logdet_cem = asymptotic_det_cem(Y, pgssm, N_iter, N_samples, key, M=len(sks_cem))
    logdet_meis = asymptotic_det_meis(
        Y, pgssm, prop_la, N_iter, N_samples, key, M=len(sks_meis)
    )

    result = pd.Series(
        {
            "n": n,
            "N_samples": N_samples,
            "N_iter": N_iter,
            "log_DET_CEM": logdet_cem,
            "log_DET_MEIS": logdet_meis,
            "ARE": jnp.exp(logdet_cem - logdet_meis),
        }
    )

    return result
key = jrn.PRNGKey(140235293)
ns_are = jnp.repeat(jnp.array([1, 2, 5, 10]), 10)
key, *keys_are = jrn.split(key, len(ns_are) + 1)
results_are = pd.DataFrame(
    [asymptotic_variance(n, k) for n, k in zip(ns_are, keys_are)]
)

results_are.to_csv(here("data/figures/are_meis_cem_ssms.csv"), index=False)
results_are
n N_samples N_iter DET_CEM DET_MEIS ARE
0 1 10000 1000 1.2501975940617859e-21 7.362584944337338e-21 0.16980416572624124
1 1 10000 1000 2.1315459186618234e-19 7.12489799034234e-18 0.029916862270183968
2 1 10000 1000 1.5207781065517603e-20 6.692526415181031e-20 0.22723527890784181
3 1 10000 1000 6.007385714992874e-17 6.660606838997179e-11 9.01927686201239e-07
4 1 10000 1000 6.624824337829013e-22 3.177769216726953e-22 2.084740547852769
5 1 10000 1000 1.151610362262488e-21 3.231203621180358e-21 0.3564029065558564
6 1 10000 1000 3.083248944208273e-23 1.2071821492071727e-22 0.2554087588383595
7 1 10000 1000 3.1796825882208896e-21 4.0446667403011136e-21 0.7861420463986537
8 1 10000 1000 2.180298654142417e-18 8.979117586628375e-15 0.00024281881076925632
9 1 10000 1000 1.877286375520868e-20 2.2784117172564035e-18 0.008239451901087663
10 2 10000 1000 1.9273993170437574e-28 1.0798989361480693e-28 1.784796014262842
11 2 10000 1000 1.1446337747514061e-32 2.8882261782989135e-32 0.39631029707845256
12 2 10000 1000 6.159994369289026e-29 7.289892425672587e-29 0.8450048381503636
13 2 10000 1000 1.4290304502401327e-31 2.0813887733628598e-32 6.86575457948337
14 2 10000 1000 5.42336248299976e-30 7.953995529684445e-30 0.6818412787333963
15 2 10000 1000 5.877859598370439e-32 6.436371648504608e-32 0.9132256369527805
16 2 10000 1000 5.379361864671192e-33 1.0620511446164679e-32 0.5065068562789232
17 2 10000 1000 4.887804261821233e-31 5.855830511801467e-30 0.08346901864681132
18 2 10000 1000 2.0359864128830622e-32 9.830573463933316e-32 0.2071075935043613
19 2 10000 1000 7.752964455709755e-33 9.875675986456281e-34 7.850565841105299
20 5 10000 1000 1.3067690649574617e-54 3.883233797706624e-60 336515.6807527836
21 5 10000 1000 7.110526124954396e-50 1.6548473953519135e-59 4296786606.98638
22 5 10000 1000 3.682259048961097e-69 7.34212755723125e-71 50.152479921633144
23 5 10000 1000 1.2714534177185017e-57 3.3668411607110587e-62 37763.98579640679
24 5 10000 1000 1.5892555050155927e-62 2.4410366744665834e-65 651.0576107435492
25 5 10000 1000 7.400743476215611e-67 5.401150252263782e-68 13.702161818426992
26 5 10000 1000 1.6715727758016275e-55 2.2222863809909505e-61 752186.03241237
27 5 10000 1000 6.099313166607847e-58 1.3031186392933132e-63 468055.09358039196
28 5 10000 1000 9.018652040498798e-68 1.9240738181124852e-70 468.7269249028131
29 5 10000 1000 2.414396183688779e-61 8.786311541305962e-65 2747.906413673459
30 10 10000 1000 1.0521835449122195e-139 -3.827549513153355e-153 -27489743536860.75
31 10 10000 1000 -1.867709109170755e-124 5.903576866165244e-146 -3.1636906768759553e+21
32 10 10000 1000 -7.623069051137711e-129 1.2997503577057545e-146 -5.865025545823676e+17
33 10 10000 1000 3.335174248271962e-127 -7.751308305009268e-148 -4.3027242847721744e+20
34 10 10000 1000 -2.491652350847499e-142 -3.4584538590464226e-152 7204526798.384138
35 10 10000 1000 -9.69823912484352e-153 2.342987700145343e-159 -4139261.6462484663
36 10 10000 1000 -4.840099789083058e-124 7.639541159271866e-146 -6.335589648874114e+21
37 10 10000 1000 -8.222821803800646e-109 -4.9680496813661975e-143 1.6551408160514604e+34
38 10 10000 1000 -5.2093692056011e-142 -7.930262068834436e-154 656897484646.0221
39 10 10000 1000 -7.507560673550981e-142 4.376429708052626e-155 -17154532745578.154
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.