Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 44 additions & 40 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,13 @@ x, y = np.meshgrid(grid, grid)
with qe.Timer(precision=8):
z_max_numpy = np.max(f(x, y))

print(f"NumPy result: {z_max_numpy}")
print(f"NumPy result: {z_max_numpy:.6f}")
```

In the vectorized version, all the looping takes place in compiled code.

Moreover, NumPy uses implicit multithreading, so that at least some parallelization occurs.

```{note}
If you have a system monitor such as htop (Linux/Mac) or perfmon
(Windows), then try running this and then observing the load on your CPUs.

(You will probably need to bump up the grid size to see large effects.)

The output typically shows that the operation is successfully distributed across multiple threads.
```

(The parallelization cannot be highly efficient because the binary is compiled
before it sees the size of the arrays `x` and `y`.)

Expand All @@ -195,15 +186,18 @@ def compute_max_numba(grid):
grid = np.linspace(-3, 3, 3_000)

with qe.Timer(precision=8):
compute_max_numba(grid)
z_max_numpy = compute_max_numba(grid)

print(f"Numba result: {z_max_numpy:.6f}")
```

Let's run again to eliminate compile time.

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba(grid)
```


Depending on your machine, the Numba version can be a bit slower or a bit faster
than NumPy.

Expand Down Expand Up @@ -240,17 +234,14 @@ Usually this returns an incorrect result:

```{code-cell} ipython3
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
print(f"Incorrect parallel Numba result: {z_max_parallel_incorrect}")
print(f"NumPy result: {z_max_numpy}")
print(f"Numba result: {z_max_parallel_incorrect} 😱")
```

The incorrect parallel implementation typically returns `-inf` (the initial value of `m`) instead of the correct maximum value of approximately `0.9999979986680024`.

The reason is that the variable $m$ is shared across threads and not properly controlled.
The reason is that the variable `m` is shared across threads and not properly controlled.

When multiple threads try to read and write `m` simultaneously, they interfere with each other, causing a race condition.
When multiple threads try to read and write `m` simultaneously, they interfere with each other.

This results in lost updates—threads read stale values of `m` or overwrite each other's updates—and the variable often never gets updated from its initial value of `-inf`.
Threads read stale values of `m` or overwrite each other's updates --— or `m` never gets updated from its initial value.

Here's a more carefully written version.

Expand All @@ -274,30 +265,31 @@ def compute_max_numba_parallel(grid):
Now the code block that `for i in numba.prange(n)` acts over is independent
across `i`.

Each thread writes to a separate element of the array `row_maxes`.

Hence the parallelization is safe.

Here's the timings.
Each thread writes to a separate element of the array `row_maxes` and
the parallelization is safe.

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba_parallel(grid)
z_max_parallel = compute_max_numba_parallel(grid)
print(f"Numba result: {z_max_parallel:.6f}")
```

Here's the timing.

```{code-cell} ipython3
with qe.Timer(precision=8):
compute_max_numba_parallel(grid)
```

If you have multiple cores, you should see at least some benefits from parallelization here.
If you have multiple cores, you should see at least some benefits from
parallelization here.

For more powerful machines and larger grid sizes, parallelization can generate major speed gains, even on the CPU.
For more powerful machines and larger grid sizes, parallelization can generate
major speed gains, even on the CPU.


### Vectorized code with JAX

In most ways, vectorization is the same in JAX as it is in NumPy.
On the surface, vectorized code in JAX is similar to NumPy code.

But there are also some differences, which we highlight here.

Expand All @@ -319,14 +311,18 @@ grid = jnp.linspace(-3, 3, 3_000)
x_mesh, y_mesh = np.meshgrid(grid, grid)

with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
z_max = jnp.max(f(x_mesh, y_mesh))
z_max.block_until_ready()

print(f"Plain vanilla JAX result: {z_max:.6f}")
```

Let's run again to eliminate compile time.

```{code-cell} ipython3
with qe.Timer(precision=8):
z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready()
z_max = jnp.max(f(x_mesh, y_mesh))
z_max.block_until_ready()
```

Once compiled, JAX is significantly faster than NumPy due to GPU acceleration.
Expand Down Expand Up @@ -374,6 +370,8 @@ Let's see the timing:
with qe.Timer(precision=8):
z_max = jnp.max(f_vec(grid))
z_max.block_until_ready()

print(f"JAX vmap v1 result: {z_max:.6f}")
```

```{code-cell} ipython3
Expand Down Expand Up @@ -429,6 +427,8 @@ Let's try it.
```{code-cell} ipython3
with qe.Timer(precision=8):
z_max = compute_max_vmap_v2(grid).block_until_ready()

print(f"JAX vmap v1 result: {z_max:.6f}")
```

Let's run it again to eliminate compilation time:
Expand All @@ -445,19 +445,19 @@ If you are running this on a GPU, as we are, you should see another nontrivial s

In our view, JAX is the winner for vectorized operations.

It dominates NumPy both in terms of speed (via JIT-compilation and parallelization) and memory efficiency (via vmap).
It dominates NumPy both in terms of speed (via JIT-compilation and
parallelization) and memory efficiency (via vmap).

Moreover, the `vmap` approach can sometimes lead to significantly clearer code.

While Numba is impressive, the beauty of JAX is that, with fully vectorized
operations, we can run exactly the
same code on machines with hardware accelerators and reap all the benefits
without extra effort.
operations, we can run exactly the same code on machines with hardware
accelerators and reap all the benefits without extra effort.

Moreover, JAX already knows how to effectively parallelize many common array
operations, which is key to fast execution.

For almost all cases encountered in economics, econometrics, and finance, it is
For most cases encountered in economics, econometrics, and finance, it is
far better to hand over to the JAX compiler for efficient parallelization than to
try to hand code these routines ourselves.

Expand Down Expand Up @@ -537,9 +537,11 @@ This code is not easy to read but, in essence, `lax.scan` repeatedly calls `upda
```{note}
Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator.

The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism.
The computation consists of many small sequential operations, leaving little
opportunity for the GPU to exploit parallelism.

As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload.
As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU
a better fit for this workload.

Curious readers can try removing this option to see how performance changes.
```
Expand All @@ -558,16 +560,17 @@ with qe.Timer(precision=8):
x_jax = qm_jax(0.1, n).block_until_ready()
```

JAX is also efficient for this sequential operation.
JAX is also quite efficient for this sequential operation.

Both JAX and Numba deliver strong performance after compilation, with Numba
typically (but not always) offering slightly better speeds on purely sequential
operations.


### Summary

While both Numba and JAX deliver strong performance for sequential operations,
there are significant differences in code readability and ease of use.
*there are significant differences in code readability and ease of use*.

The Numba version is straightforward and natural to read: we simply allocate an
array and fill it element by element using a standard Python loop.
Expand All @@ -580,3 +583,4 @@ Additionally, JAX's immutable arrays mean we cannot simply update array elements

For this type of sequential operation, Numba is the clear winner in terms of
code clarity and ease of implementation, as well as high performance.

Loading