-
-
Notifications
You must be signed in to change notification settings - Fork 31
FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The compute_call_price_jax function was timing out during cache.yml builds because JAX unrolls Python for loops during JIT compilation. With large arrays (M=10M), this causes excessive compilation time. Solution: Replace Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Fixes cell execution timeout in jax_intro.md
|
Thanks @mmcky ! Useful catch! As a standard convention, I've been using the following
It's a bit verbose but this is the first time readers will see a fori_loop. |
Timing ComparisonPR #442 (
|
| Version | First Run | Second Run | Improvement |
|---|---|---|---|
Old (for loop, CPU) |
16.73s | 14.37s | baseline |
New (fori_loop, GPU) |
1.86s | 0.48s | ~9x faster compile, ~30x faster run |
The fori_loop fix works well:
- Compilation time reduced from ~17s to ~2s - JAX no longer unrolls 20 iterations with M=10,000,000 arrays
- Runtime reduced from ~14s to ~0.5s - GPU acceleration now works effectively
- No more timeout - the old version was timing out at 600s in
cache.ymlbuilds; this completes in under 2.5s total
|
Thanks @jstac this overhead thing in |
- loop_body -> update - state -> loop_state - Added explicit new_loop_state and final_loop_state variables - More verbose but clearer for first-time fori_loop readers
|
I'll stay out of this now, many thanks @mmcky ! |
|
Appreciate the review. This process has led me to open this project idea: QuantEcon/meta#264. When using |
The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442
The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442
…#249) The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442
Problem
The
compute_call_price_jaxfunction injax_intro.mdwas timing out duringcache.ymlbuilds (600s cell execution timeout).fixes #441
Root Cause
JAX unrolls Python
forloops during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time as JAX traces through each iteration separately.Solution
Replace the Python
forloop withjax.lax.fori_loop, which compiles the loop efficiently without unrolling:Added an explanatory note for students about why we use
fori_loop.Related
This is the same category of issue as the
lax.scanGPU performance problem in PR #437, but with a different manifestation (compilation time vs runtime).