Application 1: Showcase

Imports

from ssm4epi.models.regional_growth_factor import (
    key,
    n_iterations,
    N_mle,
    N_meis,
    N_posterior,
    percentiles_of_interest,
    make_aux,
    dates_full,
    cases_full,
    n_ij,
    n_tot,
    n_pop,
    account_for_nans,
    growth_factor_model,
)

import jax.numpy as jnp
import jax
import jax.random as jrn

from isssm.importance_sampling import prediction
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)

from pyprojroot.here import here

jax.config.update("jax_enable_x64", True)
from isssm.estimation import initial_theta
import pickle
import matplotlib.pyplot as plt

Parameters

# parameters
N_ITER_LA = 100
N_MAX_ITER_THETA0 = 10

N_MLE = 10_000
N_MLE_TEXT = f"{N_MLE:,}".replace(",", "\\,")
N_MAX_ITER_MLE = 100
N_PRED = 10_000
text_parameters = f"""As hyperparameters for the showcase, we use ${N_ITER_LA}$ iterations for the LA, a maximum of ${N_MAX_ITER_THETA0}$ iterations for the optimization of $\\theta_0$ in the LA, ${N_MLE_TEXT}$ samples and a maximum of ${N_MAX_ITER_MLE}$ iterations for the estimation of $\\hat\\theta$, and ${N_PRED}$ samples for the prediction. These values have been chosen by hand to balance computational cost and accuracy. 
"""

with open(
    here(
        "./chapters/04_epidemiological_ssms/04_02_weekly_GF_showcase_hyperparameter_text.tex"
    ),
    "w",
) as f:
    f.write(text_parameters)

Missing model

from datetime import date
import pandas as pd

initial_date = "2020-04-18"
np1 = 10
(dates_index,) = jnp.where(dates_full == initial_date)[0]
dates = dates_full[dates_index + 1 : dates_index + np1 + 1]
aux = make_aux(initial_date, cases_full, n_ij, n_tot, np1)

y = aux[0][1:]
y_nan = y.at[-1].set(jnp.nan)
missing_inds = jnp.isnan(y_nan)
# original result from 10_model.ipynb
theta_manual = jnp.array(
    [-0.0236392, -2.0838978, -5.31651543, -2.62109273, -0.3461143, 0.57673125]
)
_, y_miss = account_for_nans(
    growth_factor_model(theta_manual, aux), y_nan, missing_inds
)
_model_miss = lambda theta, aux: account_for_nans(
    growth_factor_model(theta, aux), y_nan, missing_inds
)[0]
plt.figure(figsize=(12, 5))
plt.ylabel("$I$")
plt.plot(dates, y_nan.sum(axis=-1))
plt.scatter(dates, y_nan.sum(axis=-1))
plt.show()

Initial value \(\theta_0\)

theta0_result = initial_theta(
    y_miss,
    _model_miss,
    theta_manual,
    aux,
    n_iter_la=N_ITER_LA,
    options={"maxiter": N_MAX_ITER_THETA0},
)
theta0 = theta0_result.x
from ssm4epi.models.regional_growth_factor import theta_to_par
import jax.scipy as jsp

jsp.special.expit(theta0[0]) * 2 - 1
Array(-0.01538746, dtype=float64)

Maximum Likelihood estimate \(\hat \theta\)

Reimplement MLE from ISSSM to only use LA.

from isssm.laplace_approximation import laplace_approximation
from isssm.estimation import pgnll
from scipy.optimize import minimize as minimize_scipy
from jaxtyping import Float, Array
from jax import jit


def mle_pgssm(
    y: Float[Array, "n+1 p"],  # observations $y_t$
    model_fn,  # parameterized LCSSM
    theta0: Float[Array, "k"],  # initial parameter guess
    aux,  # auxiliary data for the model
    n_iter_la: int,  # number of LA iterations
    N: int,  # number of importance samples
    key: Array,  # random key
    options=None,  # options for the optimizer
) -> Float[Array, "k"]:  # MLE
    """Maximum Likelihood Estimation for PGSSMs"""

    @jit
    def f(theta, key):
        model = model_fn(theta, aux)

        proposal_la, _ = laplace_approximation(y, model, n_iter_la)

        key, subkey = jrn.split(key)
        # improve numerical stability by dividing by number of observations
        n_obs = y.size
        return pgnll(y, model, proposal_la.z, proposal_la.Omega, N, subkey) / n_obs

    key, subkey = jrn.split(key)
    result = minimize_scipy(
        f, theta0, method="BFGS", jac="3-point", options=options, args=(subkey,)
    )
    return result


