Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2e1684b
Enable RunsOn GPU support for lecture builds
mmcky Nov 27, 2025
1671eb7
Merge branch 'main' into feature/runson-gpu-support
mmcky Nov 27, 2025
849cf16
DOC: Update JAX lectures with GPU admonition and narrative
mmcky Nov 27, 2025
0b6a567
DEBUG: Add hardware benchmark script to diagnose performance
mmcky Nov 27, 2025
6d1b9c3
Add multi-pathway benchmark tests (bare metal, Jupyter, jupyter-book)
mmcky Nov 27, 2025
d129f79
Fix: Add content to benchmark-jupyter.ipynb (was empty)
mmcky Nov 27, 2025
2da4e0c
Fix: Add benchmark content to benchmark-jupyter.ipynb
mmcky Nov 27, 2025
2bda114
Add JSON output to benchmarks and upload as artifacts
mmcky Nov 27, 2025
922b24c
Fix syntax errors in benchmark-hardware.py
mmcky Nov 27, 2025
10627ef
Sync benchmark scripts with CPU branch for comparable results
mmcky Nov 27, 2025
54b2b34
ENH: Force lax.scan sequential operation to run on CPU
mmcky Nov 27, 2025
6bf345a
update note
HumphreyYang Nov 27, 2025
8fbb9a7
Add lax.scan profiler to CI for GPU debugging
mmcky Nov 27, 2025
1bfbaf9
Add diagnostic mode to lax.scan profiler
mmcky Nov 27, 2025
8c32d7c
Add Nsight Systems profiling to CI
mmcky Nov 27, 2025
a623fb7
address @jstac comment
mmcky Nov 27, 2025
1739f51
Improve JAX lecture content and pedagogy
jstac Nov 27, 2025
350da37
Remove benchmark scripts (moved to QuantEcon/benchmarks)
mmcky Nov 27, 2025
e2939c2
Update lectures/numpy_vs_numba_vs_jax.md
mmcky Nov 27, 2025
56047ab
Add GPU and JAX hardware details to status page
mmcky Nov 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .github/workflows/cache.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
workflow_dispatch:
jobs:
cache:
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v6
- name: Setup Anaconda
Expand All @@ -18,6 +18,16 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
nvidia-smi
- name: Build HTML
shell: bash -l {0}
run: |
Expand Down
11 changes: 10 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build Project [using jupyter-book]
on: [pull_request]
jobs:
preview:
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- uses: actions/checkout@v6
with:
Expand All @@ -16,6 +16,15 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Check nvidia Drivers
shell: bash -l {0}
run: nvidia-smi
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
- name: Install latex dependencies
run: |
sudo apt-get -qq update
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
jobs:
publish:
if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags')
runs-on: ubuntu-latest
runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large"
steps:
- name: Checkout
uses: actions/checkout@v6
Expand All @@ -21,6 +21,16 @@ jobs:
python-version: "3.13"
environment-file: environment.yml
activate-environment: quantecon
- name: Install JAX and Numpyro
shell: bash -l {0}
run: |
pip install -U "jax[cuda13]"
pip install numpyro
python scripts/test-jax-install.py
- name: Check nvidia drivers
shell: bash -l {0}
run: |
nvidia-smi
- name: Install latex dependencies
run: |
sudo apt-get -qq update
Expand Down
52 changes: 26 additions & 26 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ kernelspec:

# JAX

This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).

JAX is a high-performance scientific computing library that provides

* a NumPy-like interface that can automatically parallize across CPUs and GPUs,
* a just-in-time compiler for accelerating a large range of numerical
operations, and
* automatic differentiation.

Increasingly, JAX also maintains and provides more specialized scientific
computing routines, such as those originally found in SciPy.

In addition to what's in Anaconda, this lecture will need the following libraries:

```{code-cell} ipython3
Expand All @@ -21,28 +33,24 @@ In addition to what's in Anaconda, this lecture will need the following librarie
!pip install jax quantecon
```

This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax).

Here we are focused on using JAX on the CPU, rather than on accelerators such as
GPUs or TPUs.

This means we will only see a small amount of the possible benefits from using JAX.

However, JAX seamlessly handles transitions across different hardware platforms.
```{admonition} GPU
:class: warning

As a result, if you run this code on a machine with a GPU and a GPU-aware
version of JAX installed, your code will be automatically accelerated and you
will receive the full benefits.
This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.

For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html).
Free GPUs are available on Google Colab.
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.

Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
```

## JAX as a NumPy Replacement

One of the attractive features of JAX is that, whenever possible, it conforms to
the NumPy API for array operations.
One of the attractive features of JAX is that, whenever possible, its array
processing operations conform to the NumPy API.

This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement.
This means that, in many cases, we can use JAX is as a drop-in NumPy replacement.

Let's look at the similarities and differences between JAX and NumPy.

