Skip to content

flatironinstitute/jaxmg

Repository files navigation

Jaxmg

JAXMg: A multi-GPU linear solver in JAX

Docs Releases Build Status

JAXMg

JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.

  • cusolverMgPotrs: Solves the system of linear equations: $Ax=b$ where $A$ is an $N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition
  • cusolverMgPotri: Computes the inverse of an $N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition.
  • cusolverMgSyevd: Computes eigenvalues and eigenvectors of an $N\times N$ symmetric (Hermitian) matrix.

For more details, see the API.

Installation

The package is available on PyPi and can be installed with

pip install jaxmg[cuda12]

This will install a GPU compatible version of JAX.

  1. pip install "jaxmg[cuda12]": Use CUDA 12 (only works for jax>=0.6.2).

  2. pip install "jaxmg[cuda12-local]": Use locally available CUDA 12 installation.

  3. pip install "jaxmg[cuda13]": Use CUDA 13 (only works for jax>=0.7.2).

  4. pip install "jaxmg[cuda13-local]": Use locally available CUDA 13 installation.

The provided binaries are compiled with

JAXMg CUDA cuDNN
cuda12,cuda12-local 12.8.0 9.17.1.4
cuda13,cuda13-local 13.0.0 9.17.1.4

Details for compiling the from source code can be found in CONTRIBUTING.md.

Note: pip install jaxmg will install a CPU-only version of JAX. Since jaxmg is a GPU-only package you will receive a warning to install a GPU-compatible version of jax.

Example

A minimal example that runs the code is:

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import potrs
print(f"Devices: {jax.devices()}")
# Assumes we have at least one GPU available
devices = jax.devices("gpu")
N = 12
T_A = 3
dtype = jnp.float64
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (rows sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P("x", None)))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
# Call potrs
out = potrs(A, b, T_A=T_A, mesh=mesh, in_specs=(P("x", None), ))
print(out)
expected_out = 1.0 / (jnp.arange(N, dtype=dtype) + 1)
print(jnp.allclose(out.flatten(), expected_out))

which gives

[[1.        ]
 [0.5       ]
 [0.33333333]
 [0.25      ]
 [0.2       ]
 [0.16666667]
 [0.14285714]
 [0.125     ]
 [0.11111111]
 [0.1       ]
 [0.09090909]
 [0.08333333]]
True

as expected.

Projects that use JAXMg

  • JAXMg Benchmarks: Benchmarks for various Multi-GPUs setups.
  • JAXMg + Netket: Implementation of the MinSR Netket driver that uses JAXMg for inverting the S-matrix. Tested on Multi-node settings.
  • JAXMg for blurred sampling: Implementation of t-VMC that makes use JAXMg for inverting the QGT.

cuSolverMp

As of CUDA 13, there is a new distributed linear algebra library called cuSolverMp with similar capabilities as cuSolverMg, that does support multi-node computations as well as >16 devices. Given the similarities in syntax, it should be straightforward to eventually switch to this API. This will require sharding data into a cyclic 2D form and handling the solver orchestration with MPI.

Citations

@misc{2601.14466,
Author = {Roeland Wiersema},
Title = {JAXMg: A multi-GPU linear solver in JAX},
Year = {2026},
Eprint = {arXiv:2601.14466},
}

Acknowledgements

I acknowledge support from the Flatiron Institute. The Flatiron Institute is a division of the Simons Foundation.