warp.jax#
JAX integration for Warp.
This module provides helpers for converting arrays between Warp and JAX, along with JAX primitives for calling Warp kernels and Warp-backed Python functions from JAX.
The jax_kernel function wraps individual Warp
kernels, while jax_callable wraps Python
functions that launch one or more Warp kernels. Both support automatic
differentiation, custom launch dimensions, and CUDA graph capture.
- Usage:
This module must be explicitly imported:
import warp.jax
See also
Using Warp Kernels as JAX Primitives in the user guide for detailed examples and usage patterns.
Additional Submodules#
These modules must be explicitly imported (e.g., import warp.jax.custom_call).
JAX Array Interop#
Return the Warp device corresponding to a Jax device. |
|
Return the Jax device corresponding to a Warp device. |
|
Return the Warp dtype corresponding to a Jax dtype. |
|
Return the Jax dtype corresponding to a Warp dtype. |
|
Convert a Jax array to a Warp array without copying the data. |
|
Convert a Warp array to a Jax array without copying the data. |