theta_result = mle_pgssm(
    y_miss,
    _model_miss,
    theta0,
    aux,
    n_iter_la=N_ITER_LA,
    N=N_MLE,
    key=key,
    options={"maxiter": N_MAX_ITER_MLE},
)
theta = theta_result.x
theta_result
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[8], line 38
     32     result = minimize_scipy(
     33         f, theta0, method="BFGS", jac="3-point", options=options, args=(subkey,)
     34     )
     35     return result
---> 38 theta_result = mle_pgssm(
     39     y_miss,
     40     _model_miss,
     41     theta0,
     42     aux,
     43     n_iter_la=N_ITER_LA,
     44     N=N_MLE,
     45     key=key,
     46     options={"maxiter": N_MAX_ITER_MLE},
     47 )
     48 theta = theta_result.x
     49 theta_result

Cell In[8], line 32, in mle_pgssm(y, model_fn, theta0, aux, n_iter_la, N, key, options)
     29     return pgnll(y, model, proposal_la.z, proposal_la.Omega, N, subkey) / n_obs
     31 key, subkey = jrn.split(key)
---> 32 result = minimize_scipy(
     33     f, theta0, method="BFGS", jac="3-point", options=options, args=(subkey,)
     34 )
     35 return result

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_minimize.py:733, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    731     res = _minimize_cg(fun, x0, args, jac, callback, **options)
    732 elif meth == 'bfgs':
--> 733     res = _minimize_bfgs(fun, x0, args, jac, callback, **options)
    734 elif meth == 'newton-cg':
    735     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    736                              **options)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_optimize.py:1419, in _minimize_bfgs(fun, x0, args, jac, callback, gtol, norm, eps, maxiter, disp, return_all, finite_diff_rel_step, xrtol, c1, c2, hess_inv0, **unknown_options)
   1416 pk = -np.dot(Hk, gfk)
   1417 try:
   1418     alpha_k, fc, gc, old_fval, old_old_fval, gfkp1 = \
-> 1419              _line_search_wolfe12(f, myfprime, xk, pk, gfk,
   1420                                   old_fval, old_old_fval, amin=1e-100,
   1421                                   amax=1e100, c1=c1, c2=c2)
   1422 except _LineSearchError:
   1423     # Line search failed to find a better solution.
   1424     warnflag = 2

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_optimize.py:1155, in _line_search_wolfe12(f, fprime, xk, pk, gfk, old_fval, old_old_fval, **kwargs)
   1141 """
   1142 Same as line_search_wolfe1, but fall back to line_search_wolfe2 if
   1143 suitable step length is not found, and raise an exception if a
   (...)
   1150 
   1151 """
   1153 extra_condition = kwargs.pop('extra_condition', None)
-> 1155 ret = line_search_wolfe1(f, fprime, xk, pk, gfk,
   1156                          old_fval, old_old_fval,
   1157                          **kwargs)
   1159 if ret[0] is not None and extra_condition is not None:
   1160     xp1 = xk + ret[0] * pk

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_linesearch.py:93, in line_search_wolfe1(f, fprime, xk, pk, gfk, old_fval, old_old_fval, args, c1, c2, amax, amin, xtol)
     89     return np.dot(gval[0], pk)
     91 derphi0 = np.dot(gfk, pk)
---> 93 stp, fval, old_fval = scalar_search_wolfe1(
     94         phi, derphi, old_fval, old_old_fval, derphi0,
     95         c1=c1, c2=c2, amax=amax, amin=amin, xtol=xtol)
     97 return stp, fc[0], gc[0], fval, old_fval, gval[0]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_linesearch.py:170, in scalar_search_wolfe1(phi, derphi, phi0, old_phi0, derphi0, c1, c2, amax, amin, xtol)
    167 maxiter = 100
    169 dcsrch = DCSRCH(phi, derphi, c1, c2, xtol, amin, amax)
--> 170 stp, phi1, phi0, task = dcsrch(
    171     alpha1, phi0=phi0, derphi0=derphi0, maxiter=maxiter
    172 )
    174 return stp, phi1, phi0

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_dcsrch.py:256, in DCSRCH.__call__(self, alpha1, phi0, derphi0, maxiter)
    254     alpha1 = stp
    255     phi1 = self.phi(stp)
--> 256     derphi1 = self.derphi(stp)
    257 else:
    258     break

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_linesearch.py:87, in line_search_wolfe1.<locals>.derphi(s)
     86 def derphi(s):
---> 87     gval[0] = fprime(xk + s*pk, *args)
     88     gc[0] += 1
     89     return np.dot(gval[0], pk)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:332, in ScalarFunction.grad(self, x)
    330 if not np.array_equal(x, self.x):
    331     self._update_x(x)
