Feat: JAX integration via XLA FFI custom calls (supersedes #173)#190
Open
Feat: JAX integration via XLA FFI custom calls (supersedes #173)#190
Conversation
Based on PR #173 by Chao Chen <cchen104@amd.com>, adapted for the refactored architecture (raw pointer args + Python-side kernel launch).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Supersedes PR #173 (by @i-chaochen and @pemeliya ), adapted for the refactored architecture (PR #182).
Add JAX support to MORI via XLA FFI custom calls, enabling JAX users to use mori kernels and shmem without any torch dependency.
Technical Details
XLA FFI Framework
dispatch/combine/dispatch_recv/combine_recv/resetwithhipModuleLaunchKernelHandleManagersingleton for thread-safe handle lifecycleKernelManagersingleton for JIT-compiled.hsacomodule loadingENABLE_XLA_FFIcmake option (default OFF), buildslibmori_xla_ffi_ops.soTorch-Free JAX Path
mori/__init__.pyare now lazy-imported via__getattr__, soimport mori.jax.*does not triggerimport torchextern "C"fromlibmori_xla_ffi_ops.somori.jax.shmemis fully self-contained: calls shmem via ctypes, includes inline DLPack/GpuTensorView, never touchesmori.cppor pybindJAX Array Interop
jax_data_ptr(arr)— extract raw GPU pointer fromjax.Arrayshmem_ptr_to_jax(ptr, shape, dtype)— wrap shmem memory asjax.Arrayvia DLPack (zero-copy)shmem_register_jax_array(arr)— register JAX array for RDMAJIT + globalGpuStates Integration
EpDispatchCombineOpauto-compiles.hsacokernels and registers with FFIKernelManagershmem_module_init_for_kernel()initializesglobalGpuStatesin loaded kernel modulesTODO
shmem_jax_init()multi-process UniqueId broadcastdispatch()/combine()inops.pyviajax.ffi.ffi_call, add shmem outputextern "C"exports tolibmori_xla_ffi_ops.soif JAX users need IOKey differences from #173
_jax.register_custom_call_target(internal API)jax.ffi.register_ffi_target(public API, JAX 0.7+)GpuTensorViewwith GC-safe capsule lifecyclemori.jax.*Test Plan
Verified on ROCm 7.1 + JAX 0.7.1 + 8x MI300X (gfx942):
jax_data_ptr+shmem_ptr_to_jaxround-trip, multi-dtypejax.ffi.register_ffi_targetfrom mori.jax.*does not import torchmori.ops/mori.cpp/mori.shmem/mori.io/mori.irunchangedAcknowledgements
This PR is based on the original work in #173 . The XLA FFI framework design, vendored headers, and handle manager are from their contribution.
Co-authored-by: @i-chaochen cchen104@amd.com
Co-authored-by: @pemeliya pavel.emeliyanenko@amd.com
Co-authored-by: @jhchouuu jiahzhou@amd.com