Skip to content

Install CuTeDSL into JAX containers and remove old CuTeDSL + JAX project#2013

Open
mgoldfarb-nvidia wants to merge 11 commits intomainfrom
mgoldfarb/update_cutedsl_testsd
Open

Install CuTeDSL into JAX containers and remove old CuTeDSL + JAX project#2013
mgoldfarb-nvidia wants to merge 11 commits intomainfrom
mgoldfarb/update_cutedsl_testsd

Conversation

@mgoldfarb-nvidia
Copy link
Copy Markdown
Contributor

CuTeDSL + JAX is now formally a part of CuTeDSL in the cutlass.jax sub-module. Users can now install directly from the official CuTeDSL release https://pypi.org/project/nvidia-cutlass-dsl/. Examples and documentation can be found at https://github.com/NVIDIA/cutlass/tree/main

Updated the path for Python examples in the unittest script.
@gpupuck
Copy link
Copy Markdown
Contributor

gpupuck commented Mar 23, 2026

There is job failure https://github.com/NVIDIA/JAX-Toolbox/actions/runs/23462382532/job/68271153801 that I don't know how to explain.

olupton
olupton previously approved these changes Mar 25, 2026
Copy link
Copy Markdown
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think the tests all failed because of a JAX-side bug fixed by jax-ml/jax#36185

@gpupuck
Copy link
Copy Markdown
Contributor

gpupuck commented Mar 30, 2026

Ready to merge?

@mgoldfarb-nvidia
Copy link
Copy Markdown
Contributor Author

Ready to merge?

Appears the tests are failing for some reason:

=== /jax-cutlass-src/cutlass/examples/python/CuTeDSL/jax/cutlass_call_basic.py ===
Traceback (most recent call last):
  File "/jax-cutlass-src/cutlass/examples/python/CuTeDSL/jax/cutlass_call_basic.py", line 33, in <module>
    import cutlass
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/jax/types.py", line 65, in <module>
    @jax.tree_util.register_dataclass
     ^^^^^^^^^^^^^
AttributeError: module 'jax' has no attribute 'tree_util'

This is surprising since jax.tree_util.register_dataclass is a valid module inside jax.

@olupton
Copy link
Copy Markdown
Collaborator

olupton commented Mar 31, 2026

Ready to merge?

Appears the tests are failing for some reason:

=== /jax-cutlass-src/cutlass/examples/python/CuTeDSL/jax/cutlass_call_basic.py ===
Traceback (most recent call last):
  File "/jax-cutlass-src/cutlass/examples/python/CuTeDSL/jax/cutlass_call_basic.py", line 33, in <module>
    import cutlass
  File "/usr/local/lib/python3.12/dist-packages/nvidia_cutlass_dsl/python_packages/cutlass/jax/types.py", line 65, in <module>
    @jax.tree_util.register_dataclass
     ^^^^^^^^^^^^^
AttributeError: module 'jax' has no attribute 'tree_util'

This is surprising since jax.tree_util.register_dataclass is a valid module inside jax.

I suspect it's because PYTHONPATH=${CUTLASS_EXAMPLES_ROOT} puts a directory containing a jax folder first in the path.

Removed PYTHONPATH export from unittest.sh.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants