Skip to content

Conversation

@xuanguang-li
Copy link
Contributor

@xuanguang-li xuanguang-li commented Nov 17, 2025

Updated the ge_arrow.md to JAX and complemented the styling consistent with the operation manual.

Key changes:

  • Rewrite the RecurCompetitive class as a NamedTuple.
  • Complete all computations inside the compute_rc_model function. Inside this function, arguments of sub-functions can be written in the same way as the definitions in the theory part.
  • Partially jitted the main computation function, and used jax.lax.fori_loop to conduct loops.
  • Fixed some typos and styling.

Update: Runtime Comparison Between JAX (GPU), JAX (CPU), and NumPy

Methodology: nearly the same as in #654

  • The JAX version uses the code in this PR, while the NumPy version uses the code in main.
  • The runtime for JAX (GPU) is measured using Google Colab T4 GPU runtime.
  • Runtime is collected using qe.timeit over 1,000 iterations.
  • Each iteration consists of solving Example 3 (solving the wealth distribution for 100 kinds of transition matrices).

Results:

  • Average runtime: JAX (CPU) > NumPy > JAX (GPU).
  • More details are shown in the attached plots.
runtime_compare_average runtime_compare_boxplot

@xuanguang-li xuanguang-li marked this pull request as ready for review December 11, 2025 00:52
 - Replace `@partial(jax.jit)` with `jax.jit` on the main function `compute_rc_model`.

- Write a function to compute example 3 and add `jax.jit` decorator.
@xuanguang-li xuanguang-li changed the title [ge_arrow] Update to JAX [ge_arrow] Update to JAX and compare runtime Feb 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

1 participant