--> 332 self._update_grad()
    333 return self.g

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:307, in ScalarFunction._update_grad(self)
    305 if self._orig_grad in FD_METHODS:
    306     self._update_fun()
--> 307 self.g = self._wrapped_grad(self.x, f0=self.f)
    308 self.g_updated = True

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:48, in _wrapper_grad.<locals>.wrapped1(x, f0)
     46 def wrapped1(x, f0=None):
     47     ncalls[0] += 1
---> 48     return approx_derivative(
     49         fun, x, f0=f0, **finite_diff_options
     50     )

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_numdiff.py:523, in approx_derivative(fun, x0, method, rel_step, abs_step, f0, bounds, sparsity, as_linear_operator, args, kwargs)
    520     use_one_sided = False
    522 if sparsity is None:
--> 523     return _dense_difference(fun_wrapped, x0, f0, h,
    524                              use_one_sided, method)
    525 else:
    526     if not issparse(sparsity) and len(sparsity) == 2:

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_numdiff.py:608, in _dense_difference(fun, x0, f0, h, use_one_sided, method)
    606 x2[i] += h[i]
    607 dx = x2[i] - x1[i]
--> 608 f1 = fun(x1)
    609 f2 = fun(x2)
    610 df = f2 - f1

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_numdiff.py:474, in approx_derivative.<locals>.fun_wrapped(x)
    471 if xp.isdtype(x.dtype, "real floating"):
    472     x = xp.astype(x, x0.dtype)
--> 474 f = np.atleast_1d(fun(x, *args, **kwargs))
    475 if f.ndim > 1:
    476     raise RuntimeError("`fun` return value has "
    477                        "more than 1 dimension.")

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:25, in _wrapper_fun.<locals>.wrapped(x)
     23 if not np.isscalar(fx):
     24     try:
---> 25         fx = np.asarray(fx).item()
     26     except (TypeError, ValueError) as e:
     27         raise ValueError(
     28             "The user-provided objective function "
     29             "must return a scalar value."
     30         ) from e

KeyboardInterrupt: 
fitted_model = _model_miss(theta, aux)
proposal_la, info_la = LA(y_miss, fitted_model, N_ITER_LA, eps=1e-5)
from isssm.importance_sampling import ess_pct, pgssm_importance_sampling


_, log_weights = pgssm_importance_sampling(
    y_miss, fitted_model, proposal_la.z, proposal_la.Omega, N_PRED, key
)
ess_pct(log_weights) * N_PRED
Array(260.10875772, dtype=float64)
def f_pred(x, s, y):
    y_county = y[-1]
    y_tot = jnp.minimum(y_county, n_pop).sum()[None]

    growth_factor = x[:, 0]
    growth_factors_county = s.reshape(-1)
    mean_rho_country = jnp.mean(jnp.exp(s), axis=-1)
    return jnp.concatenate(
        [y_tot, y_county, growth_factors_county, growth_factor, mean_rho_country]
    )


key, subkey = jrn.split(key)

Make predictions

key = jrn.PRNGKey(12453452)
key, subkey = jrn.split(key)
preds = prediction(
    f_pred,
    y_miss,
    proposal_la,
    fitted_model,
    1_000,
    subkey,
    percentiles_of_interest,
    growth_factor_model(theta, aux),
)
ESS:31.35 %
result = (theta, proposal_la, preds, dates, y)

with open(here() / "data/results/4_local_outbreak_model/results.pickle", "wb") as f:
    pickle.dump(result, f)

Analyze results

with open(here() / "data/results/4_local_outbreak_model/results.pickle", "rb") as f:
    result = pickle.load(f)
theta, proposal_la, preds, dates, y = result
from ssm4epi.models.regional_growth_factor import _P
import jax.scipy as jsp

logit_alpha, log_s2_r, log_s2_spat, logit_q, log_Cm1, log_r = theta

params_ser = pd.Series(
    {
        "alpha": jsp.special.expit(logit_alpha) * 2 - 1,
        "sigma_r": jnp.sqrt(jnp.exp(log_s2_r)),
        "sigma spatial": jnp.sqrt(jnp.exp(log_s2_spat)),
        "q": jsp.special.expit(logit_q),
        "C": jnp.exp(log_Cm1) + 1,
        "r": jnp.exp(log_r),
    }
)
params_ser.to_csv(
    here() / "data/results/4_local_outbreak_model/estimated_parameters.csv",
    header=False,
)
params_ser.apply(lambda x: f"{x:.2f}")
alpha            -0.22
sigma_r           0.18
sigma spatial     0.87
q                 0.05
C                 2.41
r                19.52
dtype: object
import numpy as np

