Jitted JAX does not support traditional python asserts that access JAX arrays .
Chex and jax.experimental.checkify.check provide ways of wrapping a jitted function
with decorators to enable value assertions, but they lead to performance-degradation, making them unusable in practical settings.
For a performance-degradation free way of using value assertions in jitted JAX, we can use a new (as of today still private) JAX API: error_check:
import jax
from jax._src.error_check import set_error_if, raise_if_error
import jax.numpy as jnp
@jax.jit
def f(x, y):
set_error_if(x != 0, 'x must be 0')
return jnp.multiply(x, y)
f(1, 0)
raise_if_error()
Traceback (most recent call last):
File "/home/ubuntu/code/temp.py", line 12, in
raise_if_error()
File "/home/ubuntu/code/.venv/lib/python3.10/site-packages/jax/_src/error_check.py", line 93, in raise_if_error
raise exc.with_traceback(filtered_traceback)
File "/home/ubuntu/code/temp.py", line 10, in
f(1, 0)
File "/home/ubuntu/code/temp.py", line 7, in f
set_error_if(x != 0, 'x must be 0')
jax._src.error_check.JaxValueError: x must be 0
This pattern exploits that it suffices to raise an assertion error post-hoc, in this case after the computation of the jitted function.
Thus, the implementation merely conditionally stores the error in JAX-managed context. While purely functional conditional computation is fully supported by JAX and XLA , and thus fully compatible with XLA graph compilation,
the error is only raised outside of the jitted function, avoiding the typical performance overhead of value assertions.
Contributions
MM and FS worked on research and analysis, FS wrote the manuscript.