diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index c7e8d4c8..1fa83116 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -155,22 +155,13 @@ x, y = np.meshgrid(grid, grid) with qe.Timer(precision=8): z_max_numpy = np.max(f(x, y)) -print(f"NumPy result: {z_max_numpy}") +print(f"NumPy result: {z_max_numpy:.6f}") ``` In the vectorized version, all the looping takes place in compiled code. Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs. -```{note} -If you have a system monitor such as htop (Linux/Mac) or perfmon -(Windows), then try running this and then observing the load on your CPUs. - -(You will probably need to bump up the grid size to see large effects.) - -The output typically shows that the operation is successfully distributed across multiple threads. -``` - (The parallelization cannot be highly efficient because the binary is compiled before it sees the size of the arrays `x` and `y`.) @@ -195,15 +186,18 @@ def compute_max_numba(grid): grid = np.linspace(-3, 3, 3_000) with qe.Timer(precision=8): - compute_max_numba(grid) + z_max_numpy = compute_max_numba(grid) + +print(f"Numba result: {z_max_numpy:.6f}") ``` +Let's run again to eliminate compile time. + ```{code-cell} ipython3 with qe.Timer(precision=8): compute_max_numba(grid) ``` - Depending on your machine, the Numba version can be a bit slower or a bit faster than NumPy. @@ -240,17 +234,14 @@ Usually this returns an incorrect result: ```{code-cell} ipython3 z_max_parallel_incorrect = compute_max_numba_parallel(grid) -print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}") -print(f"NumPy result: {z_max_numpy}") +print(f"Numba result: {z_max_parallel_incorrect} 😱") ``` -The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`. - -The reason is that the variable $m$ is shared across threads and not properly controlled. +The reason is that the variable `m` is shared across threads and not properly controlled. -When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition. +When multiple threads try to read and write `m` simultaneously, they interfere with each other. -This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`. +Threads read stale values of `m` or overwrite each other's updates --— or `m` never gets updated from its initial value. Here's a more carefully written version. @@ -274,30 +265,31 @@ def compute_max_numba_parallel(grid): Now the code block that `for i in numba.prange(n)` acts over is independent across `i`. -Each thread writes to a separate element of the array `row_maxes`. - -Hence the parallelization is safe. - -Here's the timings. +Each thread writes to a separate element of the array `row_maxes` and +the parallelization is safe. ```{code-cell} ipython3 -with qe.Timer(precision=8): - compute_max_numba_parallel(grid) +z_max_parallel = compute_max_numba_parallel(grid) +print(f"Numba result: {z_max_parallel:.6f}") ``` +Here's the timing. + ```{code-cell} ipython3 with qe.Timer(precision=8): compute_max_numba_parallel(grid) ``` -If you have multiple cores, you should see at least some benefits from parallelization here. +If you have multiple cores, you should see at least some benefits from +parallelization here. -For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU. +For more powerful machines and larger grid sizes, parallelization can generate +major speed gains, even on the CPU. ### Vectorized code with JAX -In most ways, vectorization is the same in JAX as it is in NumPy. +On the surface, vectorized code in JAX is similar to NumPy code. But there are also some differences, which we highlight here. @@ -319,14 +311,18 @@ grid = jnp.linspace(-3, 3, 3_000) x_mesh, y_mesh = np.meshgrid(grid, grid) with qe.Timer(precision=8): - z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() + z_max = jnp.max(f(x_mesh, y_mesh)) + z_max.block_until_ready() + +print(f"Plain vanilla JAX result: {z_max:.6f}") ``` Let's run again to eliminate compile time. ```{code-cell} ipython3 with qe.Timer(precision=8): - z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() + z_max = jnp.max(f(x_mesh, y_mesh)) + z_max.block_until_ready() ``` Once compiled, JAX is significantly faster than NumPy due to GPU acceleration. @@ -374,6 +370,8 @@ Let's see the timing: with qe.Timer(precision=8): z_max = jnp.max(f_vec(grid)) z_max.block_until_ready() + +print(f"JAX vmap v1 result: {z_max:.6f}") ``` ```{code-cell} ipython3 @@ -429,6 +427,8 @@ Let's try it. ```{code-cell} ipython3 with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() + +print(f"JAX vmap v1 result: {z_max:.6f}") ``` Let's run it again to eliminate compilation time: @@ -445,19 +445,19 @@ If you are running this on a GPU, as we are, you should see another nontrivial s In our view, JAX is the winner for vectorized operations. -It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap). +It dominates NumPy both in terms of speed (via JIT-compilation and +parallelization) and memory efficiency (via vmap). Moreover, the `vmap` approach can sometimes lead to significantly clearer code. While Numba is impressive, the beauty of JAX is that, with fully vectorized -operations, we can run exactly the -same code on machines with hardware accelerators and reap all the benefits -without extra effort. +operations, we can run exactly the same code on machines with hardware +accelerators and reap all the benefits without extra effort. Moreover, JAX already knows how to effectively parallelize many common array operations, which is key to fast execution. -For almost all cases encountered in economics, econometrics, and finance, it is +For most cases encountered in economics, econometrics, and finance, it is far better to hand over to the JAX compiler for efficient parallelization than to try to hand code these routines ourselves. @@ -537,9 +537,11 @@ This code is not easy to read but, in essence, `lax.scan` repeatedly calls `upda ```{note} Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator. -The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism. +The computation consists of many small sequential operations, leaving little +opportunity for the GPU to exploit parallelism. -As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload. +As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU +a better fit for this workload. Curious readers can try removing this option to see how performance changes. ``` @@ -558,16 +560,17 @@ with qe.Timer(precision=8): x_jax = qm_jax(0.1, n).block_until_ready() ``` -JAX is also efficient for this sequential operation. +JAX is also quite efficient for this sequential operation. Both JAX and Numba deliver strong performance after compilation, with Numba typically (but not always) offering slightly better speeds on purely sequential operations. + ### Summary While both Numba and JAX deliver strong performance for sequential operations, -there are significant differences in code readability and ease of use. +*there are significant differences in code readability and ease of use*. The Numba version is straightforward and natural to read: we simply allocate an array and fill it element by element using a standard Python loop. @@ -580,3 +583,4 @@ Additionally, JAX's immutable arrays mean we cannot simply update array elements For this type of sequential operation, Numba is the clear winner in terms of code clarity and ease of implementation, as well as high performance. +