diff --git a/.github/workflows/cache.yml b/.github/workflows/cache.yml index 138ee3fb..c9325914 100644 --- a/.github/workflows/cache.yml +++ b/.github/workflows/cache.yml @@ -6,7 +6,7 @@ on: workflow_dispatch: jobs: cache: - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v6 - name: Setup Anaconda @@ -18,6 +18,16 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py + - name: Check nvidia drivers + shell: bash -l {0} + run: | + nvidia-smi - name: Build HTML shell: bash -l {0} run: | diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 51755ea9..58f69dcc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Build Project [using jupyter-book] on: [pull_request] jobs: preview: - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - uses: actions/checkout@v6 with: @@ -16,6 +16,15 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Check nvidia Drivers + shell: bash -l {0} + run: nvidia-smi + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 258fbe54..5622ef7a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,7 +6,7 @@ on: jobs: publish: if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') - runs-on: ubuntu-latest + runs-on: "runs-on=${{ github.run_id }}/family=g4dn.2xlarge/image=quantecon_ubuntu2404/disk=large" steps: - name: Checkout uses: actions/checkout@v6 @@ -21,6 +21,16 @@ jobs: python-version: "3.13" environment-file: environment.yml activate-environment: quantecon + - name: Install JAX and Numpyro + shell: bash -l {0} + run: | + pip install -U "jax[cuda13]" + pip install numpyro + python scripts/test-jax-install.py + - name: Check nvidia drivers + shell: bash -l {0} + run: | + nvidia-smi - name: Install latex dependencies run: | sudo apt-get -qq update diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 0d890d8f..b4114630 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -13,6 +13,18 @@ kernelspec: # JAX +This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). + +JAX is a high-performance scientific computing library that provides + +* a NumPy-like interface that can automatically parallize across CPUs and GPUs, +* a just-in-time compiler for accelerating a large range of numerical + operations, and +* automatic differentiation. + +Increasingly, JAX also maintains and provides more specialized scientific +computing routines, such as those originally found in SciPy. + In addition to what's in Anaconda, this lecture will need the following libraries: ```{code-cell} ipython3 @@ -21,28 +33,24 @@ In addition to what's in Anaconda, this lecture will need the following librarie !pip install jax quantecon ``` -This lecture provides a short introduction to [Google JAX](https://github.com/jax-ml/jax). - -Here we are focused on using JAX on the CPU, rather than on accelerators such as -GPUs or TPUs. - -This means we will only see a small amount of the possible benefits from using JAX. - -However, JAX seamlessly handles transitions across different hardware platforms. +```{admonition} GPU +:class: warning -As a result, if you run this code on a machine with a GPU and a GPU-aware -version of JAX installed, your code will be automatically accelerated and you -will receive the full benefits. +This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming. -For a discussion of JAX on GPUs, see [our JAX lecture series](https://jax.quantecon.org/intro.html). +Free GPUs are available on Google Colab. +To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. +Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. +If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +``` ## JAX as a NumPy Replacement -One of the attractive features of JAX is that, whenever possible, it conforms to -the NumPy API for array operations. +One of the attractive features of JAX is that, whenever possible, its array +processing operations conform to the NumPy API. -This means that, to a large extent, we can use JAX is as a drop-in NumPy replacement. +This means that, in many cases, we can use JAX is as a drop-in NumPy replacement. Let's look at the similarities and differences between JAX and NumPy. @@ -523,16 +531,9 @@ with qe.Timer(): jax.block_until_ready(y); ``` -If you are running this on a GPU the code will run much faster than its NumPy -equivalent, which ran on the CPU. - -Even if you are running on a machine with many CPUs, the second JAX run should -be substantially faster with JAX. - -Also, typically, the second run is faster than the first. +On a GPU, this code runs much faster than its NumPy equivalent. -(This might not be noticable on the CPU but it should definitely be noticable on -the GPU.) +Also, typically, the second run is faster than the first due to JIT compilation. This is because even built in functions like `jnp.cos` are JIT-compiled --- and the first run includes compile time. @@ -634,8 +635,7 @@ with qe.Timer(): jax.block_until_ready(y); ``` -The outcome is similar to the `cos` example --- JAX is faster, especially if you -use a GPU and especially on the second run. +The outcome is similar to the `cos` example --- JAX is faster, especially on the second run after JIT compilation. Moreover, with JAX, we have another trick up our sleeve: diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 883f2d14..c7e8d4c8 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -48,6 +48,18 @@ tags: [hide-output] !pip install quantecon jax ``` +```{admonition} GPU +:class: warning + +This lecture is accelerated via [hardware](status:machine-details) that has access to a GPU and target JAX for GPU programming. + +Free GPUs are available on Google Colab. +To use this option, please click on the play icon top right, select Colab, and set the runtime environment to include a GPU. + +Alternatively, if you have your own GPU, you can follow the [instructions](https://github.com/google/jax) for installing JAX with GPU support. +If you would like to install JAX running on the `cpu` only you can use `pip install jax[cpu]` +``` + We will use the following imports. ```{code-cell} ipython3 @@ -317,7 +329,7 @@ with qe.Timer(precision=8): z_max = jnp.max(f(x_mesh, y_mesh)).block_until_ready() ``` -Once compiled, JAX will be significantly faster than NumPy, especially if you are using a GPU. +Once compiled, JAX is significantly faster than NumPy due to GPU acceleration. The compilation overhead is a one-time cost that pays off when the function is called repeatedly. @@ -370,23 +382,29 @@ with qe.Timer(precision=8): z_max.block_until_ready() ``` -The execution time is similar to the mesh operation but, by avoiding the large input arrays `x_mesh` and `y_mesh`, -we are using far less memory. +By avoiding the large input arrays `x_mesh` and `y_mesh`, this `vmap` version uses far less memory. + +When run on a CPU, its runtime is similar to that of the meshgrid version. -In addition, `vmap` allows us to break vectorization up into stages, which is -often easier to comprehend than the traditional approach. +When run on a GPU, it is usually significantly faster. -This will become more obvious when we tackle larger problems. +In fact, using `vmap` has another advantage: It allows us to break vectorization up into stages. + +This leads to code that is often easier to comprehend than traditional vectorized code. + +We will investigate these ideas more when we tackle larger problems. ### vmap version 2 We can be still more memory efficient using vmap. -While we avoided large input arrays in the preceding version, +While we avoid large input arrays in the preceding version, we still create the large output array `f(x,y)` before we compute the max. -Let's use a slightly different approach that takes the max to the inside. +Let's try a slightly different approach that takes the max to the inside. + +Because of this change, we never compute the two-dimensional array `f(x,y)`. ```{code-cell} ipython3 @jax.jit @@ -399,14 +417,20 @@ def compute_max_vmap_v2(grid): return jnp.max(f_vec_max(grid)) ``` -Let's try it +Here + +* `f_vec_x_max` computes the max along any given row +* `f_vec_max` is a vectorized version that can compute the max of all rows in parallel. + +We apply this function to all rows and then take the max of the row maxes. + +Let's try it. ```{code-cell} ipython3 with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() ``` - Let's run it again to eliminate compilation time: ```{code-cell} ipython3 @@ -414,8 +438,7 @@ with qe.Timer(precision=8): z_max = compute_max_vmap_v2(grid).block_until_ready() ``` -We don't get much speed gain but we do save some memory. - +If you are running this on a GPU, as we are, you should see another nontrivial speed gain. ### Summary @@ -497,7 +520,9 @@ Now let's create a JAX version using `lax.scan`: from jax import lax from functools import partial -@partial(jax.jit, static_argnums=(1,)) +cpu = jax.devices("cpu")[0] + +@partial(jax.jit, static_argnums=(1,), device=cpu) def qm_jax(x0, n, α=4.0): def update(x, t): x_new = α * x * (1 - x) @@ -509,6 +534,16 @@ def qm_jax(x0, n, α=4.0): This code is not easy to read but, in essence, `lax.scan` repeatedly calls `update` and accumulates the returns `x_new` into an array. +```{note} +Sharp readers will notice that we specify `device=cpu` in the `jax.jit` decorator. + +The computation consists of many very small `lax.scan` iterations that must run sequentially, leaving little opportunity for the GPU to exploit parallelism. + +As a result, kernel-launch overhead tends to dominate on the GPU, making the CPU a better fit for this workload. + +Curious readers can try removing this option to see how performance changes. +``` + Let's time it with the same parameters: ```{code-cell} ipython3 diff --git a/lectures/status.md b/lectures/status.md index 3ada25f0..2ec414c4 100644 --- a/lectures/status.md +++ b/lectures/status.md @@ -31,4 +31,18 @@ and the following package versions ```{code-cell} ipython :tags: [hide-output] !conda list +``` + +This lecture series has access to the following GPU + +```{code-cell} ipython +!nvidia-smi +``` + +You can check the backend used by JAX using: + +```{code-cell} ipython3 +import jax +# Check if JAX is using GPU +print(f"JAX backend: {jax.devices()[0].platform}") ``` \ No newline at end of file diff --git a/scripts/test-jax-install.py b/scripts/test-jax-install.py new file mode 100644 index 00000000..c2be1d3d --- /dev/null +++ b/scripts/test-jax-install.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp + +devices = jax.devices() +print(f"The available devices are: {devices}") + +@jax.jit +def matrix_multiply(a, b): + return jnp.dot(a, b) + +# Example usage: +key = jax.random.PRNGKey(0) +x = jax.random.normal(key, (1000, 1000)) +y = jax.random.normal(key, (1000, 1000)) +z = matrix_multiply(x, y) + +# Now the function is JIT compiled and will likely run on GPU (if available) +print(z) + +devices = jax.devices() +print(f"The available devices are: {devices}")