Skip to content

Conversation

@mmcky
Copy link
Contributor

@mmcky mmcky commented Nov 27, 2025

Summary

This PR enables GPU support for building lectures using RunsOn with AWS GPU instances.

Changes

  • Add scripts/test-jax-install.py - Script to verify JAX/GPU installation works correctly
  • Add .github/runs-on.yml - RunsOn configuration with QuantEcon Ubuntu 24.04 AMI (ami-0edec81935264b6d3) in us-west-2
  • Update cache.yml - Use RunsOn g4dn.2xlarge GPU runner with JAX CUDA 13 support
  • Update ci.yml - Use RunsOn g4dn.2xlarge GPU runner with JAX CUDA 13 support
  • Update publish.yml - Use RunsOn g4dn.2xlarge GPU runner with JAX CUDA 13 support

Configuration Details

  • Runner: runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large
  • JAX: jax[cuda13] with Numpyro
  • Python: 3.13
  • AMI: Custom QuantEcon Ubuntu 24.04 image

Reference

This configuration mirrors the setup used in lecture-python.myst.

- Add scripts/test-jax-install.py to verify JAX/GPU installation
- Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration
- Update cache.yml to use RunsOn g4dn.2xlarge GPU runner
- Update ci.yml to use RunsOn g4dn.2xlarge GPU runner
- Update publish.yml to use RunsOn g4dn.2xlarge GPU runner
- Install JAX with CUDA 13 support and Numpyro on all workflows
- Add nvidia-smi check to verify GPU availability

This mirrors the setup used in lecture-python.myst repository.
@mmcky mmcky changed the title Enable RunsOn GPU support for lecture builds ENH: Enable RunsOn GPU support for lecture builds Nov 27, 2025
mmcky added a commit that referenced this pull request Nov 27, 2025
Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration.
This file must be on main branch for RunsOn to recognize the custom image.

AMI: ami-0edec81935264b6d3
Region: us-west-2

This is a prerequisite for PR #437 which enables GPU builds.
mmcky added a commit that referenced this pull request Nov 27, 2025
Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration.
This file must be on main branch for RunsOn to recognize the custom image.

AMI: ami-0edec81935264b6d3
Region: us-west-2

This is a prerequisite for PR #437 which enables GPU builds.
@github-actions
Copy link

github-actions bot commented Nov 27, 2025

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 04:27 Inactive
- Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md
- Update introduction in jax_intro.md to reflect GPU access
- Update conditional GPU language to reflect lectures now run on GPU
- Following QuantEcon style guide for JAX lectures
@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Lecture Narrative Updates

Added GPU admonitions and updated narrative text for lectures that use JAX:

Lecture Status Changes Made
jax_intro.md ✅ Updated GPU admonition added, introduction rewritten, conditional GPU language updated
numpy_vs_numba_vs_jax.md ✅ Updated GPU admonition added, GPU narrative updated
getting_started.md ✅ No change needed Just mentions Colab/GPUs as setup options
need_for_speed.md ✅ No change needed Educational discussion only, no JAX code
about_py.md ✅ No change needed Just mentions JAX in ecosystem overview

The GPU admonition follows the QuantEcon style guide and provides readers with information about GPU acceleration and alternative options for running locally.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

@jstac I am a bit concerned by these status.html results

GPU backend:

Screenshot 2025-11-27 at 3 36 06 pm

CPU backend (live site):

Screenshot 2025-11-27 at 3 36 27 pm

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

  • check status.hml of the next build and cross check.

https://6927d34136e0d837e87477e1--epic-agnesi-957267.netlify.app/numpy_vs_numba_vs_jax#jax-version is showing 80s for the JAX version using lax.scan!

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 04:44 Inactive
- Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks
- Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners
- Include warm-up vs compiled timing to isolate JIT overhead
- Add system info collection (CPU model, frequency, GPU detection)
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 05:01 Inactive
@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Hardware Benchmark Investigation Results

We ran identical benchmarks on both GitHub Actions (CPU) and RunsOn (GPU) to diagnose the performance differences.

System Information

