Utilities

import jax
jax.config.update("jax_enable_x64", True)

sampling from degenerate Multivariate normal

The MultivariateNormalFullCovariance distribution from tfp only supports non-singular covariance matrices for sampling, because internally a Cholesky decomposition is used, which is ambiguous for singular symmetric matrices. Instead, we use an eigenvalue decomposition, and compute a valid Cholesky root by QR-decomposition.


source

MVN_degenerate

 MVN_degenerate (loc:jax.Array, cov:jax.Array)

source

degenerate_cholesky

 degenerate_cholesky (Sigma)
import jax.random as jrn
import matplotlib.pyplot as plt
import fastcore.test as fct
mu = jnp.zeros(2)
Sigma = jnp.array([[1.0, 1.0], [1.0, 1.0]])

N = 1000
key = jrn.PRNGKey(1423423)
key, subkey = jrn.split(key)
samples = MVN_degenerate(mu, Sigma).sample(seed=subkey, sample_shape=(N,))
plt.title("Samples from degenerate 2D Gaussian")
plt.scatter(samples[:, 0], samples[:, 1])
plt.show()

fct.test_close(samples @ jnp.array([[1.0], [-1.0]]), jnp.zeros(N))

L = degenerate_cholesky(Sigma)
# ensure cholesky is correct
fct.test_close(Sigma, L @ L.T)
fct.test_ne(Sigma, L.T @ L)

optimization


source

converged

 converged (new:jaxtyping.Float[Array,'...'],
            old:jaxtyping.Float[Array,'...'], eps:jaxtyping.Float)

check that sup-norm of relative change is smaller than tolerance

Type Details
new Float[Array, ‘…’] the new array
old Float[Array, ‘…’] the old array
eps Float tolerance
Returns Bool whether the arrays are close enough

vmapped utilities

Throughout the package we make extensive use of matrix-vector multiplication. Depending on the algorithm, different vectorizations are helpful.

Let \(B \in \mathbf R^{(n+1)\times p \times m}\) be a list of \(n + 1\) matrices, let \(X \in \mathbf R^{(n + 1) \times m}\) be a set of states and let \(\mathbf X \in \mathbf R^{N \times (n + 1) \times p}\) be \(N\) simulations of \(X\).

mm_sim allows to multiply at a single time point \(t\) the single matrix \(B_t\) with all \(X_t^i\), i.e, maps \[\mathbf R^{p \times m} \times \mathbf R^{N \times m} \to \mathbf R^{N \times p}.\]

mm_time allows to map the single sample \(X\) for each time \(t\) to \((B_tX_t)_{t = 0, \dots, n}\), i.e. maps \[\mathbf R^{(n +1) \times p \times m} \times \mathbf R^{(n + 1) \times m} \to \mathbf R^{(n+1) \times p}.\]

mm_time_sim allows to multiply all samples \(\mathbf X\) ;or all times with matrices \(B\), i.e. maps from \[\mathbf R^{(n + 1) \times p \times m}\times \mathbf R^{N \times (n+1) \times m} \to \mathbf R^{N \times (n + 1) \times p}.\]

Exported source
# multiply $B_t$ and $X^i_t$
mm_sim = vmap(jnp.matmul, (None, 0))
# matmul with $(B_t)_{t}$ and $(X_t)_{t}$
mm_time = vmap(jnp.matmul, (0, 0))
# matmul with $(B_t)_{t}$ and $(X^i_t)_{i,t}$
mm_time_sim = vmap(mm_time, (None, 0))
N, np1, p, m = 1000, 100, 3, 5
key, subkey = jrn.split(key)
B = jrn.normal(subkey, (np1, p, m))
key, subkey = jrn.split(key)
X = jrn.normal(subkey, (N, np1, m))

fct.test_eq(mm_sim(B[0], X[:, 0, :]).shape, (N, p))
fct.test_eq(mm_time(B, X[0]).shape, (np1, p))
fct.test_eq(mm_time_sim(B, X).shape, (N, np1, p))

Appending to the front of an array


source

append_to_front

 append_to_front (a0:jaxtyping.Float[Array,'...'],
                  a:jaxtyping.Float[Array,'n...'])

Antithetic variables

To improve the efficiency of importance sampling (Durbin and Koopman 1997) recommend using antithetic variables. These are a device to reduce Monte-Carlo variance by introducing negative correlations. We use both location- and scale-balanced antithetic variables.


source

scale_antithethic

 scale_antithethic (u:jaxtyping.Float[Array,'Nn+1k'],
                    samples:jaxtyping.Float[Array,'Nn+1p'],
                    mean:jaxtyping.Float[Array,'n+1p'])

source

location_antithetic

 location_antithetic (samples:jaxtyping.Float[Array,'N...'],
                      mean:jaxtyping.Float[Array,'N...'])

References

Durbin, James, and Siem Jan Koopman. 1997. “Monte Carlo Maximum Likelihood Estimation for Non-Gaussian State Space Models.” Biometrika 84 (3): 669–84. https://doi.org/10.1093/biomet/84.3.669.