from ssm4epi.models.regional_growth_factor import (
key,
n_iterations,
N_mle,
N_meis,
N_posterior,
percentiles_of_interest,
make_aux,
dates_full,
cases_full,
n_ij,
n_tot,
n_pop,
account_for_nans,
growth_factor_model,
)
from tqdm.notebook import tqdm
import pandas as pd
import jax.numpy as jnp
import jax
import jax.random as jrn
from isssm.importance_sampling import prediction
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
modified_efficient_importance_sampling as MEIS,
)
from pyprojroot.here import here
jax.config.update("jax_enable_x64", True)
from isssm.estimation import initial_theta
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from datetime import dateImports
Data
county_dict = pd.read_csv(here("data/processed/ags_county_dict.csv"))
state_dict = pd.read_csv(here("data/processed/ags_state_dict.csv"))
county_dict["ags_state"] = county_dict["ags"] // 1000
county_dict = county_dict.rename(columns={"name": "name_county", "ags": "ags_county"})
state_dict = state_dict.rename(columns={"ags": "ags_state"})
ags_state_order = pd.merge(county_dict, state_dict, on="ags_state")[
"ags_state"
].to_numpy()def incidence_per_state(I_county):
return jnp.bincount(ags_state_order, I_county)[1:]Parameters
min_date = date(2020, 10, 17)
max_date = date(2022, 1, 22)
np1 = 10Dates
initial_dates = [
(d - pd.Timedelta(weeks=np1)).strftime("%Y-%m-%d")
for d in pd.date_range(start=min_date, end=max_date, freq="W-SAT", inclusive="both")
]
def initial_to_final_date(initial_date: date) -> date:
return initial_date + pd.Timedelta(weeks=np1)
final_dates = [
initial_to_final_date(pd.to_datetime(d)).strftime("%Y-%m-%d") for d in initial_dates
]
len(initial_dates)67
def make_aux(date, cases_full, n_ij, n_tot, np1):
final_date = initial_to_final_date(pd.to_datetime(date))
df_weekly_cases = pd.read_csv(
here()
/ f"data/processed/RKI_county_{(final_date + pd.Timedelta(days=-5)).strftime('%Y-%m-%d')}.csv"
).pivot(index="date", columns="ags", values="cases")
cases_full = jnp.asarray(df_weekly_cases.to_numpy(), dtype=jnp.float64)
cases = cases_full[-(np1 + 1) :]
return cases, n_ij, n_totGLOBAL_KEY = jrn.PRNGKey(4534365463653)
def f_pred(x, s, y):
y_county = y[-1]
y_state = incidence_per_state(y_county)
return jnp.concatenate([y_county, y_state])def create_regional_forecast_dataframe(initial_date):
with open(
here() / f"data/results/4_local_outbreak_model/results_{initial_date}.pickle",
"rb",
) as f:
result = pickle.load(f)
(theta0, proposal_la, _, dates, _) = result
(dates_index,) = jnp.where(dates_full == initial_date)[0]
dates = dates_full[dates_index : dates_index + np1]
aux = make_aux(initial_date, cases_full, n_ij, n_tot, np1)
y = aux[0][1:]
y_nan = y.at[-1].set(jnp.nan)
missing_inds = jnp.isnan(y_nan)
_, y_miss = account_for_nans(growth_factor_model(theta0, aux), y_nan, missing_inds)
_model_miss = lambda theta, aux: account_for_nans(
growth_factor_model(theta, aux), y_nan, missing_inds
)[0]
fitted_model = _model_miss(theta0, aux)
key, subkey = jrn.split(GLOBAL_KEY)
preds = prediction(
f_pred,
y_miss,
proposal_la,
fitted_model,
1000,
subkey,
percentiles_of_interest,
growth_factor_model(theta0, aux),
)
df = pd.DataFrame(
{
"variable": [
*[f"y_county_{c}" for c in range(1, 401)],
*[f"y_state_{s}" for s in range(1, 17)],
],
"c": [*range(1, 401), *[c for c in range(1, 17)]],
"t": [
*jnp.repeat(10, 400),
*jnp.repeat(10, 16),
],
"mean": preds[0],
"sd": preds[1],
**{
f"{p * 100:.1f} %": preds[2][i, :]
for i, p in enumerate(percentiles_of_interest)
},
}
)
df["date"] = [dates[t - 1] for t in df["t"]]
df.to_csv(
here(
f"data/results/4_local_outbreak_model/regional_forecasts_{initial_date}.csv"
),
index=False,
)for initial_date in tqdm(initial_dates):
create_regional_forecast_dataframe(initial_date)