Metric GitHub Actions (CPU) RunsOn (GPU)
Platform Linux Azure Linux AWS
CPU AMD EPYC 7763 64-Core @ 3281 MHz Intel Xeon Platinum 8259CL @ 2500 MHz
CPU Cores 4 8
GPU None Tesla T4 (15360 MiB)

CPU Performance (Pure Python)

Benchmark GitHub Actions RunsOn Winner
Integer sum (10M) 0.599 sec 0.881 sec GitHub Actions 1.5x faster
Float sqrt (1M) 0.077 sec 0.107 sec GitHub Actions 1.4x faster

CPU Performance (NumPy)

Benchmark GitHub Actions RunsOn Winner
Matrix multiply (3000×3000) 0.642 sec 0.224 sec RunsOn 2.9x faster
Element-wise (50M) 1.686 sec 1.768 sec ~Same

CPU Performance (Numba)

Benchmark GitHub Actions RunsOn Winner
Integer sum warm-up 0.320 sec 0.300 sec ~Same
Integer sum compiled 0.000 sec 0.000 sec Same
Parallel sum warm-up 0.344 sec 0.382 sec ~Same
Parallel sum compiled 0.012 sec 0.010 sec ~Same

JAX Performance

Benchmark GitHub Actions (CPU) RunsOn (GPU) Winner
1000×1000 warm-up 0.030 sec 0.079 sec GitHub Actions
1000×1000 compiled 0.011 sec 0.001 sec GPU 11x faster
3000×3000 warm-up 0.426 sec 0.645 sec GitHub Actions
3000×3000 compiled 0.276 sec 0.009 sec GPU 30x faster
50M element-wise warm-up 0.816 sec 0.118 sec GPU 7x faster
50M element-wise compiled 0.381 sec 0.002 sec GPU 190x faster

Key Findings

  1. Pure Python is slower on RunsOn - The Intel Xeon @ 2.5 GHz is slower than the AMD EPYC @ 3.3 GHz for single-threaded Python. This contributes to slower lecture execution times.

  2. NumPy matrix ops are faster on RunsOn - Likely due to 8 cores vs 4 cores for BLAS parallelization.

  3. GPU (JAX compiled) is massively faster - 30-190x faster for compiled operations! ✅

  4. JIT compilation overhead is higher on GPU - Warm-up times are actually longer on GPU due to CUDA kernel compilation.

Why numpy_vs_numba_vs_jax took 176 sec (GPU) vs 15 sec (CPU)

The lecture execution includes:

  • Pure Python setup code (slower on Intel Xeon)
  • JIT compilation time for both Numba and JAX (compilation overhead is higher for GPU kernels)
  • Multiple recompilations when array sizes change (the lecture uses different grid sizes)
  • 15 timed benchmark blocks that each trigger fresh compilations

Conclusion

The GPU is working correctly and provides massive speedups for compiled JAX operations. The slower total execution time is due to:

  1. Slower single-threaded Python on the Intel Xeon CPU
  2. Higher JIT compilation overhead for GPU kernels
  3. The lecture code structure (many small benchmarks with varying sizes)

For production lectures that run cached notebooks, the GPU benefits will be realized. The initial build with JIT compilation will be slower, but subsequent cached runs will benefit from the pre-compiled kernels.

@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

This is looking good @mmcky ---- are you still concerned?

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

@jstac yes I am concerned -- the results are strange on the status.html page.

I now seeing if there is some timing issues due to Jupyter/Jupyter Kernels or Jupyter Book.

Screenshot 2025-11-27 at 4 16 15 pm

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 05:44 Inactive
- Update benchmark-hardware.py to save results to JSON
- Update benchmark-jupyter.ipynb to save results to JSON
- Update benchmark-jupyterbook.md to save results to JSON
- Add CI step to collect and display benchmark results
- Add CI step to upload benchmark results as artifact
- Remove extra triple quote at start of file
- Remove stray parentheses in result assignments
- Copy benchmark-hardware.py from debug/benchmark-github-actions
- Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions
- Copy benchmark-jupyterbook.md from debug/benchmark-github-actions
- Update ci.yml to use matching file names

The test scripts are now identical between both branches,
only the CI workflow differs (runner type and JAX installation).
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 06:14 Inactive
@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Investigation: lax.scan Performance on GPU

