warp.jax#

JAX integration for Warp.

This module provides helpers for converting arrays between Warp and JAX, along with JAX primitives for calling Warp kernels and Warp-backed Python functions from JAX.

The jax_kernel function wraps individual Warp kernels, while jax_callable wraps Python functions that launch one or more Warp kernels. Both support automatic differentiation, custom launch dimensions, and CUDA graph capture.

Usage:

This module must be explicitly imported:

import warp.jax

See also

Using Warp Kernels as JAX Primitives in the user guide for detailed examples and usage patterns.

Additional Submodules#

These modules must be explicitly imported (e.g., import warp.jax.custom_call).

JAX Array Interop#

device_from_jax

Return the Warp device corresponding to a Jax device.

device_to_jax

Return the Jax device corresponding to a Warp device.

dtype_from_jax

Return the Warp dtype corresponding to a Jax dtype.

dtype_to_jax

Return the Jax dtype corresponding to a Warp dtype.

from_jax

Convert a Jax array to a Warp array without copying the data.

to_jax

Convert a Warp array to a Jax array without copying the data.

JAX Callable Interop#