Expand Down Expand Up @@ -523,16 +531,9 @@ with qe.Timer():
jax.block_until_ready(y);
```

If you are running this on a GPU the code will run much faster than its NumPy
equivalent, which ran on the CPU.

Even if you are running on a machine with many CPUs, the second JAX run should
be substantially faster with JAX.

Also, typically, the second run is faster than the first.
On a GPU, this code runs much faster than its NumPy equivalent.

(This might not be noticable on the CPU but it should definitely be noticable on
the GPU.)
Also, typically, the second run is faster than the first due to JIT compilation.

This is because even built in functions like `jnp.cos` are JIT-compiled --- and the
first run includes compile time.
Expand Down Expand Up @@ -634,8 +635,7 @@ with qe.Timer():
jax.block_until_ready(y);
```

The outcome is similar to the `cos` example --- JAX is faster, especially if you
use a GPU and especially on the second run.
The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation.

Moreover, with JAX, we have another trick up our sleeve:

Expand Down
61 changes: 48 additions & 13 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ tags: [hide-output]
!pip install quantecon jax
```

```{admonition} GPU
:class: warning

This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming.

Free GPUs are available on Google Colab.
To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU.

Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support.
If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]`
```

We will use the following imports.

```{code-cell} ipython3
Expand Down Expand Up @@ -317,7 +329,7 @@ with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
```

Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU.
Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.

The compilation overhead is a one-time cost that pays off when the function is called repeatedly.

Expand Down Expand Up @@ -370,23 +382,29 @@ with qe.Timer(precision=8):
z_max.block_until_ready()
```

The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`,
we are using far less memory.
By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory.

When run on a CPU, its runtime is similar to that of the meshgrid version.

In addition, `vmap` allows us to break vectorization up into stages, which is
often easier to comprehend than the traditional approach.
When run on a GPU, it is usually significantly faster.

This will become more obvious when we tackle larger problems.
In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages.

This leads to code that is often easier to comprehend than traditional vectorized code.

We will investigate these ideas more when we tackle larger problems.


### vmap version 2

We can be still more memory efficient using vmap.

While we avoided large input arrays in the preceding version,
While we avoid large input arrays in the preceding version,
we still create the large output array `f(x,y)` before we compute the max.

Let's use a slightly different approach that takes the max to the inside.
Let's try a slightly different approach that takes the max to the inside.

Because of this change, we never compute the two-dimensional array `f(x,y)`.

```{code-cell} ipython3
@jax.jit
Expand All @@ -399,23 +417,28 @@ def compute_max_vmap_v2(grid):
return jnp.max(f_vec_max(grid))
```

Let's try it
Here

* `f_vec_x_max` computes the max along any given row
* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel.

We apply this function to all rows and then take the max of the row maxes.

Let's try it.

```{code-cell} ipython3
with qe.Timer(precision=8):
z_max = compute_max_vmap_v2(grid).block_until_ready()
```


Let's run it again to eliminate compilation time:

```{code-cell} ipython3
with qe.Timer(precision=8):
z_max = compute_max_vmap_v2(grid).block_until_ready()
```

We don't get much speed gain but we do save some memory.

If you are running this on a GPU, as we are, you should see another nontrivial speed gain.


### Summary
Expand Down Expand Up @@ -497,7 +520,9 @@ Now let's create a JAX version using `lax.scan`:
from jax import lax
from functools import partial

@partial(jax.jit, static_argnums=(1,))
cpu = jax.devices("cpu")[0]

@partial(jax.jit, static_argnums=(1,), device=cpu)
def qm_jax(x0, n, α=4.0):
def update(x, t):
x_new = α * x * (1 - x)
Expand All @@ -509,6 +534,16 @@ def qm_jax(x0, n, α=4.0):

This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array.

```{note}
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.

The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.

As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.

Curious readers can try removing this option to see how performance changes.
```

Let's time it with the same parameters:

```{code-cell} ipython3
Expand Down
14 changes: 14 additions & 0 deletions lectures/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,18 @@ and the following package versions
```{code-cell} ipython
:tags: [hide-output]
!conda list
```

This lecture series has access to the following GPU

```{code-cell} ipython
!nvidia-smi
```

You can check the backend used by JAX using:

```{code-cell} ipython3
import jax
# Check if JAX is using GPU
print(f"JAX backend: {jax.devices()[0].platform}")
```
21 changes: 21 additions & 0 deletions scripts/test-jax-install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import jax
import jax.numpy as jnp

devices = jax.devices()
print(f"The available devices are: {devices}")

@jax.jit
def matrix_multiply(a, b):
return jnp.dot(a, b)

# Example usage:
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
y = jax.random.normal(key, (1000, 1000))
z = matrix_multiply(x, y)

# Now the function is JIT compiled and will likely run on GPU (if available)
print(z)

devices = jax.devices()
print(f"The available devices are: {devices}")
Loading