Finding

lax.scan with millions of iterations performs extremely poorly on GPU due to a known XLA limitation. In the numpy_vs_numba_vs_jax lecture, the qm_jax function with n=10,000,000 takes:

  • CPU: ~0.06 seconds
  • GPU: ~81.6 seconds (1,340x slower)

Root Cause

XLA GPU executes dynamic control flow on CPU, requiring CPU-GPU synchronization on every loop iteration. This is documented in:

  1. JAX Issue #2491 (open since March 2020, P1 priority): lax.scan / lax.fori_loop is slow on the GPU

    "XLA GPU currently always executes dynamic control flow on the CPU. So small loop iterations end up much slower, due to the need to frequently synchronize between the CPU/GPU."@shoyer (JAX collaborator)

  2. JAX Issue #29946 (July 2025): Same pattern reported with identical symptoms

Documentation Gap

The official JAX lax.scan documentation does not warn about this GPU performance limitation. It only mentions the unroll parameter for partial loop unrolling, without explaining the GPU synchronization overhead.

Impact on This PR

The GPU build is slower than CPU primarily because of this single lecture. Other JAX operations (compiled array operations, matrix multiplications) are 28-207x faster on GPU as expected.

Potential Solutions

  1. Reduce n for GPU builds (e.g., n=100,000 instead of n=10,000,000)
  2. Skip GPU execution for this specific lecture
  3. Add documentation in the lecture explaining this is a known XLA limitation
  4. Accept the limitation and let the build run longer

Recommendation

Given this is an educational resource, Option 3 (add documentation) may be most valuable—turning this limitation into a teaching moment about GPU programming patterns.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

@jstac what is your take on this?

Here are the results of testing cpu and gpu benchmarks in both GitHub Actions (CPU) and AWS EC2 RunsOn (GPU) for the three execution pathways:

  1. bare metal
  2. noteboook
  3. jupyter-book

they all show consistent results: QuantEcon/meta#262 (comment)

I iterated with Claude and the results suggest it is to do with lex.scan on gpu along with some links to "evidence".
#437 (comment). Essentially the available evidence suggest lex.scan requires to interface with cpu on every iteration which massively slows down computations when run on gpu.

@jstac I am super confused by one thing though. Your local results don't show the same when running on gpu?

cc: @HumphreyYang interested in your thoughts on this if you have time.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Nsight Systems Artifact Available 🎉

@HumphreyYang the Nsight profile artifact is now available for download. Would love your help analyzing it!

How to Use the Artifact

  1. Download: Go to Actions → latest workflow run → Artifacts section → download nsight-profile.zip
  2. Extract: Get lax_scan_trace.nsys-rep
  3. Open: Use Nsight Systems UI (nsys-ui lax_scan_trace.nsys-rep)

What to Look For

The key question: Is the ~8µs per iteration from CPU-GPU sync, kernel launch overhead, or single-threaded GPU execution?

Timeline patterns to identify:

Pattern Interpretation
Long continuous GPU bar Loop runs entirely on GPU (single-threaded GPU issue)
Many tiny GPU bars with gaps Kernel launch overhead per iteration
CPU activity interleaved with GPU CPU-GPU synchronization per iteration
Repeated cudaMemcpy calls Data transfers between host and device
cudaDeviceSynchronize calls Explicit sync points

Expected if CPU-GPU sync per iteration:

CPU:  ████░░████░░████░░████░░...  (busy-wait-busy-wait pattern)
GPU:     █     █     █     █  ...  (tiny kernels with gaps)

Expected if single-threaded GPU (no sync):

CPU:  █░░░░░░░░░░░░░░░░░░░░░░░█    (idle during execution)
GPU:  ██████████████████████████   (one long continuous kernel)

Key Metrics to Note

  • Total CUDA API calls count
  • Average kernel duration
  • GPU utilization percentage
  • Any cudaStreamSynchronize or cudaDeviceSynchronize counts

This should definitively show whether the ~8µs per iteration overhead is from CPU-GPU sync, kernel launch overhead, or something else!

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Analysis of @HumphreyYang's TensorBoard Results 🔍

Great analysis @HumphreyYang! The TensorBoard data is very revealing.

