JAX Validation Tutorial¶
Validate JAX custom ops and primitives end-to-end using gpuemu -- including JIT safety, vmap compatibility, and gradient correctness.
Prerequisites¶
Before you begin, make sure the following are in place:
- gpuemu CLI installed and on your
PATH(Installation) - gpuemu daemon running (
gpuemu daemon start --background) - Python 3.9+ with a virtual environment activated
- gpuemu with JAX adapter installed:
Verify your setup
Setup¶
Initialize a new gpuemu project configured for JAX:
This generates the following project structure:
The generated gpuemu.toml includes JAX-specific defaults:
[project]
name = "my-jax-ops"
version = "0.1.0"
framework = "jax"
[validation]
dtypes = ["float32", "float16", "bfloat16"]
check_nan = true
check_inf = true
[validation.tolerances]
float32 = { atol = 1.5e-5, rtol = 1.5e-5 }
float16 = { atol = 1.5e-2, rtol = 1.5e-2 }
bfloat16 = { atol = 1.5e-2, rtol = 1.5e-2 }
[[ops]]
name = "my_op"
module = "my_jax_ops.ops"
reference = "scripts/my_op_ref.py"
execution_mode = "script_based"
Why 1.5x tolerance for JAX?
JAX compiles operations through XLA, which applies aggressive optimizations such as operator fusion and layout transformations. These optimizations can change the order of floating-point operations, introducing small numerical differences compared to a plain NumPy reference. The default tolerances are set to 1.5x the PyTorch defaults to account for this.
Reference Script¶
Write a NumPy-based reference implementation. The protocol is the same as for other frameworks -- JSON+base64 over stdin/stdout.
"""Reference implementation for softmax."""
import json
import base64
import sys
import numpy as np
def decode_tensor(encoded: dict) -> np.ndarray:
"""Decode a base64-encoded tensor from the input payload."""
data = base64.b64decode(encoded["data"])
dtype = np.dtype(encoded["dtype"])
shape = tuple(encoded["shape"])
return np.frombuffer(data, dtype=dtype).reshape(shape)
def encode_tensor(arr: np.ndarray) -> dict:
"""Encode a numpy array as a base64 JSON-serializable dict."""
return {
"data": base64.b64encode(arr.tobytes()).decode("ascii"),
"dtype": str(arr.dtype),
"shape": list(arr.shape),
}
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically stable softmax."""
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def main():
request = json.loads(sys.stdin.read())
x = decode_tensor(request["inputs"]["x"])
axis = request["inputs"].get("axis", -1)
result = softmax(x, axis=axis)
response = {"outputs": {"result": encode_tensor(result)}}
json.dump(response, sys.stdout)
if __name__ == "__main__":
main()
Wire up the reference in gpuemu.toml:
[[ops]]
name = "softmax"
module = "my_jax_ops.ops.softmax"
reference = "scripts/softmax_ref.py"
execution_mode = "script_based"
[ops.tolerances]
float32 = { atol = 1.5e-5, rtol = 1.5e-5 }
float16 = { atol = 1.5e-2, rtol = 1.5e-2 }
Single-Shot Validation¶
The validate_jax() context manager validates a single JAX op invocation against its reference.
import jax
import jax.numpy as jnp
from gpuemu import Client
from gpuemu.frameworks.jax import validate_jax
client = Client()
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (4, 64))
with validate_jax(client, "softmax", {"x": x}) as ctx:
ctx["output"] = jax.nn.softmax(x, axis=-1)
The context manager:
- Converts JAX arrays to the JSON+base64 wire format.
- Sends inputs to the daemon, which runs the reference script.
- Compares
ctx["output"]against the reference using configured tolerances.
Use jnp arrays, not NumPy arrays
Always pass jax.numpy arrays (not plain numpy arrays) to validate_jax(). The adapter relies on JAX array metadata for dtype handling and device transfer. Passing NumPy arrays may produce incorrect dtype conversions.
JAX-Specific Checks¶
JAX programs are expected to work correctly under several transformations: jit, vmap, pmap, and grad. gpuemu provides dedicated checks for each.
JIT Safety¶
check_jit_safe() verifies that your op produces the same results when run eagerly versus under jax.jit. Differences indicate reliance on Python-level side effects or tracing-time values.
from gpuemu.frameworks.jax import check_jit_safe
result = check_jit_safe(
func=my_custom_softmax,
sample_inputs={"x": x},
atol=1e-6,
)
assert result.passed, f"JIT safety check failed: {result.message}"
vmap Compatibility¶
check_vmap_compatible() verifies that your op can be batched with jax.vmap and produces correct results across the batch dimension.
from gpuemu.frameworks.jax import check_vmap_compatible
result = check_vmap_compatible(
func=my_custom_softmax,
sample_inputs={"x": x}, # x has shape (4, 64)
vmap_axis=0, # Batch over the first dimension
)
assert result.passed, f"vmap check failed: {result.message}"
pmap Compatibility¶
check_pmap_compatible() verifies that your op works correctly under jax.pmap for multi-device execution. On a single-device machine, this simulates multi-device behavior.
from gpuemu.frameworks.jax import check_pmap_compatible
result = check_pmap_compatible(
func=my_custom_softmax,
sample_inputs={"x": x},
num_devices=2, # Simulated device count
)
assert result.passed, f"pmap check failed: {result.message}"
Gradient Safety¶
check_grad_safe() verifies that jax.grad can be computed for your op and that the resulting gradients are finite and consistent with a numerical approximation.
from gpuemu.frameworks.jax import check_grad_safe
# Use a scalar-output wrapper for grad (JAX requires scalar output)
def scalar_softmax(x):
return jnp.sum(my_custom_softmax(x))
result = check_grad_safe(
func=scalar_softmax,
sample_inputs=(x,),
eps=1e-4,
)
assert result.passed, f"Gradient check failed: {result.message}"
All four checks at a glance
| Check | What it validates | Common failure causes |
|---|---|---|
check_jit_safe() |
Eager vs JIT produce same results | Python side effects during tracing |
check_vmap_compatible() |
Correct batching under vmap | Hardcoded shapes, non-batched indexing |
check_pmap_compatible() |
Multi-device correctness under pmap | Device-specific state, collective ops |
check_grad_safe() |
Gradient correctness | Non-differentiable ops, numerical instability |
Primitive Validation¶
If you have implemented a custom JAX primitive (using jax.core.Primitive), use validate_jax_primitive() for a comprehensive validation that covers the forward pass, JVP rule, and vmap rule in one call.
from gpuemu.frameworks.jax import validate_jax_primitive
result = validate_jax_primitive(
client,
primitive=my_custom_primitive,
sample_inputs={"x": x},
op_name="my_primitive",
check_jvp=True, # Validate the JVP (forward-mode AD) rule
check_vmap=True, # Validate the vmap batching rule
)
assert result.passed, f"Primitive validation failed: {result.message}"
This validates:
- Forward evaluation matches the reference implementation
- JVP rule produces correct tangent outputs (if
check_jvp=True) - vmap batching rule correctly handles the batch dimension (if
check_vmap=True) - Composed transformations (e.g.,
jit(vmap(grad(...)))) are consistent
Fuzzing¶
Use fuzz_jax_op() to stress-test your op with randomized inputs across many shapes and dtypes, optionally including vmap and JIT checks on every iteration.
from gpuemu import Client
from gpuemu.frameworks.jax import fuzz_jax_op
client = Client()
def my_softmax(inputs):
"""The op under test."""
import jax.numpy as jnp
x = inputs["x"]
x_max = jnp.max(x, axis=-1, keepdims=True)
exp_x = jnp.exp(x - x_max)
return {"result": exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)}
results = fuzz_jax_op(
client,
op_name="softmax",
op_fn=my_softmax,
iterations=100,
check_vmap=True, # Run vmap check on each iteration
check_jit=True, # Run JIT safety check on each iteration
)
print(f"Passed: {results.passed}, Failed: {results.failed}")
for failure in results.failures:
print(f" Seed {failure.seed}: {failure.message}")
Fuzzing flags
| Flag | Effect |
|---|---|
check_vmap=True |
Verifies vmap compatibility on each fuzz iteration |
check_jit=True |
Verifies JIT safety on each fuzz iteration |
| Both enabled | Catches transformation-related bugs alongside numerical errors |
Enabling both flags increases the time per iteration but provides much stronger coverage.
Tips¶
JAX tolerance is 1.5x the PyTorch default
Due to XLA optimizations (operator fusion, layout changes, constant folding), JAX outputs can differ slightly more from a NumPy reference than PyTorch outputs do. The default gpuemu tolerances for JAX are set 1.5x higher to account for this. If you are seeing false positives, consider whether XLA is legitimately reordering operations.
Always use jnp arrays
The JAX adapter expects jax.numpy arrays. Passing plain NumPy arrays will bypass JAX's type system and may cause subtle dtype mismatches, especially with bfloat16 which NumPy does not natively support.
JAX tracing and side effects
JAX traces functions to build computation graphs. Operations that depend on concrete Python values at trace time (such as if x > 0: where x is a traced value) will fail under jit. The check_jit_safe() function specifically catches these issues by comparing eager and JIT-compiled results.
Reproducible keys
When writing tests, always use a fixed jax.random.PRNGKey to ensure reproducibility:
gpuemu's seed-based reproduction is separate from JAX's PRNG, but using fixed keys in your test scripts makes manual debugging easier.
Next Steps¶
- PyTorch Validation Tutorial -- Validate PyTorch custom ops.
- TensorFlow Validation Tutorial -- Validate TensorFlow custom ops.
- CI Integration -- Run gpuemu validations in your CI pipeline.
- Configuration -- Fine-tune tolerances, dtypes, and policies.