-
-
Notifications
You must be signed in to change notification settings - Fork 31
ENH: Enable RunsOn GPU support for lecture builds #437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Add scripts/test-jax-install.py to verify JAX/GPU installation - Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration - Update cache.yml to use RunsOn g4dn.2xlarge GPU runner - Update ci.yml to use RunsOn g4dn.2xlarge GPU runner - Update publish.yml to use RunsOn g4dn.2xlarge GPU runner - Install JAX with CUDA 13 support and Numpyro on all workflows - Add nvidia-smi check to verify GPU availability This mirrors the setup used in lecture-python.myst repository.
Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration. This file must be on main branch for RunsOn to recognize the custom image. AMI: ami-0edec81935264b6d3 Region: us-west-2 This is a prerequisite for PR #437 which enables GPU builds.
Add .github/runs-on.yml with QuantEcon Ubuntu 24.04 AMI configuration. This file must be on main branch for RunsOn to recognize the custom image. AMI: ami-0edec81935264b6d3 Region: us-west-2 This is a prerequisite for PR #437 which enables GPU builds.
- Add standard GPU admonition to jax_intro.md and numpy_vs_numba_vs_jax.md - Update introduction in jax_intro.md to reflect GPU access - Update conditional GPU language to reflect lectures now run on GPU - Following QuantEcon style guide for JAX lectures
Lecture Narrative UpdatesAdded GPU admonitions and updated narrative text for lectures that use JAX:
The GPU admonition follows the QuantEcon style guide and provides readers with information about GPU acceleration and alternative options for running locally. |
|
@jstac I am a bit concerned by these status.html results GPU backend: CPU backend (live site): |
https://6927d34136e0d837e87477e1--epic-agnesi-957267.netlify.app/numpy_vs_numba_vs_jax#jax-version is showing 80s for the JAX version using |
- Add benchmark-hardware.py with CPU, NumPy, Numba, and JAX benchmarks - Works on both GPU (RunsOn) and CPU-only (GitHub Actions) runners - Include warm-up vs compiled timing to isolate JIT overhead - Add system info collection (CPU model, frequency, GPU detection)
Hardware Benchmark Investigation ResultsWe ran identical benchmarks on both GitHub Actions (CPU) and RunsOn (GPU) to diagnose the performance differences. System Information
CPU Performance (Pure Python)
CPU Performance (NumPy)
CPU Performance (Numba)
JAX Performance
Key Findings
Why
|
|
This is looking good @mmcky ---- are you still concerned? |
|
@jstac yes I am concerned -- the results are strange on the I now seeing if there is some timing issues due to Jupyter/Jupyter Kernels or Jupyter Book. |
- Update benchmark-hardware.py to save results to JSON - Update benchmark-jupyter.ipynb to save results to JSON - Update benchmark-jupyterbook.md to save results to JSON - Add CI step to collect and display benchmark results - Add CI step to upload benchmark results as artifact
- Remove extra triple quote at start of file - Remove stray parentheses in result assignments
- Copy benchmark-hardware.py from debug/benchmark-github-actions - Copy benchmark-jupyter.ipynb from debug/benchmark-github-actions - Copy benchmark-jupyterbook.md from debug/benchmark-github-actions - Update ci.yml to use matching file names The test scripts are now identical between both branches, only the CI workflow differs (runner type and JAX installation).
Investigation:
|
|
@jstac what is your take on this? Here are the results of testing
they all show consistent results: QuantEcon/meta#262 (comment) I iterated with Claude and the results suggest it is to do with @jstac I am super confused by one thing though. Your local results don't show the same when running on cc: @HumphreyYang interested in your thoughts on this if you have time. |
Nsight Systems Artifact Available 🎉@HumphreyYang the Nsight profile artifact is now available for download. Would love your help analyzing it! How to Use the Artifact
What to Look ForThe key question: Is the ~8µs per iteration from CPU-GPU sync, kernel launch overhead, or single-threaded GPU execution? Timeline patterns to identify:
Expected if CPU-GPU sync per iteration: Expected if single-threaded GPU (no sync): Key Metrics to Note
This should definitively show whether the ~8µs per iteration overhead is from CPU-GPU sync, kernel launch overhead, or something else! |
Analysis of @HumphreyYang's TensorBoard Results 🔍Great analysis @HumphreyYang! The TensorBoard data is very revealing. Key Findings from the Operation Statistics
What This Tells Us
Root Cause: Kernel Launch OverheadThe gap between device time (~2.5ms) and wall time (~8ms) is kernel launch overhead!
Why This HappensThe HLO Conclusion
This is actually a cleaner explanation! The Thanks for the great detective work! 🕵️ |
|
This is great guys. Fascinating discussion. It's fun working with such a good team :-) One small comment but other than that the changes are really nice. |
😱 |
Yeah this sounds really bad : ( We have to have large computation each loop to worth the overhead. |
|
Hi @mmcky, For your interest, here is the output from
|
Nsight Systems ConfirmationThanks @HumphreyYang for running the Nsight visualization! 🎉 The timeline view perfectly confirms our kernel launch overhead hypothesis:
This matches our measured ~8µs per iteration and explains the 1000x+ slowdown for scalar operations. Key takeaway for the lectures: Sequential scalar operations in |
|
thanks @jstac and @HumphreyYang -- learnt lots today. Been fun. Appreciate iterating with you both. |
|
@jstac In the
This appears to be incorrect based on the GPU build results:
Version 2 is actually ~2.4x faster than Version 1, in addition to being more memory efficient. This makes sense—by taking the
Is this a different result to what we had on |
Ah yes this is different between https://python-programming.quantecon.org/numpy_vs_numba_vs_jax.html#vmap-version-2 |
|
Final TODO:
|
Timing Comparison Review -
|
- Reorganize jax_intro.md to introduce JAX features upfront with clearer structure - Expand JAX introduction with bulleted list of key capabilities (parallelization, JIT, autodiff) - Add explicit GPU performance notes in vmap sections - Enhance vmap explanation with detailed function composition breakdown - Clarify memory efficiency tradeoffs between different vmap approaches 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
Pedagogical improvements to JAX lecturesThis commit improves the structure and clarity of the JAX introduction and vmap sections: jax_intro.md:
numpy_vs_numba_vs_jax.md:
These changes make the lectures more accessible while better highlighting the GPU acceleration benefits that this PR enables. |
Add benchmarking and profiling tools developed during GPU support investigation: JAX benchmarks: - lax.scan performance profiler with multiple analysis modes - Documents kernel launch overhead issue and solution Hardware benchmarks: - Cross-platform benchmark comparing Pure Python, NumPy, Numba, JAX - JAX installation verification script Notebook benchmarks: - MyST Markdown and Jupyter notebook for execution pathway comparison Documentation: - Detailed investigation report on lax.scan GPU performance issue - README files with usage instructions for each category Reference: QuantEcon/lecture-python-programming.myst#437
- Remove profile_lax_scan.py, benchmark-hardware.py, benchmark-jupyter.ipynb, benchmark-jupyterbook.md - Remove profiling/benchmarking steps from ci.yml - Keep test-jax-install.py for JAX installation verification Benchmark scripts are now maintained in: https://github.com/QuantEcon/benchmarks
- Add nvidia-smi output to show GPU availability - Add JAX backend check to confirm GPU usage - Matches format used in lecture-python.myst
|
Thanks @jstac for your edits. This is ready to go. I will:
|

Summary
This PR enables GPU support for building lectures using RunsOn with AWS GPU instances.
Changes
scripts/test-jax-install.py- Script to verify JAX/GPU installation works correctly.github/runs-on.yml- RunsOn configuration with QuantEcon Ubuntu 24.04 AMI (ami-0edec81935264b6d3) inus-west-2cache.yml- Use RunsOng4dn.2xlargeGPU runner with JAX CUDA 13 supportci.yml- Use RunsOng4dn.2xlargeGPU runner with JAX CUDA 13 supportpublish.yml- Use RunsOng4dn.2xlargeGPU runner with JAX CUDA 13 supportConfiguration Details
runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=largejax[cuda13]with NumpyroReference
This configuration mirrors the setup used in lecture-python.myst.