Key Findings from the Operation Statistics

Operation Location Count Time (µs) % of Device Time
dynamic_update_slice Device 1,000 851 33.7%
mul Device 1,000 845 33.4%
add Device 1,000 827 32.7%
Compiling IR Host 1 17,282 (compilation)
HLO Transforms Host 1 8,841 (compilation)

What This Tells Us

  1. Operations ARE on Device ✅ - The mul, add, and dynamic_update_slice all run on GPU with 1,000 calls each (matching iteration count)

  2. Each iteration launches 3 separate kernels:

    • dynamic_update_slice (writing result to array)
    • mul (the α * x * (1-x) computation)
    • add (loop counter increment)
  3. ~0.85µs per kernel average - Each operation takes roughly 0.85µs

  4. Math check: 3 kernels × ~0.85µs × 1000 = ~2.5ms device time, but we measure ~8ms total

Root Cause: Kernel Launch Overhead

The gap between device time (~2.5ms) and wall time (~8ms) is kernel launch overhead!

  • GPU kernel launch overhead is typically ~2-10µs per launch
  • With 3 kernels per iteration: 3 × ~2-3µs overhead = ~6-9µs per iteration
  • This matches our measured ~8µs per iteration!

Why This Happens

The HLO while loop doesn't fuse into a single kernel. XLA generates separate kernels for each operation inside the loop body, and each kernel launch has fixed overhead regardless of how tiny the actual compute is.

Conclusion

  • Not CPU-GPU data synchronization per iteration
  • Kernel launch overhead from launching 3 tiny kernels per iteration
  • The compute itself is fast (~0.85µs), but launching 3 kernels adds ~6µs overhead per iteration

This is actually a cleaner explanation! The device=cpu fix avoids the kernel launch overhead entirely by running the sequential loop on CPU where there's no kernel launch cost.

Thanks for the great detective work! 🕵️

@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

This is great guys. Fascinating discussion. It's fun working with such a good team :-)

One small comment but other than that the changes are really nice.

@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

The HLO while loop doesn't fuse into a single kernel. XLA generates separate kernels for each operation inside the loop body, and each kernel launch has fixed overhead regardless of how tiny the actual compute is.

😱

@HumphreyYang
Copy link
Member

The HLO while loop doesn't fuse into a single kernel. XLA generates separate kernels for each operation inside the loop body, and each kernel launch has fixed overhead regardless of how tiny the actual compute is.

😱