P = _P(params_ser.C, params_ser.q, n_ij, n_tot)

np.savetxt(
    here() / "data/results/4_local_outbreak_model/showcase_P_matrix.csv", np.array(P)
)
import pandas as pd

df = pd.DataFrame(
    {
        "variable": [
            "y_total",
            *[f"y_total_{c}" for c in range(1, 401)],
            *[f"log_rho_{c}_{t}" for t in range(1, 11) for c in range(1, 401)],
            *["log_rho" for t in range(1, 11)],
            *["mean_rho" for t in range(1, 11)],
        ],
        "c": [
            0,
            *range(1, 401),
            *[c for t in range(1, 11) for c in range(1, 401)],
            *[0 for _ in range(1, 11)],
            *[0 for _ in range(1, 11)],
        ],
        "t": [
            10,
            *jnp.repeat(10, 400),
            *[t for t in range(1, 11) for c in range(1, 401)],
            *[t for t in range(1, 11)],
            *[t for t in range(1, 11)],
        ],
        "mean": preds[0],
        "sd": preds[1],
        **{
            f"{p * 100:.1f} %": preds[2][i, :]
            for i, p in enumerate(percentiles_of_interest)
        },
    }
)

df["date"] = [dates[t - 1] for t in df["t"]]
df.to_csv(here("data/results/4_local_outbreak_model/showcase.csv"), index=False)

Country-wide growth factor

y_actual = aux[0][:-1]
fitted_model = _model_miss(theta, aux)


def f_pred(x, s, y):
    y_country = y.sum(axis=-1)
    y_country_past = y_actual.sum(axis=-1)
    rho_observed = y_country / y_country_past

    return rho_observed


key = jrn.PRNGKey(12453452)
key, subkey = jrn.split(key)
preds = prediction(
    f_pred,
    y_miss,
    proposal_la,
    fitted_model,
    N_PRED,
    subkey,
    percentiles_of_interest,
    growth_factor_model(theta, aux),
)

df = pd.DataFrame(
    {
        "variable": [f"rho_observed" for t in range(1, 11)],
        "t": [t for t in range(1, 11)],
        "mean": preds[0],
        "sd": preds[1],
        **{
            f"{p * 100:.1f} %": preds[2][i, :]
            for i, p in enumerate(percentiles_of_interest)
        },
    }
)
df["date"] = [dates[t - 1] for t in df["t"]]
df.to_csv(
    here("data/results/4_local_outbreak_model/showcase_rho_observed.csv"), index=False
)
ESS:31.35 %
# plt.plot(preds[2][11])
plt.plot(aux[0][1:].sum(axis=-1) / aux[0][:-1].sum(axis=-1))
plt.plot(preds[2][2], color="grey", linestyle="dashed")
plt.plot(preds[2][-2], color="grey", linestyle="dashed")

Gütersloh contribution to neighboring counties incidence

guetersloh_neighbor_ags = [
    "05754",  # Gütersloh
    "03459",  # Osnabrück
    "03404",  # Osnabrück Stadt
    "05566",  # Steinfurt
    "05515",  # Münster
    "05558",  # Coesfeld
    "05758",  # Herford
    "05711",  # Bielefeld
    "05766",  # Lippe
    "05774",  # Paderborn
    "05974",  # Soest
    "05570",  # Warendorf
]
df_ags = pd.read_csv(here("data/processed/ags_county_dict.csv"))
df_ags["ags"] = df_ags["ags"].astype(str).str.zfill(5)
df_ags_neighbors = df_ags[df_ags["ags"].isin(guetersloh_neighbor_ags)]
neighboring_indices = df_ags_neighbors.index.to_numpy()
guetersloh_index = df_ags[df_ags["ags"] == "05754"].index.to_numpy()[0]
from ssm4epi.models.regional_growth_factor import _P


def f_pred_distribution_of_cases(x, s, y):
    log_rho = s[-1]
    pred_cases = jnp.exp(log_rho) * y_nan[-2]
    P = _P(params_ser.C, params_ser.q, n_ij, n_tot)
    guetersloh_contribution = (P.T * pred_cases)[neighboring_indices][
        :, neighboring_indices
    ]
    return guetersloh_contribution.reshape(-1)
