Traditional value assertions in jitted JAX lead to performance degredation. A new (not yet public) JAX API fixes this.

1

Jitted JAX does not support traditional python asserts that access JAX arrays . Chex and jax.experimental.checkify.check provide ways of wrapping a jitted function with decorators to enable value assertions, but they lead to performance-degradation, making them unusable in practical settings.

For a performance-degradation free way of using value assertions in jitted JAX, we can use a new (as of today still private) JAX API: error_check :

import jax from jax._src.error_check import set_error_if, raise_if_error import jax.numpy as jnp @jax.jit def f(x, y): set_error_if(x != 0, 'x must be 0') return jnp.multiply(x, y) f(1, 0) raise_if_error() Traceback (most recent call last): File "/home/ubuntu/code/temp.py", line 12, in raise_if_error() File "/home/ubuntu/code/.venv/lib/python3.10/site-packages/jax/_src/error_check.py", line 93, in raise_if_error raise exc.with_traceback(filtered_traceback) File "/home/ubuntu/code/temp.py", line 10, in f(1, 0) File "/home/ubuntu/code/temp.py", line 7, in f set_error_if(x != 0, 'x must be 0') jax._src.error_check.JaxValueError: x must be 0

This pattern exploits that it suffices to raise an assertion error post-hoc, in this case after the computation of the jitted function. Thus, the implementation merely conditionally stores the error in JAX-managed context. While purely functional conditional computation is fully supported by JAX and XLA , and thus fully compatible with XLA graph compilation, the error is only raised outside of the jitted function, avoiding the typical performance overhead of value assertions.

Contributions

MM and FS worked on research and analysis, FS wrote the manuscript.