diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index b4114630..f8fe265d 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -832,16 +832,31 @@ def compute_call_price_jax(β=β, s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) - for t in range(n): + + def update(i, loop_state): + s, h, key = loop_state key, subkey = jax.random.split(key) Z = jax.random.normal(subkey, (2, M)) s = s + μ + jnp.exp(h) * Z[0, :] h = ρ * h + ν * Z[1, :] + new_loop_state = s, h, key + return new_loop_state + + initial_loop_state = s, h, key + final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state) + s, h, key = final_loop_state + expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0)) return β**n * expectation ``` +```{note} +We use `jax.lax.fori_loop` instead of a Python `for` loop. +This allows JAX to compile the loop efficiently without unrolling it, +which significantly reduces compilation time for large arrays. +``` + Let's run it once to compile it: ```{code-cell} ipython3