key, subkey = jrn.split(key)
res = prediction(
    f_pred_distribution_of_cases,
    y_miss,
    proposal_la,
    fitted_model,
    N_PRED,
    subkey,
    percentiles_of_interest,
    growth_factor_model(theta, aux),
)
ESS:574.09
df_guetersloh_contrib = pd.DataFrame(
    {
        "t": [*jnp.repeat(10, len(res[0]))],
        "mean": res[0],
        "sd": res[1],
        **{
            f"{p * 100:.1f} %": res[2][i, :]
            for i, p in enumerate(percentiles_of_interest)
        },
    }
)

df_guetersloh_contrib["date"] = [dates[t - 1] for t in df_guetersloh_contrib["t"]]
df_guetersloh_contrib.to_csv(
    here("data/results/4_local_outbreak_model/guetersloh_distribution_cases.csv"),
    index=False,
)
res[2][-1] - res[2][0]
Array([6.73212908e+00, 8.63652793e-01, 7.73697856e-02, 1.09639916e-02,
       6.81631995e-01, 3.56302318e-01, 5.24354216e-02, 5.29234075e+00,
       1.96407338e-01, 8.77265876e-03, 1.50070732e-02, 6.40926398e-03,
       5.96636115e-01, 8.21259855e+00, 3.50298247e-02, 6.88414552e-03,
       4.01219226e-01, 1.08278335e+00, 9.76470123e-02, 2.65663729e+01,
       9.84542919e-01, 1.11872618e-02, 1.63687777e-02, 8.42236443e-03,
       4.92496924e-02, 3.08783106e-02, 2.39297884e+01, 8.93555249e-01,
       1.53671456e+00, 1.21265836e+01, 6.76227776e-02, 7.88301802e+00,
       9.18055204e-02, 1.29503052e-02, 3.98808762e-02, 4.27060751e-02,
       3.56148240e-03, 3.34114987e-03, 3.79823599e-01, 8.43886689e+00,
       1.40441341e-01, 9.29755192e-01, 9.44352129e-03, 1.02534290e+00,
       1.87870769e-02, 3.59850945e-03, 8.01698985e-03, 1.02524557e-02,
       1.91886055e-01, 1.50752170e-01, 5.59277146e-01, 1.05124986e-01,
       2.68655566e+01, 1.32199051e+00, 2.21775352e-02, 2.37401901e+00,
       5.76143445e-02, 4.67166634e-03, 1.07403989e-02, 8.84938574e-03,
       1.10712094e-02, 3.88064759e-02, 4.17965294e-01, 5.04643200e-02,
       1.19808073e-01, 1.92894692e+02, 6.00290996e-02, 4.07846193e+01,
       7.58689553e-02, 8.38938844e-03, 3.25276717e-02, 1.68677361e-01,
       2.73503595e-02, 4.44354733e-02, 5.20908638e-02, 8.50681159e-03,
       3.56775860e-02, 9.90966454e-01, 3.26831324e+01, 1.51772281e+02,
       3.67057852e+00, 5.80382013e-01, 3.76948141e-01, 5.25885682e-02,
       3.74279286e-02, 1.24087107e-01, 5.37880353e-02, 8.92406858e-03,
       4.27328325e-02, 9.23596234e+00, 1.93391519e+00, 2.65316408e+03,
       8.62266002e-01, 1.94275493e-01, 7.50236747e-01, 2.57741805e-01,
       2.12456782e-02, 7.02812921e-02, 1.22519977e-02, 3.12883261e-03,
       1.79063519e-02, 1.15484328e-01, 7.40189000e-01, 1.17842732e+01,
       6.06620837e+01, 2.79629793e-01, 3.77021489e-02, 7.62932486e-03,
       8.06731859e-03, 7.25237808e-03, 1.11800999e-02, 2.66521373e-03,
       9.45336784e-03, 1.45284547e-01, 6.77453537e-01, 1.47787913e+01,
       1.27951425e+00, 1.10458301e+01, 3.14672853e-01, 1.57427297e-02,
       3.65838210e-03, 4.79464684e-03, 1.51997165e-02, 4.89058434e-03,
       1.13170179e-02, 4.41675920e-01, 2.19963642e-01, 3.02542782e+01,
       1.41933261e-01, 1.97188348e-01, 2.50007354e+01, 3.22039013e-01,
       3.56148240e-03, 4.05468474e-03, 2.10058293e-02, 6.83778363e-03,
       1.67082911e-02, 2.18861312e+00, 3.78321946e-02, 1.42301773e+01,
       3.09568175e-02, 1.04590482e-02, 5.21561164e-01, 1.59825589e+01],      dtype=float64)
cases_dist = jnp.array(res[2][0]).reshape((12, 12))
plt.imshow(cases_dist - jnp.diag(jnp.diag(cases_dist)))
plt.colorbar()