-
-
Notifications
You must be signed in to change notification settings - Fork 6
Description
Terms
- Checked the existing issues to see if my suggestion has not already been suggested;
Description
The ray-intersect-any-triangle test is one of the main performance bottlenecks in this library when applied to a large number of rays and triangles. This is because jax.jit cannot optimize the jax.vmap-ed version of this check, resulting in the allocation of an array whose size scales with num_rays x num_triangles-too large to fit into any reasonable amount of memory.
To work around this, the current implementation uses a jax.lax.scan-like approach that processes triangles (or rays) sequentially, in user-defined batch sizes. The main drawbacks are:
- The optimal batch size depends on many factors, including the number of rays, the number of triangles, and the available memory, making it difficult to choose good default values.
jax.lax.scan-like solutions are orders of magnitude slower than theirjax.vmapcounterparts when array sizes are not prohibitively large.
Ideally, we would develop a jax.vmap solution that jax.jit can optimize in a way that avoids allocating the problematic intermediate array altogether. I have already spent quite some time investigating this, but so far, without success.
If you have ideas for improving this or would like to contribute, please share your thoughts or open a PR!
Screenshots
No response
Additional information
Related links:
- My discussion opened in JAX's repo: JAX unable to optimize reduce operation, leading to OOM issues or really slow `jax.lax.scan` / `jax.lax.map` jax-ml/jax#30470
- My issue opened in JAX's repo:
lax.reduceNotImplementedError: Reduction computations can't close over Tracers jax-ml/jax#30841 - Example issue where JAX can optimize the reduction: Reduce functionality for vmap jax-ml/jax#9505
- JAX's
Refcould be interesting to look at: https://docs.jax.dev/en/latest/array_refs.html
Relevant functions: