import jax
Utilities
"jax_enable_x64", True) jax.config.update(
sampling from degenerate Multivariate normal
The MultivariateNormalFullCovariance
distribution from tfp
only supports non-singular covariance matrices for sampling, because internally a Cholesky decomposition is used, which is ambiguous for singular symmetric matrices. Instead, we use an eigenvalue decomposition, and compute a valid Cholesky root by QR-decomposition.
MVN_degenerate
MVN_degenerate (loc:jax.Array, cov:jax.Array)
degenerate_cholesky
degenerate_cholesky (Sigma)
import jax.random as jrn
import matplotlib.pyplot as plt
import fastcore.test as fct
= jnp.zeros(2)
mu = jnp.array([[1.0, 1.0], [1.0, 1.0]])
Sigma
= 1000
N = jrn.PRNGKey(1423423)
key = jrn.split(key)
key, subkey = MVN_degenerate(mu, Sigma).sample(seed=subkey, sample_shape=(N,))
samples "Samples from degenerate 2D Gaussian")
plt.title(0], samples[:, 1])
plt.scatter(samples[:,
plt.show()
@ jnp.array([[1.0], [-1.0]]), jnp.zeros(N))
fct.test_close(samples
= degenerate_cholesky(Sigma)
L # ensure cholesky is correct
@ L.T)
fct.test_close(Sigma, L @ L) fct.test_ne(Sigma, L.T
optimization
converged
converged (new:jaxtyping.Float[Array,'...'], old:jaxtyping.Float[Array,'...'], eps:jaxtyping.Float)
check that sup-norm of relative change is smaller than tolerance
Type | Details | |
---|---|---|
new | Float[Array, ‘…’] | the new array |
old | Float[Array, ‘…’] | the old array |
eps | Float | tolerance |
Returns | Bool | whether the arrays are close enough |
vmapped utilities
Throughout the package we make extensive use of matrix-vector multiplication. Depending on the algorithm, different vectorizations are helpful.
Let \(B \in \mathbf R^{(n+1)\times p \times m}\) be a list of \(n + 1\) matrices, let \(X \in \mathbf R^{(n + 1) \times m}\) be a set of states and let \(\mathbf X \in \mathbf R^{N \times (n + 1) \times p}\) be \(N\) simulations of \(X\).
mm_sim
allows to multiply at a single time point \(t\) the single matrix \(B_t\) with all \(X_t^i\), i.e, maps \[\mathbf R^{p \times m} \times \mathbf R^{N \times m} \to \mathbf R^{N \times p}.\]
mm_time
allows to map the single sample \(X\) for each time \(t\) to \((B_tX_t)_{t = 0, \dots, n}\), i.e. maps \[\mathbf R^{(n +1) \times p \times m} \times \mathbf R^{(n + 1) \times m} \to \mathbf R^{(n+1) \times p}.\]
mm_time_sim
allows to multiply all samples \(\mathbf X\) ;or all times with matrices \(B\), i.e. maps from \[\mathbf R^{(n + 1) \times p \times m}\times \mathbf R^{N \times (n+1) \times m} \to \mathbf R^{N \times (n + 1) \times p}.\]
Exported source
# multiply $B_t$ and $X^i_t$
= vmap(jnp.matmul, (None, 0))
mm_sim # matmul with $(B_t)_{t}$ and $(X_t)_{t}$
= vmap(jnp.matmul, (0, 0))
mm_time # matmul with $(B_t)_{t}$ and $(X^i_t)_{i,t}$
= vmap(mm_time, (None, 0)) mm_time_sim
= 1000, 100, 3, 5
N, np1, p, m = jrn.split(key)
key, subkey = jrn.normal(subkey, (np1, p, m))
B = jrn.split(key)
key, subkey = jrn.normal(subkey, (N, np1, m))
X
0], X[:, 0, :]).shape, (N, p))
fct.test_eq(mm_sim(B[0]).shape, (np1, p))
fct.test_eq(mm_time(B, X[ fct.test_eq(mm_time_sim(B, X).shape, (N, np1, p))
Appending to the front of an array
append_to_front
append_to_front (a0:jaxtyping.Float[Array,'...'], a:jaxtyping.Float[Array,'n...'])
Antithetic variables
To improve the efficiency of importance sampling (Durbin and Koopman 1997) recommend using antithetic variables. These are a device to reduce Monte-Carlo variance by introducing negative correlations. We use both location- and scale-balanced antithetic variables.
scale_antithethic
scale_antithethic (u:jaxtyping.Float[Array,'Nn+1k'], samples:jaxtyping.Float[Array,'Nn+1p'], mean:jaxtyping.Float[Array,'n+1p'])
location_antithetic
location_antithetic (samples:jaxtyping.Float[Array,'N...'], mean:jaxtyping.Float[Array,'N...'])