Monkey patching

def pgsssm_to_glssm(model: PGSSM, Omega: Float[Array, "n+1 p p"]) -> GLSSM:
    return GLSSM(
        u=model.u,
        A=model.A,
        D=model.D,
        Sigma0=model.Sigma0,
        Sigma=model.Sigma,
        v=model.v,
        B=model.B,
        Omega=Omega,
    )
from isssm.typing import to_glssm


def clip_negative_evals(proposal: GLSSMProposal) -> GLSSMProposal:
    z = proposal.z
    Omega = proposal.Omega

    glssm = to_glssm(proposal)
    filtered = kalman(z, glssm)
    s = smoothed_signals(filtered, z, glssm)
    # clip eigenvalues of final Omega to ensure PSD
    evals, evecs = jnp.linalg.eigh(Omega)
    if (evals > 0).all():
        return proposal

    evals_clipped = jnp.where(evals < 1e-8, 0.0, evals)
    Omega_new = evecs @ vmap(jnp.diag)(evals_clipped) @ jnp.transpose(evecs, (0, 2, 1))

    # adjust z s.t. (z -s) lies in span of cOmega, project to span of cOmega, then substract
    cOmega = jnp.linalg.cholesky(Omega_new)
    # projection matrix on im(cOmega ) is cOmega @ jnp.linalg.pinv(cOmega), but use more stable lstsq
    z_new = s + (cOmega @ jnp.linalg.lstsq(cOmega, z - s)[..., None])[..., 0]

    new_proposal = GLSSMProposal(
        **proposal,
        z=z_new,
        Omega=Omega,
    )

    return new_proposal