from jax import vmap, jvp
from functools import partial
def grad_from_consecutive_logits(
primals: Float[Array, "k-1"], tangents: Float[Array, "k-1"]
) -> Float[Array, "k-1"]:
(l,) = primals.shape
jac = jnp.zeros((l + 1, l))
p = from_consecutive_logits(primals)
q = jsp.special.expit(primals)
for i in range(l):
jac = jac.at[i, i].set(p[i] * (1 - q[i]))
for k in range(i):
jac = jac.at[i, k].set(-p[i] * q[k])
jac = jac.at[l, :].set(-jnp.prod(1 - q) * q)
return p, jac @ tangents
key = jrn.PRNGKey(0)
key, subkey = jrn.split(key)
rand_primal = jrn.normal(subkey, (5,))
key, subkey = jrn.split(key)
rand_tangent = jrn.normal(subkey, (5,))
(
jvp(from_consecutive_logits, (rand_primal,), (rand_tangent,))[1]
- grad_from_consecutive_logits(rand_primal, rand_tangent)[1]
)
# relative error
def rel_error(a, b):
return jnp.abs(a - b) / (jnp.abs(a) + jnp.abs(b) + 1e-10)
fct.test_close(
rel_error(
jvp(from_consecutive_logits, (rand_primal,), (rand_tangent,))[1],
grad_from_consecutive_logits(rand_primal, rand_tangent)[1],
),
jnp.zeros_like(6),
)Utilities for the models of this thesis
Visualization
Computation
For \(p \in \mathbf R^{k}_{>0}\) with \(\sum_{i = 1}^k p_{i} = 1\), let \(\log q_i = \log \frac{p_{i}}{p_{k}}\) for \(i = 1, \dots, k -1\). Then \[ p_{k} = \frac{1}{1 + \sum_{i = 1}^{k-1}q_{i}}, \] so $$ p_{i} = q_{i} p_{k} = .
$$
Another parametrization takes consecutive conditonal probabilities, using logits to make the problem unconstrained.
Thus for \(p\in \mathbf R^k\) we have \[ q_{i} = \frac{p_{i}}{1 - \sum_{j = 1}^{i - 1} p_{j}} = \frac{p_{i}}{\sum_{j = i}^k p_{j}}, \] for \(i = 1, \dots, k - 1\) (\(q_k\) is \(1\) and can be discarded).
Then for \(i = 1, \dots, k\) \[ p_{i} = q_{i} \prod_{j = 1}^{i - 1}(1 - q_j). \]
checking the derivative
we have \[ \partial_{\operatorname{logit} (q_{k})}(p_{i}) = \partial_{\operatorname{logit} (q_{i})} \left( \operatorname{expit}(\operatorname{logit}(q_{i})) \prod_{j= 1}^{i - 1} (1 - \operatorname{expit}(\operatorname{logit}(q_{j}))) \right) = \begin{cases} p_{i} (1 - q_{k}) & k = i \\ -p_{i}q_{k} & k < i \\ 0 & \text{else} \end{cases} \]