JAXMg provides a C++ interface between JAX and cuSolverMg, NVIDIA’s multi-GPU linear solver. We provide a jittable API for the following routines.
-
cusolverMgPotrs: Solves the system of linear equations:
$Ax=b$ where$A$ is an$N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition -
cusolverMgPotri: Computes the inverse of an
$N\times N$ symmetric (Hermitian) positive-definite matrix via a Cholesky decomposition. -
cusolverMgSyevd: Computes eigenvalues and eigenvectors of an
$N\times N$ symmetric (Hermitian) matrix.
For more details, see the API.
The package is available on PyPi and can be installed with
pip install jaxmg[cuda12]This will install a GPU compatible version of JAX.
-
pip install "jaxmg[cuda12]": Use CUDA 12 (only works forjax>=0.6.2). -
pip install "jaxmg[cuda12-local]": Use locally available CUDA 12 installation. -
pip install "jaxmg[cuda13]": Use CUDA 13 (only works forjax>=0.7.2). -
pip install "jaxmg[cuda13-local]": Use locally available CUDA 13 installation.
The provided binaries are compiled with
| JAXMg | CUDA | cuDNN |
|---|---|---|
cuda12,cuda12-local |
12.8.0 | 9.17.1.4 |
cuda13,cuda13-local |
13.0.0 | 9.17.1.4 |
Details for compiling the from source code can be found in CONTRIBUTING.md.
Note:
pip install jaxmgwill install a CPU-only version of JAX. Sincejaxmgis a GPU-only package you will receive a warning to install a GPU-compatible version of jax.
A minimal example that runs the code is:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding
from jaxmg import potrs
print(f"Devices: {jax.devices()}")
# Assumes we have at least one GPU available
devices = jax.devices("gpu")
N = 12
T_A = 3
dtype = jnp.float64
# Create diagonal matrix and `b` all equal to one
A = jnp.diag(jnp.arange(N, dtype=dtype) + 1)
b = jnp.ones((N, 1), dtype=dtype)
ndev = len(devices)
# Make mesh and place data (rows sharded)
mesh = jax.make_mesh((ndev,), ("x",))
A = jax.device_put(A, NamedSharding(mesh, P("x", None)))
b = jax.device_put(b, NamedSharding(mesh, P(None, None)))
# Call potrs
out = potrs(A, b, T_A=T_A, mesh=mesh, in_specs=(P("x", None), ))
print(out)
expected_out = 1.0 / (jnp.arange(N, dtype=dtype) + 1)
print(jnp.allclose(out.flatten(), expected_out))which gives
[[1. ]
[0.5 ]
[0.33333333]
[0.25 ]
[0.2 ]
[0.16666667]
[0.14285714]
[0.125 ]
[0.11111111]
[0.1 ]
[0.09090909]
[0.08333333]]
Trueas expected.
- JAXMg Benchmarks: Benchmarks for various Multi-GPUs setups.
- JAXMg + Netket: Implementation of the MinSR Netket driver that uses JAXMg for inverting the S-matrix. Tested on Multi-node settings.
- JAXMg for blurred sampling: Implementation of t-VMC that makes use JAXMg for inverting the QGT.
As of CUDA 13, there is a new distributed linear algebra library called cuSolverMp with similar capabilities as cuSolverMg, that does support multi-node computations as well as >16 devices. Given the similarities in syntax, it should be straightforward to eventually switch to this API. This will require sharding data into a cyclic 2D form and handling the solver orchestration with MPI.
@misc{2601.14466,
Author = {Roeland Wiersema},
Title = {JAXMg: A multi-GPU linear solver in JAX},
Year = {2026},
Eprint = {arXiv:2601.14466},
}
I acknowledge support from the Flatiron Institute. The Flatiron Institute is a division of the Simons Foundation.
