Model Developer Guide¶
A guide for ML engineers who write PyTorch, JAX, or TensorFlow training code and want to validate custom ops without GPU access.
Who This Is For¶
Target audience
You are a model developer who:
- Writes training pipelines using PyTorch, JAX, or TensorFlow
- Uses custom ops (fused kernels, custom autograd functions, third-party extensions)
- Wants to validate that those ops produce correct results without needing a GPU
- Works in a team where GPU machines are shared or limited
If you are integrating third-party kernels you did not write (FlashAttention, xFormers, etc.), see the Custom Op Integrator Guide. If you are writing CUDA kernels directly, see the Kernel Author Guide.
Day 1 Setup¶
1. Initialize your project¶
Run the init command with your framework of choice:
This creates a gpuemu.toml configuration file in your project root with sensible defaults for your framework.
2. Start the daemon¶
The daemon is the background process that handles validation, fuzzing, and result storage:
Auto-start in VS Code
If you use the VS Code extension, the daemon starts automatically when your workspace contains a gpuemu.toml file. You can control this with the gpuemu.autoStartDaemon setting.
3. Register your first op¶
Open gpuemu.toml and add an op entry:
[[ops]]
name = "my_custom_op"
module = "my_module"
reference = "scripts/ref_my_custom_op.py"
[ops.tolerances]
float32 = 1e-5
float16 = 1e-3
[ops.invariants]
no_nan = true
no_inf = true
4. Write the reference script¶
Create scripts/ref_my_custom_op.py. This is a pure NumPy implementation of your op that gpuemu uses as ground truth:
# scripts/ref_my_custom_op.py
import numpy as np
def reference(inputs: dict, **kwargs) -> np.ndarray:
"""Reference implementation of my_custom_op.
Args:
inputs: Dict mapping input names to numpy arrays.
Returns:
The expected output as a numpy array.
"""
x = inputs["x"]
# Your op's math in pure numpy
return np.square(x) + 2.0 * x + 1.0
5. Install the Python client¶
Daily Workflow¶
The core development loop with gpuemu is: Write op, write reference, validate, fix, repeat.
The validate_op() context manager¶
The primary way to validate an op is the validate_op() context manager. It captures your op's output and sends it to the daemon for comparison against the reference:
import torch
from gpuemu.client import Client
from gpuemu.frameworks.pytorch import validate_pytorch
client = Client()
x = torch.randn(32, 128)
with validate_pytorch(client, "my_custom_op", {"x": x}) as ctx:
ctx["output"] = my_custom_op(x)
# If the output does not match the reference, ValidationError is raised.
# If it passes, execution continues normally.
The low-level validate_op() helper¶
For framework-agnostic usage or more control, use the generic validate_op context manager from gpuemu.validate:
from gpuemu.client import Client
from gpuemu.validate import validate_op
import numpy as np
client = Client()
inputs = {"x": np.random.randn(32, 128).astype(np.float32)}
with validate_op(client, "my_custom_op", inputs=inputs) as ctx:
# Run your op and store the result
ctx["output"] = my_custom_op_numpy(inputs["x"])
# Access the result object after the context exits
result = ctx["result"]
print(f"Passed: {result.passed}, max_diff: {result.max_diff:.2e}")
Validating gradients¶
For PyTorch ops with custom backward passes, pass check_backward=True:
x = torch.randn(32, 128, requires_grad=True)
with validate_pytorch(client, "my_custom_op", {"x": x}, check_backward=True) as ctx:
ctx["output"] = my_custom_op(x)
This validates both the forward output and the analytical gradients against numerical finite differences.
Using Fuzzing¶
Manual tests with a single shape and dtype are not enough. Use fuzzing to test across diverse configurations and catch edge cases.
Shape fuzzing with fuzz_shapes()¶
The fuzz_shapes() generator produces all combinations of the dimension values you provide:
from gpuemu.validate import fuzz_shapes
for batch, seq in fuzz_shapes(batch=[1, 2, 4, 8, 16], seq=[64, 128, 256, 512]):
x = torch.randn(batch, seq, 512)
with validate_pytorch(client, "my_custom_op", {"x": x}) as ctx:
ctx["output"] = my_custom_op(x)
Dtype fuzzing with fuzz_dtypes()¶
The fuzz_dtypes() generator iterates over dtype strings:
from gpuemu.validate import fuzz_dtypes
for dtype in fuzz_dtypes(["float32", "float16", "bfloat16"]):
x = torch.randn(8, 256, dtype=getattr(torch, dtype))
with validate_pytorch(client, "my_custom_op", {"x": x}) as ctx:
ctx["output"] = my_custom_op(x)
Combined shape + dtype fuzzing¶
from gpuemu.validate import fuzz_shapes, fuzz_dtypes
for batch, seq in fuzz_shapes(batch=[1, 4, 16], seq=[64, 256]):
for dtype in fuzz_dtypes(["float32", "float16"]):
x = torch.randn(batch, seq, 512, dtype=getattr(torch, dtype))
with validate_pytorch(client, "my_custom_op", {"x": x}) as ctx:
ctx["output"] = my_custom_op(x)
Drop-in fuzzing with fuzz_op_client_side()¶
For the simplest possible fuzzing integration, use client.fuzz_op_client_side(). The daemon generates random inputs, you provide a callable that runs your op, and gpuemu validates every output:
from gpuemu.client import Client
client = Client()
results = client.fuzz_op_client_side(
"my_custom_op",
run_op=lambda inputs: my_custom_op(torch.from_numpy(inputs["x"])).numpy(),
iterations=100,
seed=42,
)
print(f"Passed: {results.passed}/{results.total}")
for failure in results.failures:
print(f" Seed {failure.seed}: {failure.failures[0]['message']}")
Deterministic seeds
Every fuzz iteration has a unique seed. If a test fails, record the seed. You can reproduce the exact same inputs later with client.reproduce(seed).
Daemon-side fuzzing¶
If you do not need client-side GPU execution (e.g., your op can run on CPU), you can let the daemon handle everything:
results = client.fuzz_op(
"my_custom_op",
seed=42,
iterations=200,
fail_fast=False,
dtypes=["float32", "float16"],
layouts=["Contiguous", "Strided", "Transposed"],
)
print(f"Passed: {results.passed}/{results.total}, Failed: {results.failed}")
Reproducing Failures¶
When a fuzz run or validation fails, every failure has a seed that fully determines the inputs, shape, dtype, and layout that triggered the failure.
CLI reproduction¶
This regenerates the exact inputs and re-runs validation, printing a detailed report of what went wrong.
Python API reproduction¶
repro = client.reproduce(8374629105)
print(f"Op: {repro.result.op_name}")
print(f"Shape: {repro.inputs['x'].shape}")
print(f"Dtype: {repro.inputs['x'].dtype}")
print(f"Max diff: {repro.result.max_diff:.2e}")
# You now have the exact input tensors to debug with
x = torch.from_numpy(repro.inputs["x"])
output = my_custom_op(x)
Cross-language RNG
gpuemu uses a bit-for-bit identical xorshift128+ PRNG in both Rust and Python. Seeds are fully reproducible across the CLI, daemon, and Python client regardless of which component generated the original test case.
CI Integration¶
For full CI pipeline setup, see the CI Integration Tutorial. Here is the quick version.
Add quick validation to PR checks¶
Add this to your CI script (GitHub Actions, GitLab CI, etc.):
The --quick flag runs a reduced set of iterations per op for fast feedback. A typical --quick run completes in under 30 seconds.
Example GitHub Actions step¶
Full fuzzing before merge
--quick is designed for fast PR feedback. Always run full fuzzing (gpuemu fuzz --iterations 500) before merging to main. Consider running the full suite as a nightly CI job.
Tips and Best Practices¶
-
Use
--quickfor fast iteration. During development, rungpuemu test --quickorgpuemu ci --quickto get feedback in seconds. Save full fuzzing for pre-merge and nightly runs. -
Run full fuzzing before merge. Shape and dtype edge cases (batch=1, float16, strided layouts) are where most bugs hide. A 500-iteration fuzz run catches issues that a handful of manual tests will miss.
-
Store baselines for regression detection. Before a major refactor or library upgrade, snapshot your current results:
After the change, compare:
-
Keep reference scripts simple. Your
ref_*.pyscripts should be the simplest possible NumPy implementation of the operation. Avoid optimizations -- clarity is more important than speed in a reference. -
Use per-dtype tolerances.
float16andbfloat16need much wider tolerances thanfloat32. Configure them separately ingpuemu.toml: -
Enable invariant checks. Beyond numerical tolerances, invariants catch structural bugs:
Next Steps¶
- Custom Op Integrator Guide -- Validating third-party kernels
- Kernel Author Guide -- Writing and testing CUDA kernels
- VS Code Extension -- Editor integration with live diagnostics
- CI Integration Tutorial -- Full CI pipeline setup
- Python API Reference -- Complete client API documentation