Skip to content

Feat: JAX integration via XLA FFI custom calls (supersedes #173)#190

Open
jhchouuu wants to merge 3 commits intomainfrom
jiahzhou/jax_integration
Open

Feat: JAX integration via XLA FFI custom calls (supersedes #173)#190
jhchouuu wants to merge 3 commits intomainfrom
jiahzhou/jax_integration

Conversation

@jhchouuu
Copy link
Collaborator

@jhchouuu jhchouuu commented Mar 9, 2026

Motivation

Supersedes PR #173 (by @i-chaochen and @pemeliya ), adapted for the refactored architecture (PR #182).

Add JAX support to MORI via XLA FFI custom calls, enabling JAX users to use mori kernels and shmem without any torch dependency.

Technical Details

XLA FFI Framework

  • FFI handlers for dispatch / combine / dispatch_recv / combine_recv / reset with hipModuleLaunchKernel
  • HandleManager singleton for thread-safe handle lifecycle
  • KernelManager singleton for JIT-compiled .hsaco module loading
  • New ENABLE_XLA_FFI cmake option (default OFF), builds libmori_xla_ffi_ops.so

Torch-Free JAX Path

  • All submodules in mori/__init__.py are now lazy-imported via __getattr__, so import mori.jax.* does not trigger import torch
  • 15 shmem C API functions exported as extern "C" from libmori_xla_ffi_ops.so
  • mori.jax.shmem is fully self-contained: calls shmem via ctypes, includes inline DLPack/GpuTensorView, never touches mori.cpp or pybind

JAX Array Interop

  • jax_data_ptr(arr) — extract raw GPU pointer from jax.Array
  • shmem_ptr_to_jax(ptr, shape, dtype) — wrap shmem memory as jax.Array via DLPack (zero-copy)
  • shmem_register_jax_array(arr) — register JAX array for RDMA
  • Supports float32, bfloat16, float16, int32, uint8 and more

JIT + globalGpuStates Integration

  • EpDispatchCombineOp auto-compiles .hsaco kernels and registers with FFI KernelManager
  • shmem_module_init_for_kernel() initializes globalGpuStates in loaded kernel modules

TODO

  • Set up multi-process JAX + shmem environment (resolve PJRT device model vs shmem single-GPU binding conflict), then:
    • Verify shmem_jax_init() multi-process UniqueId broadcast
    • Implement dispatch() / combine() in ops.py via jax.ffi.ffi_call, add shmem output
    • End-to-end dispatch/combine kernel launch test
  • Add mori-io extern "C" exports to libmori_xla_ffi_ops.so if JAX users need IO

Key differences from #173

PR #173 This PR
Python FFI registration _jax.register_custom_call_target (internal API) jax.ffi.register_ffi_target (public API, JAX 0.7+)
Shmem for JAX Not included Full ctypes-based shmem API (15 functions), torch-free
DLPack Not included GpuTensorView with GC-safe capsule lifecycle
torch dependency Reduced but still imported Fully eliminated for mori.jax.*

Test Plan

Verified on ROCm 7.1 + JAX 0.7.1 + 8x MI300X (gfx942):

  • C++ contract test: compile-time type check + linker symbol verification
  • Python shim parity (4 tests): FFI and pybind expose same op set
  • JAX array interop (4 tests): jax_data_ptr + shmem_ptr_to_jax round-trip, multi-dtype
  • FFI registration (3 tests): library loading, 8 symbols, jax.ffi.register_ffi_target
  • FFI lifecycle (4 tests): JIT compile + register, handle create/destroy, shmem malloc DLPack, GPU kernel launch (shmem barrier kernel, fully torch-free path)
  • torch-free: from mori.jax.* does not import torch
  • torch regression: mori.ops / mori.cpp / mori.shmem / mori.io / mori.ir unchanged

Acknowledgements

This PR is based on the original work in #173 . The XLA FFI framework design, vendored headers, and handle manager are from their contribution.

Co-authored-by: @i-chaochen cchen104@amd.com
Co-authored-by: @pemeliya pavel.emeliyanenko@amd.com
Co-authored-by: @jhchouuu jiahzhou@amd.com

jhchouuu added 3 commits March 6, 2026 17:33
Based on PR #173 by Chao Chen <cchen104@amd.com>, adapted for the
refactored architecture (raw pointer args + Python-side kernel launch).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant