from jax import config
config.update("jax_enable_x64", True)
from ssm4epi.models.hospitalization import (
hospitalization_model,
dates,
unique_a,
h_by_age,
I_by_age,
)
from ssm4epi.models.util import from_consecutive_logits
from isssm.importance_sampling import mc_integration
from isssm.kalman import state_mode
from jax import vmap
from ssm4epi.patch import full_deps
import jax.numpy as jnp
from pyprojroot import here
import matplotlib.pyplot as plt
from isssm.estimation import initial_theta, mle_pgssm
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from isssm.importance_sampling import pgssm_importance_sampling, ess_pct
import jax.random as jrn
import pandas as pdApplication 2: Nowcasting hospitalizations
from typing import NamedTuple
from jaxtyping import Float, Array, PRNGKeyArray
from scipy.optimize import OptimizeResult
from isssm.typing import GLSSMProposal
class PredictionResult(NamedTuple):
mean: Float[Array, "..."]
sd: Float[Array, "..."]
quantiles: Float[Array, "k ..."]
quantiles_of_interest: Float[Array, "k"]
class HospitalizationNowcastingResult(NamedTuple):
a_index: int
h: Float[Array, "..."]
I: Float[Array, "..."]
dates: pd.DatetimeIndex
y_miss: Float[Array, "np1 p"]
y_true: Float[Array, "np1 p"]
theta0: Float[Array, "5"]
theta0_result: OptimizeResult
proposal: GLSSMProposal
key: PRNGKeyArray
ess_pct: Float
y_sum_predict: PredictionResultstart_date = "2021-11-22"
end_date = "2022-04-29"
n_days_back = 100
n_weekday = 2
n_iter = 100
N_meis = 10000
N_ess = 10000
N_fit = 10000
N_predict = 10000class HospitalizationResult(NamedTuple):
a_index: int
h: Float[Array, "..."]
I: Float[Array, "..."]
y: Float[Array, "np1 p"]
theta0: Float[Array, "5"]
theta0_result: OptimizeResult
theta_hat: Float[Array, "5"]
theta_hat_result: OptimizeResult
proposal: GLSSMProposal
key: PRNGKeyArray
ess_pct: Floatimport pickle
from isssm.util import degenerate_cholesky
from ssm4epi.models.hospitalization import account_for_nans, make_y_nan
from isssm.importance_sampling import prediction
def nowcast_hospitalizations(a_index, n_delay, date_to_nowcast):
s = int(jnp.where(dates == date_to_nowcast)[0][0])
i_start = s - n_days_back
np1 = n_days_back
h = h_by_age[a_index][:, :n_delay]
I = I_by_age[a_index][i_start : i_start + np1]
unique_a[a_index]
key = jrn.PRNGKey(4234809124)
y = h[i_start : i_start + np1]
aux = (np1, n_delay, n_weekday, I)
y_nan = make_y_nan(y)
missing_y_indices = jnp.isnan(y_nan)
with open(
here()
/ "data"
/ "results"
/ "4_hospitalizations"
/ "showcase"
/ f"result_A{a_index}_D{n_delay}.pkl",
"rb",
) as f:
result = pickle.load(f)
theta_manual = result.theta0
missing_s_indicies = jnp.concatenate(
(jnp.full((np1, 1), False, dtype=bool), missing_y_indices[:, :-1]), axis=-1
)
_, y_miss = account_for_nans(
hospitalization_model(theta_manual, aux),
y_nan,
missing_y_indices,
missing_s_indicies,
)
_model_miss = lambda theta, aux: account_for_nans(
hospitalization_model(theta, aux), y_nan, missing_y_indices, missing_s_indicies
)[0]
theta0_missing_result = initial_theta(
y_miss, _model_miss, theta_manual, aux, n_iter
)
theta0_missing = theta0_missing_result.x
model_miss0 = _model_miss(theta0_missing, aux)
proposal_la, info_la = LA(y_miss, model_miss0, n_iter)
cOmega = degenerate_cholesky(proposal_la.Omega)
Omega = cOmega @ cOmega.transpose((0, 2, 1))
proposal_la = GLSSMProposal(
u=proposal_la.u,
A=proposal_la.A,
D=proposal_la.D,
Sigma0=proposal_la.Sigma0,
Sigma=proposal_la.Sigma,
v=proposal_la.v,
B=proposal_la.B,
Omega=Omega,
z=proposal_la.z,
)
# key, subkey = jrn.split(key)
# proposal_meis, info_meis = MEIS(
# y_miss,
# model_miss0,
# proposal_la.z,
# proposal_la.Omega,
# n_iter,
# N_meis,
# subkey,
# )
key, subkey = jrn.split(key)
_, log_weights = pgssm_importance_sampling(
y_miss,
model_miss0,
proposal_la.z,
proposal_la.Omega,
N_fit,
subkey,
)
def f_nowcast(x, s, y):
return jnp.sum(
(missing_y_indices * y) + (1 - missing_y_indices) * y_miss, axis=-1
)
key, subkey = jrn.split(key)
preds = prediction(
f_nowcast,
y_miss,
proposal_la,
_model_miss(theta0_missing, aux),
N_predict,
subkey,
jnp.array([0.025, 0.5, 0.975]),
hospitalization_model(theta0_missing, aux),
)
result = HospitalizationNowcastingResult(
a_index=a_index,
h=h,
I=I,
y_miss=y_miss,
y_true=y,
dates=dates[i_start : i_start + np1],
theta0=theta0_missing,
theta0_result=theta0_missing_result,
proposal=proposal_la,
key=key,
ess_pct=ess_pct(log_weights),
y_sum_predict=preds,
)
with open(
here()
/ "data"
/ "results"
/ "4_hospitalizations"
/ "nowcast"
/ f"nowcast_A{a_index}_D{n_delay}_{date_to_nowcast}.pkl",
"wb",
) as f:
pickle.dump(result, f)
return resultfrom ssm4epi.models.hospitalization import delays_per_age
initial_date = "2021-11-22"
results = delays_per_age.apply(
lambda x: nowcast_hospitalizations(x[0], x[1], initial_date), axis=1
)import matplotlib.pyplot as plt
from matplotlib import dates as mdates
def visualize(result: HospitalizationNowcastingResult, ax: plt.Axes = None):
ax.set_title(f"ESS: {result.ess_pct * 100 :.2f}%")
dates = pd.to_datetime(result.dates)
y = result.y_true
ax.plot(dates, result.y_miss.sum(axis=-1), label="$y_{miss}$")
ax.plot(dates, y.sum(axis=-1), label="$y_{true}$")
# plt.plot(result.y_sum_predict[0], color="black")
ax.plot(dates, result.y_sum_predict[2][0], "--", color="gray")
ax.plot(dates, result.y_sum_predict[2][-1], "--", color="gray")
ax.legend()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%m/%d/%Y"))
ax.xaxis.set_major_locator(mdates.MonthLocator())
_, axs = plt.subplots(4, 2, figsize=(12, 12))
[visualize(result, ax) for result, ax in zip(results, axs.flatten()[:-1])]