Skip to content

Conversation

@mmcky
Copy link
Contributor

@mmcky mmcky commented Nov 27, 2025

Problem

The compute_call_price_jax function in jax_intro.md was timing out during cache.yml builds (600s cell execution timeout).

fixes #441

Root Cause

JAX unrolls Python for loops during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time as JAX traces through each iteration separately.

Solution

Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling:

# Before (Python for loop - gets unrolled)
for t in range(n):
    key, subkey = jax.random.split(key)
    Z = jax.random.normal(subkey, (2, M))
    s = s + μ + jnp.exp(h) * Z[0, :]
    h = ρ * h + ν * Z[1, :]

# After (JAX fori_loop - compiled efficiently)
def loop_body(i, state):
    s, h, key = state
    key, subkey = jax.random.split(key)
    Z = jax.random.normal(subkey, (2, M))
    s = s + μ + jnp.exp(h) * Z[0, :]
    h = ρ * h + ν * Z[1, :]
    return s, h, key

s, h, key = jax.lax.fori_loop(0, n, loop_body, (s, h, key))

Added an explanatory note for students about why we use fori_loop.

Related

This is the same category of issue as the lax.scan GPU performance problem in PR #437, but with a different manifestation (compilation time vs runtime).

The compute_call_price_jax function was timing out during cache.yml builds
because JAX unrolls Python for loops during JIT compilation. With large
arrays (M=10M), this causes excessive compilation time.

Solution: Replace Python for loop with jax.lax.fori_loop, which compiles
the loop efficiently without unrolling.

Fixes cell execution timeout in jax_intro.md
@mmcky mmcky changed the title Fix jax_intro timeout: use lax.fori_loop instead of Python for loop FIX: jax_intro timeout: use lax.fori_loop instead of Python for loop Nov 27, 2025
@mmcky mmcky requested a review from HumphreyYang November 27, 2025 23:05
@github-actions
Copy link

github-actions bot commented Nov 27, 2025

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 23:12 Inactive
@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

Thanks @mmcky ! Useful catch!

As a standard convention, I've been using the following

loop_body -> update
state -> loop_state
I might also add new_loop_state = ... before return new_loop_state and also final_loop_state = ... followed by unpacking.

It's a bit verbose but this is the first time readers will see a fori_loop.

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Timing Comparison

PR #442 (fori_loop on GPU)

  • First run (compile + execute): 1.86 seconds
  • Second run (cached): 0.48 seconds

Production (Python for loop on CPU)

  • First run (compile + execute): 16.73 seconds
  • Second run (cached): 14.37 seconds

Summary

Version First Run Second Run Improvement
Old (for loop, CPU) 16.73s 14.37s baseline
New (fori_loop, GPU) 1.86s 0.48s ~9x faster compile, ~30x faster run

The fori_loop fix works well:

  1. Compilation time reduced from ~17s to ~2s - JAX no longer unrolls 20 iterations with M=10,000,000 arrays
  2. Runtime reduced from ~14s to ~0.5s - GPU acceleration now works effectively
  3. No more timeout - the old version was timing out at 600s in cache.yml builds; this completes in under 2.5s total

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

Thanks @jstac this overhead thing in jax makes a pretty big difference.

- loop_body -> update
- state -> loop_state
- Added explicit new_loop_state and final_loop_state variables
- More verbose but clearer for first-time fori_loop readers
@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 23:36 Inactive
@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

Thanks @jstac this overhead thing in jax makes a pretty big difference.

Thanks for all this analysis @mmcky . Super useful. JAX unrolls a regular for loop, leading to massive compile times. So glad you picked this up!

(Some things were written in the early times, when we just got started with JAX.)

@jstac
Copy link
Contributor

jstac commented Nov 27, 2025

I'll stay out of this now, many thanks @mmcky !

@mmcky
Copy link
Contributor Author

mmcky commented Nov 27, 2025

thanks @jstac. Done 9bceb2d.

Appreciate the review.


This process has led me to open this project idea: QuantEcon/meta#264. When using dask I always appreciated the group execution into chunks > 10ms guideline to improve efficiency. I think a detailed set of benchmarks would be really neat to document overheads and how to use packages most effectively - that get's updated as versions are released.

@github-actions github-actions bot temporarily deployed to pull request November 27, 2025 23:48 Inactive
@github-actions github-actions bot temporarily deployed to pull request November 28, 2025 00:20 Inactive
@mmcky mmcky merged commit 2f103af into main Nov 28, 2025
5 checks passed
@mmcky mmcky deleted the fix/jax-intro-timeout branch November 28, 2025 00:27
mmcky added a commit to QuantEcon/lecture-jax that referenced this pull request Nov 28, 2025
The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.

Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.

Same fix as QuantEcon/lecture-python-programming.myst#442
mmcky added a commit to QuantEcon/lecture-jax that referenced this pull request Nov 28, 2025
The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.

Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.

Same fix as QuantEcon/lecture-python-programming.myst#442
mmcky added a commit to QuantEcon/lecture-jax that referenced this pull request Nov 28, 2025
…#249)

The compute_call_price_jax function was timing out during builds due to
JAX unrolling the Python for loop during JIT compilation. With large
arrays (M=10,000,000), this causes excessive compilation time.

Solution: Replace the Python for loop with jax.lax.fori_loop, which
compiles the loop efficiently without unrolling.

Same fix as QuantEcon/lecture-python-programming.myst#442
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: Cell timeout in jax_intro lecture

3 participants