Yeah this sounds really bad : (

We have to have large computation each loop to worth the overhead.

@HumphreyYang
Copy link
Member

HumphreyYang commented Nov 27, 2025

Hi @mmcky,

For your interest, here is the output from nsight!

image

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Nsight Systems Confirmation

Thanks @HumphreyYang for running the Nsight visualization! 🎉

The timeline view perfectly confirms our kernel launch overhead hypothesis:

  • Each lax.scan iteration launches 3 separate GPU kernels (mul, add, dynamic_update_slice)
  • The visualization shows the characteristic pattern of many tiny kernel launches with gaps between them
  • Those gaps represent the kernel launch latency (~2-3µs per launch)
  • With 3 kernels per iteration × ~2-3µs each = ~6-9µs overhead per iteration

This matches our measured ~8µs per iteration and explains the 1000x+ slowdown for scalar operations.

Key takeaway for the lectures: Sequential scalar operations in lax.scan should use device=cpu to avoid this overhead. GPU excels when each kernel does substantial parallel work, not when launching millions of tiny kernels.

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 10:59 Inactive
@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

thanks @jstac and @HumphreyYang -- learnt lots today. Been fun. Appreciate iterating with you both.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

@jstac In the numpy_vs_numba_vs_jax lecture, the statement about vmap version 2:

"We don't get much speed gain but we do save some memory."

This appears to be incorrect based on the GPU build results:

  • vmap Version 1: 0.00111055 seconds (1.1 ms)
  • vmap Version 2: 0.00046182 seconds (0.46 ms)

Version 2 is actually ~2.4x faster than Version 1, in addition to being more memory efficient.

This makes sense—by taking the max inside the inner function, we avoid creating large intermediate arrays and the GPU can compute row maxes immediately.

  • update to reflect that Version 2 is both faster AND more memory efficient.

Is this a different result to what we had on cpu?

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Is this a different result to what we had on cpu?

Ah yes this is different between cpu and gpu. When run on cpu they have the same times.

https://python-programming.quantecon.org/numpy_vs_numba_vs_jax.html#vmap-version-2

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Final TODO:

  • clean up scripts and debug infrastructure from ci.yml
  • minor update to vmap version 2 narrative.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Timing Comparison Review - numpy_vs_numba_vs_jax Lecture

I reviewed all timing comparisons in the lecture. Here's the summary:

✅ Statements that are correct:

  1. NumPy vs Numba (non-parallel): "can be a bit slower or a bit faster" - appropriately hedged
  2. Numba parallel: "should see at least some benefits" - correct (4.7x speedup)
  3. JAX meshgrid: "significantly faster than NumPy" - correct (13x faster)
  4. Sequential operations: "Both JAX and Numba deliver strong performance" - correct (0.065s vs 0.069s)

⚠️ Potentially questionable:

vmap Version 1 states: "The execution time is similar to the mesh operation"

  • Meshgrid JAX: 0.01924109 seconds
  • vmap Version 1: 0.00111055 seconds
  • vmap v1 is actually ~17x faster than meshgrid, not "similar"

❌ Incorrect (already reported):

vmap Version 2 states: "We don't get much speed gain but we do save some memory"

  • vmap v1: 0.00111055 seconds
  • vmap v2: 0.00046182 seconds
  • vmap v2 is actually ~2.4x faster, not "not much speed gain"

Recommendation: Update both vmap statements to reflect the actual performance improvements shown in the GPU build results.

- Reorganize jax_intro.md to introduce JAX features upfront with clearer structure
- Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff)
- Add explicit GPU performance notes in vmap sections
- Enhance vmap explanation with detailed function composition breakdown
- Clarify memory efficiency tradeoffs between different vmap approaches

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

Pedagogical improvements to JAX lectures

This commit improves the structure and clarity of the JAX introduction and vmap sections:

jax_intro.md:

  • Reorganized to introduce JAX's key features (parallelization, JIT compilation, autodiff) upfront before installation
  • Improved flow by presenting "what is JAX" before "how to install JAX"
  • Refined language around JAX as a NumPy drop-in replacement

numpy_vs_numba_vs_jax.md:

  • Added explicit GPU performance notes for vmap implementations
  • Enhanced explanation of vmap version 2 with detailed breakdown of function composition (f_vec_x_maxf_vec_maxjnp.max)
  • Clarified memory efficiency tradeoffs between different vmap approaches
  • Updated performance expectations to reflect GPU execution context

These changes make the lectures more accessible while better highlighting the GPU acceleration benefits that this PR enables.

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 20:17 Inactive
mmcky added a commit to QuantEcon/benchmarks that referenced this pull request Nov 27, 2025
Add benchmarking and profiling tools developed during GPU support investigation:

JAX benchmarks:
- lax.scan performance profiler with multiple analysis modes
- Documents kernel launch overhead issue and solution

Hardware benchmarks:
- Cross-platform benchmark comparing Pure Python, NumPy, Numba, JAX
- JAX installation verification script

Notebook benchmarks:
- MyST Markdown and Jupyter notebook for execution pathway comparison

Documentation:
- Detailed investigation report on lax.scan GPU performance issue
- README files with usage instructions for each category

Reference: QuantEcon/lecture-python-programming.myst#437
- Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md
- Remove profiling/benchmarking steps from ci.yml
- Keep test-jax-install.py for JAX installation verification

Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 21:31 Inactive
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 21:46 Inactive
- Add nvidia-smi output to show GPU availability
- Add JAX backend check to confirm GPU usage
- Matches format used in lecture-python.myst
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 21:55 Inactive
@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Thanks @jstac for your edits. This is ready to go.

I will:

  1. merge
  2. manually run the cache to get the latest execution cache in the correct environment
  3. run a publish

@mmcky mmcky merged commit 9045b9f into main Nov 27, 2025
5 checks passed
@mmcky mmcky deleted the feature/runson-gpu-support branch November 27, 2025 21:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants