Skip to content

[FEATURE] Faster ray-triangle intersection test that does not allocate large arrays #313

@jeertmans

Description

@jeertmans

Terms

  • Checked the existing issues to see if my suggestion has not already been suggested;

Description

The ray-intersect-any-triangle test is one of the main performance bottlenecks in this library when applied to a large number of rays and triangles. This is because jax.jit cannot optimize the jax.vmap-ed version of this check, resulting in the allocation of an array whose size scales with num_rays x num_triangles-too large to fit into any reasonable amount of memory.

To work around this, the current implementation uses a jax.lax.scan-like approach that processes triangles (or rays) sequentially, in user-defined batch sizes. The main drawbacks are:

  • The optimal batch size depends on many factors, including the number of rays, the number of triangles, and the available memory, making it difficult to choose good default values.
  • jax.lax.scan-like solutions are orders of magnitude slower than their jax.vmap counterparts when array sizes are not prohibitively large.

Ideally, we would develop a jax.vmap solution that jax.jit can optimize in a way that avoids allocating the problematic intermediate array altogether. I have already spent quite some time investigating this, but so far, without success.

If you have ideas for improving this or would like to contribute, please share your thoughts or open a PR!

Screenshots

No response

Additional information

Related links:

Relevant functions:

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needednice-to-haveA nice to have feature (or else)

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions