diff --git a/.github/container/Dockerfile.jax b/.github/container/Dockerfile.jax index da7c2a29e..81816e3e7 100644 --- a/.github/container/Dockerfile.jax +++ b/.github/container/Dockerfile.jax @@ -134,6 +134,9 @@ ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" > /opt/pip-tools.d/requirements-te.in EOF +## nvidia-cutlass-dsl +RUN echo "nvidia-cutlass-dsl" >> /opt/pip-tools.d/requirements-cutlass-dsl.in + ############################################################################### ## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems ############################################################################### diff --git a/.github/container/cutlass_dsl_jax/LICENSE b/.github/container/cutlass_dsl_jax/LICENSE deleted file mode 100644 index 261eeb9e9..000000000 --- a/.github/container/cutlass_dsl_jax/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/.github/container/cutlass_dsl_jax/README.md b/.github/container/cutlass_dsl_jax/README.md deleted file mode 100644 index f42e3ef89..000000000 --- a/.github/container/cutlass_dsl_jax/README.md +++ /dev/null @@ -1,174 +0,0 @@ -# Jax + CuTe DSL - -The experimental primitive `cutlass_call` provides a simple API to call kernels written with [CuTe DSL](https://docs.nvidia.com/cutlass/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.html). - -Note that CuTe DSL is still in its beta phase and may change significantly from release to release. As needed this API may be updated or modified as required. - -## Calling a Kernel - -`cutlass_call` enables compilation and execution of a kernel and a host function from `jax.jit` functions. We assume static layout and shape for `cute.Tensor`s passed from Jax programs. This allows for each instance of a program to be compiled against the exact shapes for maximum efficiency and minimal overhead between Jax and the CuTe DSL program. - -``` -@cute.kernel -def kernel(in: cute.Tensor, out: cute.Tensor): - ... - -@cute.jit -def launch(stream: cuda.CUstream, input: cute.Tensor, out: cute.Tensor): - kernel(a, b, c, const_a, const_b).launch( - grid=[a.shape[-1], 1, 1], - block=[a.shape[-1], 1, 1], - stream=stream) - -call = cutlass_call(launch, output_shape_dtype=jax.ShapeDtypeStruct((128, 64), jnp.float16)) -out = call(input) -``` - -## Host Function Signature - -`cutlass_call` requires a specific function signature to bind arrays and constant values provided by Jax. It is recommended that you annotate the signature with the appropriate or expected types. - -All functions must take as the first argument a `cuda.CUstream` that will be used to launch and synchronize the kernel. Following the stream must be input tensors, then input/output tensors, then output tensors. Constexpr values must be passed as named keyword arguments last. - -If the kernel signature does not exactly match there are two options: change the host function or use a small wrapper to aid in binding the parameters to the kernel. - -``` -@cute.jit -def launch(out: cute.Tensor, constval: cutlass.Constexpr, input: cute.Tensor, stream: cuda.CUstream): - ... - -x = cutlass_call( - lambda stream, input, output, **kwargs: launch(output, kwargs["constval"], input, stream), - constval=1.0, - ... -) -``` - -## Layouts and Modes - -`cutlass_call` accepts two optional parameters to aid in converting `jax.Array` into a `cute.Tensor`: - -_Layout_: Tuple of indexes that specify the physical order of axis strides for the source `jax.Array`. If omitted the array is assumed to have row-major layout. -_Mode_: Tuple of indexes that specify the order axes and strides of the `cute.Tensor`. If omitted the order is directly taken from the layout. - -Layout and mode are specified for each input and output `jax.Array` in the flattened pytree. `None` may be used to indicate a default entry for a specific `jax.Array`. - -The following example demonstrates how the layout and mode impact the `cute.Tensor` layout. - -``` -@cute.kernel -def kernel(input: cute.Tensor, out: cute.Tensor): - cute.printf(input.layout) - cute.printf(out.layout) - -@cute.jit -def launch(stream, input: cute.Tensor, out: cute.Tensor): - kernel(input, out).launch(grid=[1, 1, 1], block=[1, 1, 1], stream=stream) - -a = jnp.zeros((128, 512, 64)) # batch, row, column -call = cutlass_call(launch, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, jnp.float32)) -out = call(a) - -(128,512,64):(32768,64,1) # default row major layout note shape and strides -(128,512,64):(32768,64,1) - -call = cutlass_call(launch, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, jnp.float32), input_mode=((0, 2, 1),)) -out = call(a) - -(128,64,512):(32768,1,64) # shape and stride reordered to reflect mode. -(128,512,64):(32768,64,1) - -call = cutlass_call(launch, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, jnp.float32), input_mode=((0, 2, 1),), output_mode=((2, 1, 0),)) -out = call(a) - -(128,64,512):(32768,1,64) -(64,512,128):(1,64,32768) # output and input layout can differ - -call = cutlass_call(launch, output_shape_dtype=jax.ShapeDtypeStruct(a.shape, jnp.float32), input_layout=((2, 0, 1),)) -out = call(a) - -(128,512,64):(32768,1,512) # Column major input array -(128,512,64):(32768,64,1) -``` - -A common example of when to use modes is to implement logical layouts for gemm operations. For example we want the `cute.Tensor`s to follow a consistent layout specification of `[M][K][L]` for A, `[N][K][L]` for B and `[M][N][L]` for C. Strides of each dimension are set to properly offset into physical memory. - - -### Complex Layouts and Modes - -If your kernel requires more complex layouts at the function boundary e.g. tiled, composed or hierarchical it is recommended that the kernel be wrapped in a `cute.jit` function. Once you are outside of Jax you can built arbitrary layouts as needed using CuTe DSL. - -## Compilation Cache - -We maintain a cache of the compiled functions. When a kernel is compiled we check if it was previously seen for the given shapes, dtypes and constant values. For example if your kernel is called in a series of homogeneous layer (e.g a transformer model) it will only need to compile once and that instance can be reused as needed. - -## Limitations - -There are several limitations to highlight to avoid unexpected errors or behavior. Over time we hope to improve these as CuTe DSL matures. - -### Jit Function Argument Types - -`cutlass_call` allows for the following types to be passed: - -* `jax.Array` -* `list[jax.Array]` and `tuple[jax.Array]` - -Non-array types can be passed as keyword arguments to `cutlass_call`. - -If the kernel interface depends on a complex Python type it is recommended that a wrapper function be used to bind together `jax.Array` and other compile time constants that can be passed. - -``` -@cute.jit -def launch(stream: cuda.CUstream, x: CustomType): - kernel(x).launch(grid=x.get_grid(), block=x.get_block(), stream=stream) - -@cute.jit -def wrapper(stream: cuda.CUstream, a: cute.Tensor, b: cute.Tensor, *, constval: float): - x = CustomType(a, b, constval) - launch(stream, x) - -out = cutlass_call(wrapper, ..., constval=1.0)(a, b) -``` - -A useful trick can be to flatten a `PyTree` structure outside the call then unflatten inside the call by passing the `PyTree` as a constexpr argument. - -#### Variable Length Arguments - -Variable length positional arguments are not supported in the `cute.jit` function signature however you can emulate variable length input using `list` or `tuple` types. Its important to keep in mind that these lists are static in length and do not behave like dynamic containers. - -#### kwargs - -kwargs may not be used to pass `jax.Array`s to the `cutlass_call` they are only used to pass constant/static values. - -#### Dictionary Types - -Dictionary types are not supported for passing `jax.Array` to the `cutlass_call`. - -### Closures and Nested Functions - -Closures are not fully supported and may result in unexpected memory consumption or tracer leaks. If you need to capture state, its recommend to wrap your function in a class with global scope. The class can be instantiated deeper into your program with the appropriate values provided. - -``` -class MyFunction: - def __init__(self, v0, v1): - self.v0 = v0 - self.v1 = v1 - - @cute.jit - def __call__(self, stream, ...): - # use self.v0, self.v1 - ... - -def make_function(v0, v1): - return MyFunction(v0, v1) -``` - -One common exception is simple lambda function which are generally safe to use as shown in examples above. - -### Autotuning - -`cutlass_call` will not autotune arguments to the function. If there are multiple possible configurations you will need to sweep them in a separate program to find the optimal settings for your kernel. - -### AoT Compilation and Cache Persistence - -There is no support for AoT compilation or compile cache persistence. diff --git a/.github/container/cutlass_dsl_jax/pyproject.toml b/.github/container/cutlass_dsl_jax/pyproject.toml deleted file mode 100644 index d7ffc523b..000000000 --- a/.github/container/cutlass_dsl_jax/pyproject.toml +++ /dev/null @@ -1,28 +0,0 @@ -[project] -name = "nvidia-cutlass-dsl-jax" -description = "Primitives for calling CuTe/CUTLASS DSL kernels in Jax." -readme = "README.md" -requires-python = ">=3.12" -dependencies = [ - "jax>=0.6.2", - "nvidia-cutlass-dsl>=4.3.1" -] -dynamic = ["version"] - -[project.optional-dependencies] -tests = ["pytest"] - -[build-system] -requires = ["setuptools", "setuptools-scm"] -build-backend = "setuptools.build_meta" - -[tool.setuptools.dynamic] -version = {attr = "jax_cutlass.version.__version__"} - -[tool.setuptools.packages] -find = {"where" = ["src"]} - -[tool.pytest.ini_options] -addopts = [ - "--import-mode=importlib", -] diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/__init__.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/__init__.py deleted file mode 100644 index 07cd4854d..000000000 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .primitive import cutlass_call -from .types import jax_to_cutlass_dtype, from_dlpack, JaxArray, TensorMode -from .compile import release_compile_cache, initialize_cutlass_dsl -from .version import __version__, __version_info__ - -# This explicit init method ensures that we avoid initialization at -# unexpected times. TODO: try to remove the need for this initialization. -initialize_cutlass_dsl() - -__all__ = [ - "cutlass_call", - "jax_to_cutlass_dtype", - "from_dlpack", - "JaxArray", - "TensorMode", - "release_compile_cache", - "__version__", - "__version_info__", -] diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py deleted file mode 100644 index 5cee36710..000000000 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import gc -import ctypes -import inspect -from typing import Any, Callable, Optional -from dataclasses import dataclass -from functools import partial -from pathlib import Path - -import time -import logging -import threading - -import cuda.bindings.driver as cuda - -import jax -import jax.numpy as jnp -from jax.experimental.buffer_callback import ExecutionContext -import jaxlib - -from .types import ( - jax_to_cutlass_dtype, - from_dlpack, - JaxArray, - JaxArrayList, - TensorMode, - make_placeholder_array, - DEFAULT_CUTLASS_DEVICE_MEMSPACE, - DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT, -) - -import cutlass -import cutlass.cute as cute -from cutlass.cute import AddressSpace -from cutlass.cutlass_dsl.cutlass import CuTeDSL -from cutlass.base_dsl.runtime.cuda import unload_cubin_module - -logger = logging.getLogger(__name__) - -_CUTLASS_COMPILE_CACHE = {} - - -@dataclass(frozen=True) -class Arg: - idx: int # position in pytree - shape: tuple[int, ...] - dtype: jnp.dtype - layout: tuple[int, ...] - mode: TensorMode - - -@dataclass(frozen=True) -class FunctionSpec: - """Contains a specification of the inputs and outputs to the kernel.""" - - in_args: tuple[Arg, ...] - input_tree: Any - out_args: tuple[Arg, ...] - output_tree: Any - input_output_aliases: tuple[tuple[int, int], ...] - input_layout: tuple[tuple[int, ...]] - input_mode: tuple[TensorMode, ...] - output_layout: tuple[tuple[int, ...]] - output_mode: tuple[TensorMode, ...] - convert_tensors: bool - compile_options: str - use_static_tensors: bool - kwargs: tuple[tuple[str, Any]] - - def get_compile_args(self): - """Returns the arguments to provide to cute.compile.""" - compiler_ins = [ - make_placeholder_array( - jax_to_cutlass_dtype(leaf.dtype), - leaf.shape, - leaf.layout, - DEFAULT_CUTLASS_DEVICE_MEMSPACE, - mode.ptr_assumed_align, - ) - for leaf, mode in zip(self.in_args, self.input_mode) - ] - compiler_outs = [ - make_placeholder_array( - jax_to_cutlass_dtype(leaf.dtype), - leaf.shape, - leaf.layout, - DEFAULT_CUTLASS_DEVICE_MEMSPACE, - mode.ptr_assumed_align, - ) - for leaf, mode in zip(self.out_args, self.output_mode) - ] - return JaxArrayList(tuple(sum([compiler_ins, compiler_outs], []))) - - def get_runtime_args(self, out, *args): - """Returns the arguments to provide to the compiled function at runtime.""" - ins = [from_dlpack(args[i]).iterator for i, spec in enumerate(self.in_args)] - outs = [from_dlpack(out[i]).iterator for i, spec in enumerate(self.out_args)] - return JaxArrayList(tuple(sum([ins, outs], []))) - - -@cute.jit -def jit_wrapper( - stream: cuda.CUstream, - args: JaxArrayList, - *, - wrapped_fn: cutlass.Constexpr, - spec: cutlass.Constexpr, -): - # split buffer argument into inputs and outputs and return to tree - ins, outs = args[: len(spec.in_args)], args[(len(spec.in_args)) :] - if cutlass.const_expr(spec.convert_tensors): - ins = [ - x.get_tensor(a.mode, spec.use_static_tensors) - for x, a in zip(ins, spec.in_args) - ] - outs = [ - x.get_tensor(a.mode, spec.use_static_tensors) - for x, a in zip(outs, spec.out_args) - ] - ins = jax.tree.unflatten(spec.input_tree, ins) - outs = jax.tree.unflatten(spec.output_tree, outs) - wrapped_fn(stream, *ins, *outs, **dict(spec.kwargs)) - - -@dataclass -class CompileResult: - """Holds reference to the compiled kernel and arguments. - - compiled_fn: The compiled function (a JitExecutor). - This reference keeps CUDA modules alive. - - """ - - compiled_fn: cutlass.base_dsl.jit_executor.JitExecutor - spec: FunctionSpec - - def __call__(self, ctx: ExecutionContext, out, *args): - self.compiled_fn( - cuda.CUstream(ctx.stream), self.spec.get_runtime_args(out, *args) - ) - - -def _check_is_valid_type(x, is_input): - if not is_input: - if not isinstance(x, jax.ShapeDtypeStruct): - raise TypeError("Invalid output value passed.", x) - else: - if not isinstance(x, jax.Array): - raise TypeError("Invalid type passed.", x) - - -def _build_arg_tree(args, specs, is_input): - args = [] - for idx, (arg, layout) in enumerate(zip(args_flat, specs)): - _check_is_valid_type(arg, is_input) - args.append(Arg(idx, arg.shape, arg.dtype, layout)) - args = jax.tree.unflatten(args_tree, args) - - return args, args_tree, is_single_leaf_node - - -def build_function_spec( - ins, - in_tree, - outs, - out_tree, - input_layout, - output_layout, - input_mode, - output_mode, - input_output_aliases, - convert_tensors, - compile_options, - use_static_tensors, - kwargs, -): - # TODO: Improve type checking and validate pytree structures. - # TODO: Improve Pytree support for more complex or user defined structures. - - in_args = [] - for idx, (arg, layout, mode) in enumerate(zip(ins, input_layout, input_mode)): - _check_is_valid_type(arg, is_input=True) - in_args.append(Arg(idx, arg.shape, arg.dtype, layout, mode)) - - out_args = [] - for idx, (arg, layout, mode) in enumerate(zip(outs, output_layout, output_mode)): - _check_is_valid_type(arg, is_input=False) - out_args.append(Arg(idx, arg.shape, arg.dtype, layout, mode)) - - # Return the argument specs to the original pytree structure - # We need this structure to sanely match index positions of the - # arguments to the kernel. - ins_args_structured = jax.tree.unflatten(in_tree, in_args) - out_args_structured = jax.tree.unflatten(out_tree, out_args) - - # Assign per-leaf aliases - input_output_aliases_per_leaf = {} - for input_arg_alias_idx in input_output_aliases: - flat_in, _ = jax.tree.flatten(ins_args_structured[input_arg_alias_idx]) - flat_out, _ = jax.tree.flatten( - out_args_structured[input_output_aliases[input_arg_alias_idx]] - ) - for i, o in zip(flat_in, flat_out): - input_output_aliases_per_leaf[i.idx] = o.idx - - # Remove aliased arguments from output set since they are also provided - # as inputs. This is done at the very top level of the tree to simplify - # how we handle aliasing. The assumption is that the entire pytree is - # aliased. - out_args_structured = list(out_args_structured) - for out_idx in sorted(tuple(set(input_output_aliases.values())), reverse=True): - try: - out_args_structured.pop(out_idx) - except: - raise ValueError(f"Invalid output alias in input_output_aliases.") - out_args_structured = tuple(out_args_structured) - - in_args_flat, _ = jax.tree.flatten(ins_args_structured) - out_args_flat, out_tree = jax.tree.flatten(out_args_structured) - - spec = FunctionSpec( - tuple(in_args_flat), - in_tree, - tuple(out_args_flat), - out_tree, - tuple(input_output_aliases_per_leaf.items()), - tuple(input_layout), - tuple(input_mode), - tuple(output_layout), - tuple(output_mode), - convert_tensors, - compile_options, - use_static_tensors, - tuple((k, kwargs[k]) for k in kwargs), - ) - - return spec - - -_compile_lock = threading.Lock() - - -def get_or_compile_kernel(fn, spec, stream): - """Gets or compiles fn and returns a CutlassCompileResult. - - The function and its specification is used as a key to determine if a new - function must be compiled. - """ - cache_key = (fn, spec, stream) - if cache_key in _CUTLASS_COMPILE_CACHE: - return _CUTLASS_COMPILE_CACHE[cache_key] - - # Don't allow more than 1 thead to compile at any time. - # We assume that the cache key is per thread so we don't need to lock - # the above check in compile cache, - # TODO: ideally this lock would happen in cute.compile as needed. - compiled_fn = None - with _compile_lock: - start = time.time() - try: - cute_compile = cutlass.cute.compile - if spec.compile_options: - cute_compile = partial(cute_compile, options=spec.compile_options) - - compiled_fn = cute_compile( - jit_wrapper, - cuda.CUstream(stream), - spec.get_compile_args(), - wrapped_fn=fn, - spec=spec, - ) - except Exception as e: - # Log here because Jax can obscure the exception details. - logger.exception("Compilation failure for kernel.") - raise e - end = time.time() - logger.debug(f"Took {end-start} to compile cute kernel.") - - result = CompileResult(compiled_fn=compiled_fn, spec=spec) - _CUTLASS_COMPILE_CACHE[cache_key] = result - return result - - -def release_compile_cache(): - """Releases entries from the compile cache. - - Note that is may prevent cute dsl from saving its persistent compilation cache entries. - """ - _CUTLASS_COMPILE_CACHE.clear() - dsl = CuTeDSL._get_dsl() - dsl.jit_cache.clear() - # TODO: This is needed to release frames being held in the DSL - # We should avoid holding such references as they unexpectedly - # extend object lifetime. - dsl.frame = None - gc.collect() - - -class _DummyInitKernel: - @cute.kernel - def kernel(self): - pass - - @cute.jit - def init(self): - pass - - -_CUTLASS_DSL_INITIALIZED = False - - -def initialize_cutlass_dsl(): - """Initializes cutlass DSL.""" - global _CUTLASS_DSL_INITIALIZED - if _CUTLASS_DSL_INITIALIZED: - return - - kernel = _DummyInitKernel() - with _compile_lock: - logger.debug("Initializing cutlass dsl...") - _ = cutlass.cute.compile(kernel.init) - - _CUTLASS_DSL_INITIALIZED = True diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py deleted file mode 100644 index 31b6abcdb..000000000 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Union, Sequence, Callable -from functools import partial -import logging - -import cuda.bindings.driver as cuda - -import jax, jax.numpy as jnp -import jax.extend -from jax.interpreters import mlir -from jax._src.interpreters import ad -from jax._src.interpreters import batching -from jax._src import ffi -from jax.tree import flatten, unflatten - -try: - from jax.experimental.buffer_callback import buffer_callback -except ImportError as e: - # buffer_callback is used until we implement C++/FFI interface - raise ImportError( - "A more recent version (>=0.6.2) of Jax is required for cutlass_call. Current version:" - f" {jax.__version__}" - ) from e - -import cutlass - -from .compile import get_or_compile_kernel, build_function_spec -from .types import row_major_layout, default_tensor_mode, TensorMode - -logger = logging.getLogger(__name__) - -cutlass_call_inner_p = jax.extend.core.Primitive("cutlass_call_inner") -cutlass_call_inner_p.multiple_results = True - - -def cutlass_call( - fn: Callable[..., None], - *, - output_shape_dtype: Any, - input_layout: Any = None, - output_layout: Any = None, - input_mode: Any = None, - output_mode: Any = None, - input_output_aliases={}, - convert_tensors=True, - allow_cuda_graph=True, - compile_options=None, - use_static_tensors=False, - **kwargs, -): - """Creates a callable that invokes a @cute.jit function. - - Args: - fn: A @cute.jit decorated function that launches a cutlass kernel. - output_shape_dtype: A pytree representing the shape and dtype of the output buffers. - input_output_aliases: A mapping of input to output aliases. Positions are specified assuming - a flattened input and output pytree. - input_layout: Specifies the Jax layout for input arrays. If None then the layout will - assume row-major order. - output_layout: Specifies the Jax layout for output arrays. If None then the layout will - assume row-major order. - input_mode: Specifies a cute.Tensor dimension order for input tensors. If None then the order - will assume the corresponding layout order specific by input_layout. - output_mode: Specifies a cute.Tensor dimension order for output tensors. If None then the order - will assume the corresponding layout order specific by output_layout. - convert_tensors: Jax array buffers will be converted to cute.Tensor with static shape and - layout. If disabled the kernel is instead given a JaxArray pointer directly. - allow_cuda_graph: If false will prevent XLA from building a cuda graph of for this call. - compile_options: Optional compiler arguments to pass into cute.compile. - use_static_tensors: If True, tensor shapes and strides are treated as constexpr values by - default. This can improve performance through compiler specialization but may not work - properly with all kernels. Specific tensors may be marked static or dynamic using the mode - and override this flag. - kwargs: Optional constexpr parameters to pass into the kernel fn. - - Note: This API is experimental and subject to change! - """ - output_shape_dtype = jax.tree.map( - lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), output_shape_dtype - ) - return _cutlass_call_impl( - fn, - output_shape_dtype=output_shape_dtype, - input_layout=input_layout, - output_layout=output_layout, - input_mode=input_mode, - output_mode=output_mode, - input_output_aliases=input_output_aliases, - convert_tensors=convert_tensors, - allow_cuda_graph=allow_cuda_graph, - compile_options=compile_options, - use_static_tensors=use_static_tensors, - **kwargs, - ) - - -def _normalize_tensor_mode(value: Any): - if value is None: - return [None] - elif isinstance(value, (tuple, list)): - if isinstance(value[0], int): # single tuple of modes - return TensorMode(tuple(value)) - else: - flat, _ = jax.tree.flatten( - [_normalize_tensor_mode(x) for x in value], - is_leaf=lambda x: x is None or isinstance(x, TensorMode), - ) - return flat - elif isinstance(value, TensorMode): - return [value] - else: - raise TypeError(f"Unexpected value for TensorMode {value} {type(value)}") - - -def _cutlass_call_impl( - fn, - *, - output_shape_dtype: Any, - input_layout: Any, - output_layout: Any, - input_mode: Any, - output_mode: Any, - input_output_aliases, - convert_tensors, - allow_cuda_graph, - compile_options, - use_static_tensors, - **kwargs, -): - multiple_results = isinstance(output_shape_dtype, Sequence) - if not multiple_results: - output_shape_dtype = (output_shape_dtype,) - output_shape_dtype_flat, output_tree = jax.tree.flatten(output_shape_dtype) - - @partial(jax.jit, inline=True) - def call_wrapper(*args): - args_flat, args_tree = jax.tree.flatten(args) - - if input_layout is None: - input_layout_flat = [row_major_layout(x) for x in args_flat] - else: - input_layout_flat = list(input_layout) - for idx, (layout, arg) in enumerate(zip(input_layout_flat, args_flat)): - if layout is None: - input_layout_flat[idx] = row_major_layout(arg) - input_layout_flat = tuple(input_layout_flat) - - if output_layout is None: - output_layout_flat = [row_major_layout(x) for x in output_shape_dtype_flat] - else: - output_layout_flat = list(output_layout) - for idx, (layout, arg) in enumerate( - zip(output_layout_flat, output_shape_dtype_flat) - ): - if layout is None: - output_layout_flat[idx] = row_major_layout(arg) - output_layout_flat = tuple(output_layout_flat) - - if len(input_layout_flat) != len(args_flat): - raise ValueError("Must has same number of input layouts as input arrays.") - - if len(output_layout_flat) != len(output_shape_dtype_flat): - raise ValueError("Must has same number of output layouts as output arrays.") - - if input_mode is None: - input_mode_flat = tuple(default_tensor_mode(x) for x in args_flat) - else: - input_mode_flat = _normalize_tensor_mode(input_mode) - for idx, (mode, arg) in enumerate(zip(input_mode_flat, args_flat)): - if mode is None: - input_mode_flat[idx] = default_tensor_mode(arg) - input_mode_flat = tuple(input_mode_flat) - - if output_mode is None: - output_mode_flat = tuple( - default_tensor_mode(x) for x in output_shape_dtype_flat - ) - else: - output_mode_flat = _normalize_tensor_mode(output_mode) - for idx, (mode, arg) in enumerate( - zip(output_mode_flat, output_shape_dtype_flat) - ): - if mode is None: - output_mode_flat[idx] = default_tensor_mode(arg) - output_mode_flat = tuple(output_mode_flat) - - if len(input_mode_flat) != len(args_flat): - raise ValueError( - f"Must has same number of input modes ({len(input_mode_flat)}) as input arrays ({len(args_flat)})." - ) - - if len(output_mode_flat) != len(output_shape_dtype_flat): - raise ValueError( - f"Must has same number of output modes ({len(output_mode_flat)}) as output arrays ({len(output_shape_dtype_flat)})." - ) - - # Validate dynamic mode settings match whatever static shape - # information we got as input. - for idx, (arg, mode) in enumerate(zip(args_flat, input_mode_flat)): - if mode.mode is not None and len(mode.mode) != len(arg.shape): - raise ValueError( - f"Input #{idx} has invalid mode {mode.mode} for shape {arg.shape}." - ) - for idx, (arg, mode) in enumerate( - zip(output_shape_dtype_flat, output_mode_flat) - ): - if mode.mode is not None and len(mode.mode) != len(arg.shape): - raise ValueError(f"Output #{idx} has invalid mode.") - - output_flat = cutlass_call_inner_p.bind( - *args_flat, - fn=fn, - args_tree=args_tree, - output_shape_dtype_flat=tuple(output_shape_dtype_flat), - output_tree=output_tree, - input_layout_flat=tuple(input_layout_flat), - output_layout_flat=tuple(output_layout_flat), - input_mode_flat=tuple(input_mode_flat), - output_mode_flat=tuple(output_mode_flat), - input_output_aliases=tuple(input_output_aliases.items()), - convert_tensors=convert_tensors, - allow_cuda_graph=allow_cuda_graph, - compile_options=compile_options, - use_static_tensors=use_static_tensors, - **kwargs, - ) - - output = jax.tree.unflatten(output_tree, output_flat) - return output if multiple_results else output[0] - - return call_wrapper - - -@cutlass_call_inner_p.def_abstract_eval -def cutlass_call_inner_p_abstract(*_, output_shape_dtype_flat, **__): - return [jax.core.ShapedArray(x.shape, x.dtype) for x in output_shape_dtype_flat] - - -def cutlass_call_inner_p_impl( - *args_flat, - fn, - args_tree: Any, - output_shape_dtype_flat: Any, - output_tree: Any, - input_layout_flat: Any, - output_layout_flat: Any, - input_mode_flat: Any, - output_mode_flat: Any, - input_output_aliases, - convert_tensors, - allow_cuda_graph, - compile_options, - use_static_tensors, - **kwargs, -): - input_output_aliases = dict(input_output_aliases) - - # TODO: Need to support device-less compilation and defer module load - # so we can compile at trace time. While we could compile here with one - # device, the JitExecutor does not support more than one device per - # instance. For now we explicitly compile and load under each context. - spec = build_function_spec( - args_flat, - args_tree, - output_shape_dtype_flat, - output_tree, - input_layout_flat, - output_layout_flat, - input_mode_flat, - output_mode_flat, - input_output_aliases, - convert_tensors, - compile_options, - use_static_tensors, - kwargs, - ) - - def wrap(fn, spec): - def _inner(ctx, *args): - kernel = get_or_compile_kernel(fn, spec, ctx.stream) - kernel(ctx, *args) - - return _inner - - fun = buffer_callback( - wrap(fn, spec), - result_shape_dtypes=output_shape_dtype_flat, - input_output_aliases=dict(spec.input_output_aliases), - command_buffer_compatible=allow_cuda_graph, - ) - return fun(*args_flat) - - -def _cutlass_call_jvp_rule(*args, **kwargs): - del args, kwargs - raise NotImplementedError( - "cutlass_call does not support VJP. Please use `jax.custom_jvp` for taking gradients." - ) - - -ad.primitive_jvps[cutlass_call_inner_p] = _cutlass_call_jvp_rule - - -def _cutlass_call_transpose_rule(*args, **kwargs): - del args, kwargs - raise NotImplementedError( - "cutlass_call does not support transpose. Please use `jax.custom_vjp` for taking gradients." - ) - - -ad.primitive_transposes[cutlass_call_inner_p] = _cutlass_call_transpose_rule - - -def _cutlass_call_vmap_rule(*args, **kwargs): - del args, kwargs - raise NotImplementedError( - "cutlass_call does not support batching with jax.vmap. Please " - "use jax.custom_batching.custom_vmap for applying vmap. " - ) - - -batching.primitive_batchers[cutlass_call_inner_p] = _cutlass_call_vmap_rule - -jax._src.dispatch.simple_impl(cutlass_call_inner_p) -jax._src.dispatch.prim_requires_devices_during_lowering.add(cutlass_call_inner_p) -lowering = mlir.lower_fun(cutlass_call_inner_p_impl, multiple_results=True) -mlir.register_lowering(cutlass_call_inner_p, lowering, platform="cuda") diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py deleted file mode 100644 index 4b0f333b1..000000000 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Type, Optional, Sequence, Union, Callable, Any, TypeVar -import sys -import ctypes -import math -import inspect -from dataclasses import dataclass, field -from functools import partial, reduce -from operator import mul -from itertools import chain -from typing import Annotated - -import cuda.bindings.driver as cuda - -import jax -import jax.numpy as jnp - -import cutlass -import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack as _from_dlpack -from cutlass.cute import AddressSpace, Numeric, IntTuple -from cutlass._mlir import ir - -JAX_DTYPE_TO_CUTLASS_DTYPE = { - # TODO(mgoldfarb-nvidia): Check passing boolean arrays via __dlpack__ - jnp.bool.dtype: cutlass.Boolean, - jnp.int8.dtype: cutlass.Int8, - jnp.int16.dtype: cutlass.Int16, - jnp.int32.dtype: cutlass.Int32, - jnp.int64.dtype: cutlass.Int64, - jnp.uint8.dtype: cutlass.Uint8, - jnp.uint16.dtype: cutlass.Uint16, - jnp.uint32.dtype: cutlass.Uint32, - jnp.uint64.dtype: cutlass.Uint64, - jnp.bfloat16.dtype: cutlass.BFloat16, - jnp.float16.dtype: cutlass.Float16, - jnp.float32.dtype: cutlass.Float32, - jnp.float64.dtype: cutlass.Float64, - jnp.float8_e8m0fnu.dtype: cutlass.Float8E8M0FNU, - jnp.float8_e5m2.dtype: cutlass.Float8E5M2, - jnp.float8_e4m3.dtype: cutlass.Float8E4M3, - jnp.float8_e4m3fn.dtype: cutlass.Float8E4M3FN, - jnp.float8_e4m3b11fnuz.dtype: cutlass.Float8E4M3B11FNUZ, - jnp.float4_e2m1fn.dtype: cutlass.Float4E2M1FN, -} - -DEFAULT_CUTLASS_DEVICE_MEMSPACE = AddressSpace.gmem -DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256 - - -@jax.tree_util.register_dataclass -@dataclass(frozen=True) -class TensorMode: - """Provides a specification of cute.Tensor modes and additional metadata about - dynamic/static modes. - - Arguments: - mode : Specifies the position of each mode in the tensor (M0, M1, ... MN) - """ - - mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) - # Indicates the shape and strides will be defined statically. Enabling may enable - # additional optimization. Kernels that do not support static shapes will generate - # compile errors if this is enabled so we leave it off by default. - static: bool = field(metadata=dict(static=True), default=None) - # Overrides the default pointer alignment. Generally this should not be changed - # but is left here to provide a hook. - ptr_assumed_align: int = field( - metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT - ) - - def __post_init__(self): - if self.mode is not None: - if len(self.mode) != len(set(self.mode)): - raise ValueError( - f"Invalid mode {self.mode} contains duplicate entries." - ) - for m in self.mode: - if m < 0 or m >= len(self.mode): - raise ValueError( - f"Invalid mode {self.mode} contains out of range entires." - ) - if ( - self.ptr_assumed_align <= 0 - or not math.log2(self.ptr_assumed_align).is_integer() - ): - raise ValueError( - f"Invalid pointer alignment {self.ptr_assumed_align} must be power of 2." - ) - - -def row_major_layout(shaped): - """Returns a row major layout given a shaped value. - - Row major layout is (N-1, N-2, ... 1, 0) for an N-dimensional tensor. - """ - return tuple(reversed(range(len(shaped.shape)))) - - -def default_tensor_mode(shaped) -> TensorMode: - """Returns a default tensor mode given a shaped value. - - Default tensor mode is (0, 1, ... N-2, N-1) for an N_dimensional tensor. - """ - return TensorMode(tuple(range(len(shaped.shape)))) - - -def jax_to_cutlass_dtype(dtype): - """Gets the corresponding cutlass dtype given a jax dtype.""" - dtype = jnp.dtype(dtype) - if dtype not in JAX_DTYPE_TO_CUTLASS_DTYPE: - raise ValueError(f"Jax dtype [{dtype}] has no equivalent cutlass dtype.") - return JAX_DTYPE_TO_CUTLASS_DTYPE[dtype] - - -def from_dlpack(buffer, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT): - """Convert device buffer to runtime Tensor.""" - return _from_dlpack(buffer, assumed_align=assumed_align) - - -class _JaxArrayBase(cute.Pointer): - """Base class for the JaxArray and JaxRuntimeArray types.""" - - def __init__( - self, - ptr: cute.Pointer, - shape: tuple[int, ...], - order: tuple[int, ...] | None = None, - ): - self.ptr = ptr - self._shape = tuple(shape) - if order is None: - order = tuple(reversed(range(len(self._shape)))) - if len(order) != len(shape): - raise ValueError(f"order must be same length as shape", order, shape) - for s in order: - if s < 0 or s > len(self._shape): - raise ValueError(f"Invalid index {s} in stride order", order, shape) - if len(tuple(set(order))) != len(order): - raise ValueError(f"order has duplicate indices", order) - self._order = tuple(order) - - @property - def shape(self) -> tuple[int, ...]: - """Returns physical shape of this jax array.""" - return self._shape - - @property - def ndim(self) -> int: - return len(self._shape) - - @property - def order(self) -> tuple[int, ...]: - """Returns stride order (layout) of this jax array.""" - return self._order - - @property - def dtype(self) -> Type[Numeric]: - """Returns cute dtype of this jax array.""" - return self.ptr.dtype - - @property - def memspace(self): - """Returns the address space of this jax array.""" - return self.ptr.memspace - - -class JaxArray(_JaxArrayBase): - """Represents a jax.Array IR value passed to a cute kernel or function. - - The JaxArray is a shaped pointer with physical dimension specified by the Jax program. - By default the data is assumed to follow row-major layout but a custom order - (e.g. column-major) can also be used. - - e.g. (8, 4, 2) row-major strides are (8, 2, 1) - - JaxArray always have statically know shapes and strides. - """ - - def __init__( - self, - ptr: cute.Pointer, - shape: tuple[int, ...], - order: tuple[int, ...] | None = None, - ): - """Creates a Jax array from a cute.Pointer and shape/stride information. - - Args: - ptr: The typed pointer. - shape: A tuple of shape dimensions associated with the . - order: An optional ordering of the dimensions in shape. If None the - shape is assumed to be row-major. - """ - if not hasattr(ptr, "value") or not isinstance(ptr.value, ir.Value): - raise ValueError("not an ir.Value", ptr) - super().__init__(ptr, shape, order) - - # - # Compile Time IR Value Properties - # - # These methods allow JaxArray look like core cute.Pointer. The ptr must - # be a cute.Pointer value - - @property - def value(self) -> cute.Pointer: - return self.ptr.value - - @property - def type(self) -> ir.Type: - return self.ptr.type - - @property - def alignment(self) -> int: - return self.ptr.alignment - - @property - def max_alignment(self) -> int: - return self.ptr.max_alignment - - def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: - return self.ptr.llvm_ptr(loc, ip) - - def __add__(self, offset: IntTuple) -> "JaxArray": - return JaxArray(self.ptr + offset, self._shape, self._order) - - def toint(self, *, loc=None, ip=None): - return self.ptr.toint() - - def align(self, min_align: int, *, loc=None, ip=None) -> "JaxArray": - return JaxArray(self.ptr.align(min_align, loc, ip), self._shape, self._order) - - def get_layout( - self, - mode: tuple[int, ...] | TensorMode = None, - use_static_tensors: bool = False, - *, - loc=None, - ip=None, - ) -> cute.Layout: - """Create a cute.Layout from this JaxArray. - - Physical: (I, J, K) strides are (J*K, K, 1) in row-major order. - - mode = (2, 0, 1) : shape becomes (K, I, J) strides become (1, J*K, K) - mode = (1, 2, 0) : shape becomes (J, K, I) strides become (K, 1, J*K) - - :param mode: Maps the physical shape dimension to logical shape dimensions. If not given the physical layout is used. - :type tuple[int,...]: Tuple that is same size as shape. - """ - if isinstance(mode, (tuple, list)): - mode = TensorMode(mode, static=use_static_tensors) - - if (mode.static is None and use_static_tensors) or mode.static: - shape = self._shape - else: - shape = [cutlass.as_numeric(m) for m in self._shape] - - layout = cute.make_ordered_layout(tuple(shape), self._order, loc=loc, ip=ip) - if mode is not None and mode.mode is not None: - layout = cute.select(layout, mode.mode) - return layout - - def get_tensor( - self, - mode: tuple[int, ...] | TensorMode = None, - use_static_tensors: bool = False, - *, - loc=None, - ip=None, - ) -> cute.Tensor: - """Create a cute.Tensor from this JaxArray. - - :param mode: Maps the physical shape dimension to logical shape dimensions. If not given the physical layout is used. - :param use_static_tensors: Defaults tensor shape and stride to static if no mode is given. - :type tuple[int,...]: Tuple that is same size as shape. - :see get_layout - """ - layout = self.get_layout(mode, use_static_tensors, loc=loc, ip=ip) - return cute.make_tensor(self.ptr, layout) - - # Utility methods - - def __str__(self) -> str: - return f"JaxArray<{self.ptr}:{self.shape}:{self.order}>" - - def __repr__(self) -> str: - return str(self) - - # DynamicExpression Protocol - - def __extract_mlir_values__(self): - return [self.ptr.value] - - def __new_from_mlir_values__(self, values): - return JaxArray( - self.ptr.__new_from_mlir_values__(values), self._shape, self._order - ) - - -class JaxRuntimeArray(_JaxArrayBase): - """Runtime equivalent of jax.Array.""" - - def __init__( - self, - ptr: cute.Pointer | int, - shape: tuple[int, ...], - order: tuple[int, ...] | None = None, - ): - super().__init__(ptr, shape, order) - - @property - def alignment(self) -> int: - return self.ptr._assumed_align - - def __str__(self) -> str: - return f"JaxRuntimeArray<{self.ptr}:{self.shape}:{self.order}>" - - def __repr__(self) -> str: - return str(self) - - # JitArgument Protocol - - def __c_pointers__(self): - return self.ptr.__c_pointers__() - - def __get_mlir_types__(self): - return self.ptr.__get_mlir_types__() - - -def make_runtime_array( - value: Union[int, ctypes._Pointer], - dtype: Type[Numeric], - shape: tuple[int, ...], - order: tuple[int, ...] | None = None, - mem_space: AddressSpace = AddressSpace.generic, - assumed_align=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT, -): - """Creates a JaxRuntimeArray and its underlying pointer.""" - ptr = cute.runtime.make_ptr(dtype, value, mem_space, assumed_align) - return JaxRuntimeArray(ptr, shape, order) - - -def make_placeholder_array( - dtype: Type[Numeric], - shape: tuple[int, ...], - order: tuple[int, ...] | None = None, - mem_space: AddressSpace = AddressSpace.generic, - assumed_align=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT, -): - """Creates a JaxRuntimeArray that can be used as a placeholder for cute.compile.""" - # n.b. 0 causes issues with c_types so we use a non-zero address that is aligned - # The value should not matter but we do it for good measure. - addr = 2 ** int(math.ceil(math.log2(assumed_align))) - return make_runtime_array(addr, dtype, shape, order, mem_space, assumed_align) - - -class JaxArrayList: - """Holds list of JaxArray or JaxRuntimeArray. - This class facilitates conversion of JaxRuntimeArray to JaxArray when crossing - the jit boundary. - """ - - def __init__(self, arrays: Sequence[JaxArray]): - self.arrays = tuple(arrays) - - def __getitem__(self, idx): - return self.arrays[idx] - - def __len__(self): - return len(self.arrays) - - def __iter__(self): - return iter(self.arrays) - - def __c_pointers__(self): - return [x.__c_pointers__()[0] for x in self.arrays] - - def __get_mlir_types__(self): - return [x.__get_mlir_types__()[0] for x in self.arrays] - - def __extract_mlir_values__(self): - return [x.__extract_mlir_values__()[0] for x in self.arrays] - - def __new_from_mlir_values__(self, values): - return JaxArrayList( - [JaxArray(v, x.shape, x.order) for x, v in zip(self.arrays, values)] - ) diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py deleted file mode 100644 index e20c6a7a2..000000000 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -__version_info__ = (0, 3, 0) -__version__ = ".".join(str(v) for v in __version_info__) diff --git a/.github/container/cutlass_dsl_jax/tests/__init__.py b/.github/container/cutlass_dsl_jax/tests/__init__.py deleted file mode 100644 index 070b8c0d7..000000000 --- a/.github/container/cutlass_dsl_jax/tests/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/.github/container/cutlass_dsl_jax/tests/benchmark.py b/.github/container/cutlass_dsl_jax/tests/benchmark.py deleted file mode 100644 index d0d0a2bf7..000000000 --- a/.github/container/cutlass_dsl_jax/tests/benchmark.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from itertools import groupby -from contextlib import contextmanager -from collections import defaultdict - -import jax.numpy as jnp - - -def cupti_profile(f): - """Profiles a callable `f` and returns CUPTI profiled timings. - - Returns (pytree_result_f, timings) - """ - - from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - def wrapped(*args): - result = None - timings = None - try: - ext = mosaic_gpu_lib._mosaic_gpu_ext - ext._cupti_init() - result = f(*args) - finally: - timings = ext._cupti_get_timings(True) - return result, timings - - return wrapped - - -class BenchmarkRunner: - def __init__(self, num_iterations): - self.num_iterations = num_iterations - - @property - def enabled(self): - return self.num_iterations > 0 - - def __call__(self, fn, *args, **kwargs): - """Calls the given function num_iterations times.""" - out = None # keep only last output - for _ in range(self.num_iterations): - out = fn(*args, **kwargs) - return out - - def __iter__(self): - """Returns an iterable for num_iterations.""" - return iter(range(self.num_iterations)) - - -@contextmanager -def cupti_benchmark_profiler_runner_context(request, filename, collector, iter_count): - """A context manager for collecting benchmark data with CUPTI for a number of iterations.""" - try: - from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - if iter_count > 0: - ext = mosaic_gpu_lib._mosaic_gpu_ext - ext._cupti_init() - yield BenchmarkRunner(iter_count) - finally: - if collector.enabled: - timings = ext._cupti_get_timings(True) - collector.record_timings(request, filename, timings) - - -class BenchmarkCollector: - def __init__(self, enabled, default_benchmark_iters=16): - # file name -> result dict - self.enabled = enabled - self.results = defaultdict(lambda: defaultdict(list)) - self.default_benchmark_iters = default_benchmark_iters - self.request = None - - def set_current_request(self, request): - """Sets the current pytest request.""" - self.request = request - - def _write_one_benchmark_result_csv(self, filename, results): - with open(filename, "w") as fp: - for key in results: - gkey = lambda x: x[0] - for kernel_key, group in groupby( - sorted(results[key], key=gkey), key=gkey - ): - # header - for key_entry in key: - if not isinstance(key_entry[1], (list, tuple)): - fp.write(f"{key_entry[1]},") - else: - for x in key_entry[1]: - if isinstance(x, jnp.dtype): - fp.write(f"{x.__nane__},") - else: - fp.write(f"{x},") - # data - values = list([x[1] for x in group]) - total, count, minv, maxv = ( - sum(values), - len(values), - min(values), - max(values), - ) - fp.write(f"{kernel_key},{count},{total/count},{minv},{maxv}\n") - - def save_csv(self): - """Save the recorded benchmark data files.""" - for filename in self.results: - self._write_one_benchmark_result_csv(filename, self.results[filename]) - - def _benchmark_key(self, request): - key = [("name", request.node.name)] - arg_names = sorted(list(request.node.callspec.params.keys())) - for arg in arg_names: - key.append((arg, request.node.callspec.params[arg])) - return tuple(key) - - def record_timings(self, request, filename, timings): - """Records the timings from the request to a specific file.""" - if not self.enabled: - raise RuntimeError("Collection is not enabled.") - key = self._benchmark_key(request) - self.results[filename][key].extend(timings) - - def runner(self, filename, num_iters=None): - """Returns a `cupti_benchmark_profiler_runner_context` for collecting. - - with benchmark.runner(request, "blackwell_dense_gemm.txt") as runner: - runner(launch, a, b) - """ - if self.request is None: - raise RuntimeError("No request was set.") - if num_iters is None: - num_iters = self.default_benchmark_iters - if not self.enabled: - num_iters = 0 - return cupti_benchmark_profiler_runner_context( - self.request, filename, self, num_iters - ) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py b/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py deleted file mode 100644 index 070b8c0d7..000000000 --- a/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py deleted file mode 100644 index 59a477855..000000000 --- a/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from collections import defaultdict -from typing import List, Type, Tuple, Union, Optional -import os - -import pytest -import jax -import jax.numpy as jnp - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils - -from jax_cutlass import cutlass_call, jax_to_cutlass_dtype - -from ..tensor import ( - create_a_tensor, - create_b_tensor, - create_cd_tensor, - gemm_a_mode, - gemm_b_mode, - gemm_c_mode, - gemm_c_shape, - gemm_reference_einsum, -) - -from blackwell.dense_blockscaled_gemm_persistent import ( - Sm100BlockScaledPersistentDenseGemmKernel, -) - - -@pytest.mark.parametrize( - "problem_size", - [ - pytest.param((8 * 1024, 8 * 1024, 8 * 1024, 1), id="M8092-N8092-K8092-L1"), - pytest.param((8 * 1024, 4 * 1024, 4 * 1024, 1), id="M8092-N4096-K4096-L1"), - pytest.param((16 * 1024, 16 * 1024, 16 * 1024, 1), id="M16K-N16K-K16-L1"), - ], -) -@pytest.mark.parametrize( - "mma_tile_shape_mn", - [ - pytest.param((128, 128), id="MMA_128x128"), - # pytest.param((256, 128), id="MMA_256x128"), - # pytest.param((256, 256), id="MMA_256x256"), - ], -) -@pytest.mark.parametrize( - "is_2sm, cluster_shape_mn", - [ - # pytest.param(False, (1, 1), id="1SM-1x1"), - pytest.param(False, (2, 1), id="1SM-2x1"), - # pytest.param(False, (2, 2), id="1SM-2x2"), - # pytest.param(False, (4, 1), id="1SM-4x1"), - pytest.param(True, (2, 1), id="2SM-2x1"), - # pytest.param(True, (2, 2), id="2SM-2x2"), - # pytest.param(True, (4, 1), id="2SM-4x1"), - ], -) -@pytest.mark.parametrize( - "ab_dtype, c_dtype, sf_dtype, sf_vec_size", - [ - pytest.param( - "float4_e2m1fn", "float16", "float8_e8m0fnu", 16, id="mxfp4xmxfp4xf16" - ), - pytest.param( - "float4_e2m1fn", "float16", "float8_e4m3fn", 16, id="nvfp4xnvfp4xf16" - ), - ], -) -@pytest.mark.parametrize( - "a_major, b_major, c_major", - [ - # n.b. only k major a/b is supported by this test fixture. - pytest.param("k", "k", "n", id="kkn_major"), - ], -) -@pytest.mark.requires_device("B200") -def test_dense_block_scaled_gemm( - benchmark, - problem_size, - mma_tile_shape_mn, - is_2sm, - cluster_shape_mn, - ab_dtype, - c_dtype, - sf_dtype, - sf_vec_size, - a_major, - b_major, - c_major, -): - def ceil_div(a, b): - return (a + b - 1) // b - - m, n, k, l = problem_size - sf_k = ceil_div(k, sf_vec_size) - - if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( - jax_to_cutlass_dtype(ab_dtype), - jax_to_cutlass_dtype(sf_dtype), - sf_vec_size, - jax_to_cutlass_dtype(c_dtype), - mma_tile_shape_mn, - cluster_shape_mn, - m, - n, - k, - l, - a_major, - b_major, - c_major, - ): - pytest.skip( - f"Sm100BlockScaledPersistentDenseGemmKernel does not support test config." - ) - - if not is_2sm and mma_tile_shape_mn[0] not in (64, 128): - pytest.skip(f"Skipping {is_2sm=} {mma_tile_shape_mn=}") - - akey, asfkey, bkey, bsfkey = jax.random.split(jax.random.key(1337), 4) - a = create_a_tensor(l, m, k, a_major, ab_dtype, akey, minval=-1.0, maxval=1.0) - b = create_b_tensor(l, n, k, b_major, ab_dtype, bkey, minval=-2.0, maxval=2.0) - - assert a_major == "k", "a_major must be k" - assert b_major == "k", "b_major must be k" - - # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x - # Scale factors are using .scale_vec::4X / .block16 config to support nvfp4 and mxfp4 - atom_mn = (32, 4) - atom_k = 4 - - sfa = create_a_tensor(l, m, sf_k, a_major, sf_dtype, asfkey, minval=1.0, maxval=3.0) - sfa_ref = sfa - sfa = sfa.reshape( - l, - ceil_div(m, atom_mn[0] * atom_mn[1]), - atom_mn[1], - atom_mn[0], - ceil_div(sf_k, atom_k), - atom_k, - ) - # TODO: See if we can pass this layout mapping from jax primitive (it requires grouping) - sfa = sfa.transpose(0, 1, 4, 3, 2, 5) - - sfb = create_b_tensor(l, n, sf_k, b_major, sf_dtype, bsfkey, minval=1.0, maxval=3.0) - sfb_ref = sfb - sfb = sfb.reshape( - l, - ceil_div(n, atom_mn[0] * atom_mn[1]), - atom_mn[1], - atom_mn[0], - ceil_div(sf_k, atom_k), - atom_k, - ) - sfb = sfb.transpose(0, 1, 4, 3, 2, 5) - - gemm = Sm100BlockScaledPersistentDenseGemmKernel( - sf_vec_size, - mma_tile_shape_mn, - cluster_shape_mn, - ) - - hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1] - ) - - def launch(a, b, sfa, sfb): - call = ( - lambda stream, a, b, sfa, sfb, c, *, max_active_clusters, epilogue_op: gemm( - a, b, sfa, sfb, c, max_active_clusters, stream, epilogue_op - ) - ) - return cutlass_call( - call, - input_mode=(gemm_a_mode(a_major), gemm_b_mode(b_major), None, None), - output_mode=(gemm_c_mode(c_major),), - output_shape_dtype=jax.ShapeDtypeStruct( - gemm_c_shape(l, m, n, c_major), c_dtype - ), - epilogue_op=lambda x: x, - max_active_clusters=max_active_clusters, - )(a, b, sfa, sfb) - - c = launch(a, b, sfa, sfb) - - c_ref = gemm_reference_einsum( - a, - b, - acc_dtype=jnp.float16, - c_dtype=c_dtype, - a_major=a_major, - b_major=b_major, - c_major=c_major, - sf_a=sfa_ref, - sf_b=sfb_ref, - ) - - assert jnp.allclose(c, c_ref) - - with benchmark.runner("blackwell_dense_block_scaled_gemm.txt") as runner: - runner(launch, a, b, sfa, sfb) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py deleted file mode 100644 index 72756c934..000000000 --- a/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -from collections import defaultdict -from typing import List, Type, Tuple, Union, Optional -import os - -import pytest -import jax -import jax.numpy as jnp - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils - -from jax_cutlass import cutlass_call, jax_to_cutlass_dtype, TensorMode as T - -from ..tensor import ( - create_a_tensor, - create_b_tensor, - create_cd_tensor, - gemm_a_mode, - gemm_b_mode, - gemm_c_mode, - gemm_c_shape, - gemm_reference_einsum, -) - -from blackwell.dense_gemm_persistent import PersistentDenseGemmKernel - - -@pytest.mark.parametrize( - "problem_size", - [ - pytest.param((8 * 1024, 8 * 1024, 8 * 1024, 1), id="M8092-N8092-K8092-L1"), - # pytest.param((8 * 1024, 4 * 1024, 4 * 1024, 1), id="M8092-N4096-K4096-L1"), - # pytest.param((16 * 1024, 16 * 1024, 16 * 1024, 1), id="M16K-N16K-K16-L1"), - ], -) -@pytest.mark.parametrize( - "mma_tile_shape_mn", - [ - pytest.param((128, 128), id="MMA_128x128"), - # pytest.param((256, 128), id="MMA_256x128"), - # pytest.param((256, 256), id="MMA_256x256"), - ], -) -@pytest.mark.parametrize( - "is_2sm, cluster_shape_mn", - [ - pytest.param(False, (1, 1), id="1SM-1x1"), - # pytest.param(False, (2, 1), id="1SM-2x1"), - # pytest.param(False, (2, 2), id="1SM-2x2"), - # pytest.param(False, (4, 1), id="1SM-4x1"), - pytest.param(True, (2, 1), id="2SM-2x1"), - # pytest.param(True, (2, 2), id="2SM-2x2"), - # pytest.param(True, (4, 1), id="2SM-4x1"), - ], -) -@pytest.mark.parametrize( - "use_tma_store", - [ - pytest.param(False, id="NTS"), - pytest.param(True, id="TS"), - ], -) -@pytest.mark.parametrize( - "a_dtype, b_dtype, c_dtype, acc_dtype", - [ - pytest.param( - "float16", "float16", "float16", "float32", id="bf16xbf16xbf16xfp32" - ), - pytest.param( - "float8_e4m3fn", "float8_e4m3fn", "float16", "float32", id="fp8xfp8xf16xf32" - ), - ], -) -@pytest.mark.parametrize( - "a_major, b_major, c_major", - [ - pytest.param("k", "k", "n", id="kkn_major"), - # pytest.param("m", "n", "n", id="mnn_major"), - # pytest.param("m", "n", "m", id="mnm_major"), - ], -) -@pytest.mark.requires_device("B200") -def test_dense_gemm( - benchmark, - problem_size, - mma_tile_shape_mn, - is_2sm, - cluster_shape_mn, - use_tma_store, - a_dtype, - b_dtype, - c_dtype, - acc_dtype, - a_major, - b_major, - c_major, -): - if not is_2sm and mma_tile_shape_mn[0] not in (64, 128): - pytest.skip(f"Skipping {is_2sm=} {mma_tile_shape_mn=}") - - m, n, k, l = problem_size - - akey, bkey = jax.random.split(jax.random.key(1337), 2) - a = create_a_tensor(l, m, k, a_major, a_dtype, akey) - b = create_b_tensor(l, n, k, b_major, b_dtype, bkey) - - hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1] - ) - gemm = PersistentDenseGemmKernel( - jax_to_cutlass_dtype(acc_dtype), - is_2sm, - mma_tile_shape_mn, - cluster_shape_mn, - use_tma_store, - ) - call = lambda stream, a, b, c, **kwargs: gemm( - a, b, c, max_active_clusters, stream, **kwargs - ) - - def launch(a, b): - return cutlass_call( - call, - input_mode=(gemm_a_mode(a_major), gemm_b_mode(b_major)), - output_mode=(gemm_c_mode(c_major),), - output_shape_dtype=jax.ShapeDtypeStruct( - gemm_c_shape(l, m, n, c_major), c_dtype - ), - epilogue_op=lambda x: x, - )(a, b) - - c = launch(a, b) - c_ref = gemm_reference_einsum( - a, - b, - acc_dtype=acc_dtype, - c_dtype=c_dtype, - a_major=a_major, - b_major=b_major, - c_major=c_major, - ) - assert jnp.allclose(c, c_ref) - - with benchmark.runner("blackwell_dense_gemm.txt") as runner: - runner(launch, a, b) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py deleted file mode 100644 index 6bed187c1..000000000 --- a/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial, reduce -from collections import defaultdict -import pytest -import jax -import jax.numpy as jnp -import os - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute -import cutlass.utils as utils - -from jax_cutlass import cutlass_call, jax_to_cutlass_dtype, TensorMode as TM - -from ..tensor import ( - create_a_tensor, - create_b_tensor, - create_cd_tensor, - gemm_reference_einsum, - gemm_a_mode, - gemm_b_mode, - gemm_c_mode, -) - -# Import from cutlass examples -from blackwell.grouped_gemm import GroupedGemmKernel - -# Needed for int64 types -jax.config.update("jax_enable_x64", True) - - -class JaxGroupGemmKernel: - """A Jax wrapper around GroupGemmKernel. - - The jax flavor of group gemm takes as input a single unified tensor and runs an aux - kernel to extract the addresses of the groups. This allows the use of the existing - group gemm kernel from cutlass w/o modification. - """ - - def __init__( - self, - a_mode, - b_mode, - c_mode, - group_count, - acc_dtype, - is_2sm, - mma_tile_shape_mn, - cluster_shape_mn, - tensormap_update_mode, - num_tensormap_buffers, - max_active_clusters, - total_num_clusters, - ): - self._gemm = GroupedGemmKernel( - jax_to_cutlass_dtype(acc_dtype), - is_2sm, - mma_tile_shape_mn, - cluster_shape_mn, - tensormap_update_mode, - ) - self._a_mode = a_mode - self._b_mode = b_mode - self._c_mode = c_mode - self._group_count = group_count - self._num_tensormap_buffers = num_tensormap_buffers - self._max_active_clusters = max_active_clusters - self._total_num_clusters = total_num_clusters - - @partial(jax.jit, static_argnums=[0], donate_argnums=[3]) - def __call__( - self, - tensor_a, - tensor_b, - tensor_c, - group_offsets, - problem_sizes_mnkl, - strides_abc, - ): - - # Storage for tensormap in gmem - tensormap = jnp.zeros( - ( - self._num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ), - dtype=jnp.int64, - ) - - # Storage for pointer offsets to each tensor. - ptrs_abc = jnp.zeros((group_offsets.shape[0], 3), jnp.int64) - - c, tmap, ptrs = cutlass_call( - fn=self.launch, - output_shape_dtype=(tensor_c, tensormap, ptrs_abc), - input_output_aliases={2: 0, 7: 1, 6: 2}, - group_count=self._group_count, - total_num_clusters=self._total_num_clusters, - max_active_clusters=self._max_active_clusters, - input_mode=( - self._a_mode, - self._b_mode, - self._c_mode, - None, - None, - None, - None, - None, - ), - output_mode=(self._c_mode, None, None), - use_static_tensors=True, - )( - tensor_a, - tensor_b, - tensor_c, - group_offsets, - problem_sizes_mnkl, - strides_abc, - ptrs_abc, - tensormap, - ) - return c - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - initial_a: cute.Tensor, - initial_b: cute.Tensor, - initial_c: cute.Tensor, - group_offsets: cute.Tensor, - problem_shape_mnkl: cute.Tensor, - strides_abc: cute.Tensor, - tensor_address_abc: cute.Tensor, - tensormap_cute_tensor: cute.Tensor, - *, - group_count: cutlass.Constexpr[int], - total_num_clusters: cutlass.Constexpr[int], - max_active_clusters: cutlass.Constexpr[int], - ): - extract_tensor_address_kernel( - group_offsets, initial_a, initial_b, initial_c, tensor_address_abc - ).launch( - stream=stream, grid=[tensor_address_abc.shape[0], 1, 1], block=[1, 1, 1] - ) - - self._gemm( - initial_a, - initial_b, - initial_c, - group_count, - problem_shape_mnkl, - strides_abc, - tensor_address_abc, - total_num_clusters, - tensormap_cute_tensor, - max_active_clusters, - stream, - ) - - @cute.kernel - def extract_tensor_address_kernel( - group_offsets: cute.Tensor, - tensor_a: cute.Tensor, - tensor_b: cute.Tensor, - tensor_c: cute.Tensor, - dst: cute.Tensor, - ): - # mkl, nkl, mnl - bidx, _, _ = cute.arch.block_idx() - - num_groups = group_offsets.shape[0] - group_offset = group_offsets[bidx] - per_expert_size = tensor_b.shape[0] // num_groups - - a_offset = ( - cute.Int64(group_offset) - * tensor_a.stride[0] - * tensor_a.element_type.width - // 8 - ) - a_ptr = tensor_a.iterator.toint() + a_offset - dst[bidx, 0] = a_ptr - - b_offset = ( - cute.Int64(bidx) - * per_expert_size - * tensor_b.stride[0] - * tensor_b.element_type.width - // 8 - ) - b_ptr = tensor_b.iterator.toint() + b_offset - dst[bidx, 1] = b_ptr - - c_offset = ( - cute.Int64(group_offset) - * tensor_c.stride[0] - * tensor_c.element_type.width - // 8 - ) - c_ptr = tensor_c.iterator.toint() + c_offset - dst[bidx, 2] = c_ptr - - -@partial(jax.jit, static_argnums=[0, 1, 3]) -def generate_group_sizes( - expert_count, token_count, key, uniform_group_size=False, round_group_sizes=8 -): - if uniform_group_size: - return jnp.array([token_count // expert_count] * expert_count) - round_group_sizes = float(round_group_sizes) - key1, key2 = jax.random.split(key, 2) - v = jax.random.truncated_normal(key1, -2.0, 2.0, expert_count) + 2.0 - expert_probs = v / jnp.sum(v) - expert_assignment = jax.random.choice( - key2, expert_count, (token_count,), p=expert_probs - ) - group_sizes = jnp.bincount(expert_assignment, length=expert_count) - group_sizes = round_group_sizes * jnp.floor( - group_sizes.astype(jnp.float32) / round_group_sizes - ) - group_sizes = group_sizes.at[0].add(token_count - group_sizes.sum()) - return group_sizes.astype(jnp.int32) - - -@pytest.mark.parametrize( - "uniform_groups", - [pytest.param(True, id="UNIFORM"), pytest.param(False, id="RANDOM")], -) -@pytest.mark.parametrize( - "problem_size", - [ - pytest.param( - (16, 8 * 1024, int(1.5 * 1024), 3 * 1024, 1), id="E16-M8192-N1536-K3072-L1" - ), - pytest.param( - (128, 32 * 1024, int(1.5 * 1024), 2048, 1), id="E128-M32768-N1536-K2048-L1" - ), - ], -) -@pytest.mark.parametrize( - "tensormap_update_mode", - [ - # pytest.param(utils.TensorMapUpdateMode.GMEM, id="GMEM"), - pytest.param(utils.TensorMapUpdateMode.SMEM, id="SMEM"), - ], -) -@pytest.mark.parametrize( - "mma_tile_shape_mn", - [ - pytest.param((128, 128), id="MMA_128x128"), - # pytest.param((256, 128), id="MMA_256x128"), - # pytest.param((256, 256), id="MMA_256x256"), - ], -) -@pytest.mark.parametrize( - "is_2sm, cluster_shape_mn", - [ - pytest.param(False, (1, 1), id="1SM-1x1"), - # pytest.param(False, (2, 1), id="1SM-2x1"), - # pytest.param(False, (2, 2), id="1SM-2x2"), - # pytest.param(False, (4, 1), id="1SM-4x1"), - pytest.param(True, (2, 1), id="2SM-2x1"), - # pytest.param(True, (2, 2), id="2SM-2x2"), - # pytest.param(True, (4, 1), id="2SM-4x1"), - ], -) -@pytest.mark.parametrize( - "a_dtype, b_dtype, c_dtype, acc_dtype", - [ - pytest.param( - jnp.float16, - jnp.float16, - jnp.float16, - jnp.float32, - id="bf16xbf16xbf16xfp32", - ), - pytest.param( - jnp.float8_e4m3fn, - jnp.float8_e4m3fn, - jnp.float16, - jnp.float32, - id="fp8xfp8xf16xf32", - ), - ], -) -@pytest.mark.parametrize( - "a_major, b_major, c_major", - [ - pytest.param("k", "k", "n", id="kkn_major"), - # pytest.param("k", "n", "n", id="knn_major"), - ], -) -@pytest.mark.requires_device("B200") -def test_grouped_gemm( - benchmark, - problem_size, - uniform_groups, - mma_tile_shape_mn, - cluster_shape_mn, - tensormap_update_mode, - is_2sm, - a_dtype, - b_dtype, - c_dtype, - acc_dtype, - a_major, - b_major, - c_major, -): - key = jax.random.key(1337) - - num_groups, m, n, k, l = problem_size - - # Skip invalid mma tile shape - if not ( - (not is_2sm and mma_tile_shape_mn[0] in [64, 128]) - or (is_2sm and mma_tile_shape_mn[0] in [128, 256]) - ): - raise pytest.skip(f"Skip invalid mma tiler M {mma_tile_shape_mn[0]}") - - if mma_tile_shape_mn[1] not in range(32, 257, 32): - raise pytest.skip(f"Skip invalid mma tiler N {mma_tile_shape_mn[1]}") - - if m % (mma_tile_shape_mn[0] * cluster_shape_mn[0]) != 0: - pytest.skip(f"Problem too small for M tiling.") - - if n % (mma_tile_shape_mn[1] * cluster_shape_mn[1]) != 0: - pytest.skip(f"Problem too small for N tiling.") - - # Skip illegal cluster shape - if cluster_shape_mn[0] % (2 if is_2sm else 1) != 0: - raise pytest.skip( - f"cluster_shape_m need align with is_2sm config {cluster_shape_mn}" - ) - - tensors_abc = [] - problem_sizes_mnkl = [] - strides_abc = [] - - gkey, key = jax.random.split(key) - group_sizes = generate_group_sizes(num_groups, m, gkey, uniform_groups) - assert group_sizes.sum() == m, "unexpected group sizes" - - # Build separate tensors for each expert. It is expected that the total tokens will - # sum to m. n is uniform across all experts. - for idx in range(num_groups): - sub_m = int(group_sizes[idx]) - akey, bkey, ckey, key = jax.random.split(key, 4) - - tensor_a = create_a_tensor(l, sub_m, k, a_major, a_dtype, akey) - tensor_b = create_b_tensor(l, n, k, b_major, b_dtype, bkey) - tensor_c = create_cd_tensor(l, sub_m, n, c_major, c_dtype, ckey, fill_value=0.0) - tensors_abc.append((tensor_a, tensor_b, tensor_c)) - - stride_mk_a = (k, 1) if a_major == "k" else (1, m) # mkl - stride_nk_b = (k, 1) if b_major == "k" else (1, n * num_groups) # nkl - stride_mn_c = (n, 1) if c_major == "n" else (1, m) # mnl - - strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) - problem_sizes_mnkl.append(((sub_m, n, k, l))) - - # layout (num_groups, 3, 2):(6, 2, 1) - strides_abc_tensor = jnp.array(strides_abc, dtype=jnp.int32) - problem_sizes_mnkl_tensor = jnp.array(problem_sizes_mnkl, dtype=jnp.int32) - group_offsets = jnp.cumsum(group_sizes) - group_sizes - - # get number of SMs by querying max active clusters with 1x1 cluster shape - hardware_info = cutlass.utils.HardwareInfo() - num_sms = hardware_info.get_device_multiprocessor_count() - max_active_clusters = hardware_info.get_max_active_clusters( - cluster_shape_mn[0] * cluster_shape_mn[1] - ) - num_tensormap_buffers = num_sms - - def compute_total_num_clusters(problem_sizes_mnkl, cga_tile_shape_mn): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - (x + y - 1) // y for x, y in zip((m, n), cga_tile_shape_mn) - ) - total_num_clusters += reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - def compute_cga_tile_shape(mma_tile_shape_mn, cluster_shape_mn, is_2sm): - cta_tile_shape_mn = list(mma_tile_shape_mn) - if is_2sm: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cga_tile_shape_mn = compute_cga_tile_shape( - mma_tile_shape_mn, cluster_shape_mn, is_2sm - ) - total_num_clusters = compute_total_num_clusters( - problem_sizes_mnkl, cga_tile_shape_mn - ) - - gemm = JaxGroupGemmKernel( - gemm_a_mode(a_major), - gemm_b_mode(b_major), - gemm_c_mode(c_major), - num_groups, - acc_dtype, - is_2sm, - mma_tile_shape_mn, - cluster_shape_mn, - tensormap_update_mode, - num_tensormap_buffers, - max_active_clusters, - total_num_clusters, - ) - - # Create the combined tensors by concatenating along the appropriate axis - am_axis = gemm_a_mode(a_major)[0] # mkl - bn_axis = gemm_b_mode(b_major)[0] # nkl - cm_axis = gemm_c_mode(c_major)[0] # mnl - tensor_a_device = jnp.concatenate([x[0] for x in tensors_abc], axis=am_axis) - tensor_b_device = jnp.concatenate([x[1] for x in tensors_abc], axis=bn_axis) - tensor_c_device = jnp.concatenate([x[2] for x in tensors_abc], axis=cm_axis) - - # Note: this call setup is a bit tricky because we need to extract addresses - # from tensor_c. To do this we donate tensor_c so we can treat it as both an - # input and output ensuring it has a stable allocation. - tensor_c_device = gemm( - tensor_a_device, - tensor_b_device, - tensor_c_device, - group_offsets, - problem_sizes_mnkl_tensor, - strides_abc_tensor, - ) - - c_ref = [] - for idx in range(num_groups): - c_ref.append( - gemm_reference_einsum( - tensors_abc[idx][0], - tensors_abc[idx][1], - acc_dtype=acc_dtype, - c_dtype=c_dtype, - a_major=a_major, - b_major=b_major, - c_major=c_major, - ) - ) - c_ref = jnp.concatenate(c_ref, axis=cm_axis).astype(jnp.float32) - tensor_c_device = tensor_c_device.astype(jnp.float32) - - # Tolerance from cutedsl tests. - assert jnp.allclose(c_ref, tensor_c_device, atol=0.1) - - with benchmark.runner("blackwell_grouped_gemm.txt") as runner: - for _ in runner: - tensor_c_device = gemm( - tensor_a_device, - tensor_b_device, - tensor_c_device, - group_offsets, - problem_sizes_mnkl_tensor, - strides_abc_tensor, - ) diff --git a/.github/container/cutlass_dsl_jax/tests/conftest.py b/.github/container/cutlass_dsl_jax/tests/conftest.py deleted file mode 100644 index aeee5a3c5..000000000 --- a/.github/container/cutlass_dsl_jax/tests/conftest.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import jax -import sys -import re -from unittest.mock import MagicMock, patch - -from jax_cutlass import release_compile_cache -from .benchmark import cupti_profile, BenchmarkCollector - - -def pytest_configure(config): - config.addinivalue_line("markers", "requires_sm(arg): Specify required SM type.") - - -def pytest_addoption(parser): - parser.addoption("--benchmark_iters", default=16, action="store", type=int) - parser.addoption("--benchmark", action="store_true") - parser.addoption("--check_tracer_leaks", action="store_true") - - -def pytest_sessionstart(session): - # Mock torch so that import of CuteDSL examples does not - # break on platforms without torch. - mock_modules = ("torch", "torch.nn", "torch.nn.functional") - for m in mock_modules: - sys.modules.update({m: MagicMock()}) - - session.stash["collector"] = BenchmarkCollector( - session.config.option.benchmark, session.config.option.benchmark_iters - ) - - if session.config.option.check_tracer_leaks: - jax.check_tracer_leaks(True) - - -def pytest_sessionfinish(session): - session.stash["collector"].save_csv() - - -def pytest_runtest_setup(item): - requires_device = item.get_closest_marker("requires_device") - if requires_device: - arg_value = requires_device.args[0] if requires_device.args else "" - for d in jax.devices(): - if not re.search(arg_value, d.device_kind): - pytest.skip( - f"Skipping test because device {d} is '{d.device_kind}' but requires '{arg_value}'" - ) - - -@pytest.fixture -def benchmark(request): - collector = request.session.stash["collector"] - collector.set_current_request(request) - yield collector - collector.set_current_request(None) - - -@pytest.fixture(scope="function", autouse=True) -def clear_cache_and_live_arrays_after_test(): - yield - jax.clear_caches() - release_compile_cache() - for a in jax.live_arrays(): - a.delete() diff --git a/.github/container/cutlass_dsl_jax/tests/tensor.py b/.github/container/cutlass_dsl_jax/tests/tensor.py deleted file mode 100644 index 3eeaf7be8..000000000 --- a/.github/container/cutlass_dsl_jax/tests/tensor.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial -import jax -import jax.numpy as jnp - - -def reorder_modes(src: str, target: str) -> tuple[int, ...]: - """Computes the mode given a source and target order.""" - src = tuple(src) - target = tuple(target) - src_map = {} - for idx, s in enumerate(src): - src_map[s] = idx - return tuple([src_map[d] for d in target]) - - -def gemm_a_major(d: str): - """Returns order for A tensor major mode.""" - return {"k": "lmk", "m": "lkm"}[d] - - -def gemm_a_mode(d: str) -> tuple[int, ...]: - """Returns mode for A tensor major mode.""" - return reorder_modes(gemm_a_major(d), "mkl") - - -def gemm_b_major(d: str): - """Returns order for B tensor major mode.""" - return {"k": "lnk", "n": "lkn"}[d] - - -def gemm_b_mode(d: str) -> tuple[int, ...]: - """Returns mode for B tensor major mode.""" - return reorder_modes(gemm_b_major(d), "nkl") - - -def gemm_c_major(d: str): - """Returns order for C tensor major mode.""" - return {"n": "lmn", "m": "lnm"}[d] - - -def gemm_c_mode(d: str) -> tuple[int, ...]: - """Returns mode for C tensor major mode.""" - return reorder_modes(gemm_c_major(d), "mnl") - - -def gemm_a_shape(l, m, k, major) -> tuple[int, ...]: - """Returns shape for A tensor given major mode.""" - assert major in ("k", "m") - shape = (l, m, k) if major == "k" else (l, k, m) - return shape - - -def gemm_b_shape(l, n, k, major) -> tuple[int, ...]: - """Returns shape for B tensor given major mode.""" - assert major in ("k", "n") - shape = (l, n, k) if major == "k" else (l, k, n) - return shape - - -def gemm_c_shape(l, m, n, major) -> tuple[int, ...]: - """Returns shape for C tensor given major mode.""" - assert major in ("m", "n") - shape = (l, m, n) if major == "n" else (l, n, m) - return shape - - -def create_tensor( - shape, dtype, key, *, minval=-2.0, maxval=2.0, fill_value=None, fill_arange=False -): - if fill_arange: - tensor = jnp.ones(shape, dtype=dtype) - tensor = tensor * jnp.arange(tensor.size, dtype=tensor.dtype).reshape( - tensor.shape - ) - elif fill_value is not None: - tensor = jnp.full(shape, fill_value, dtype=dtype) - else: - tensor = jax.random.uniform( - key, shape, dtype=jnp.float32, minval=minval, maxval=maxval - ) - tensor = tensor.astype(dtype) - return tensor - - -def create_a_tensor( - l, - m, - k, - major, - dtype, - key, - minval=-2.0, - maxval=2.0, - fill_value=None, - fill_arange=False, -): - shape = gemm_a_shape(l, m, k, major) - tensor = create_tensor( - shape, - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - fill_arange=fill_arange, - ) - return tensor - - -def create_b_tensor( - l, - n, - k, - major, - dtype, - key, - minval=-2.0, - maxval=2.0, - fill_value=None, - fill_arange=False, -): - shape = gemm_b_shape(l, n, k, major) - tensor = create_tensor( - shape, - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - fill_arange=fill_arange, - ) - return tensor - - -def create_cd_tensor( - l, - m, - n, - major, - dtype, - key, - *, - minval=-2.0, - maxval=2.0, - fill_value=None, - fill_arange=False, -): - shape = gemm_c_shape(l, m, n, major) - tensor = create_tensor( - shape, - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - fill_arange=fill_arange, - ) - return tensor - - -def gemm_reference_einsum( - a, - b, - acc_dtype, - c_dtype, - a_major, - b_major, - c_major, - sf_a=None, - sf_b=None, - precision="highest", -): - a_idx = gemm_a_major(a_major) - b_idx = gemm_b_major(b_major) - c_idx = gemm_c_major(c_major) - spec = f"{a_idx},{b_idx}->{c_idx}" - - # If block scaled pre-scale input at higher precision - # Assumes we only use it for fp8 and smaller. - if sf_a is not None: - sf_vec_size = int(a.shape[-1] // sf_a.shape[-1]) - sf_a = jnp.repeat(sf_a, sf_vec_size, axis=-1) - a = a.astype(jnp.float16) * sf_a.astype(jnp.float16) - - if sf_b is not None: - sf_vec_size = int(b.shape[-1] // sf_b.shape[-1]) - sf_b = jnp.repeat(sf_b, sf_vec_size, axis=-1) - b = b.astype(jnp.float16) * sf_b.astype(jnp.float16) - - return jax.jit( - lambda a, b: jnp.einsum( - spec, a, b, preferred_element_type=acc_dtype, precision=precision - ).astype(c_dtype) - )(a, b) - - -def create_attn_tensors( - b, s, hq, hkv, d, dtype, key, *, minval=-2.0, maxval=2.0, fill_value=None -): - qkey, kkey, vkey = jax.random.split(key, 3) - return ( - create_tensor( - (b, s, hq, d), - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - ), - create_tensor( - (b, s, hkv, d), - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - ), - create_tensor( - (b, s, hkv, d), - dtype, - key, - minval=minval, - maxval=maxval, - fill_value=fill_value, - ), - ) - - -def attn_ref(q, k, v, is_causal: bool): - return jax.jit( - lambda q, k, v: jax.nn.dot_product_attention( - q, k, v, is_causal=is_causal, implementation="cudnn" - ) - )(q, k, v) diff --git a/.github/container/cutlass_dsl_jax/tests/test_args.py b/.github/container/cutlass_dsl_jax/tests/test_args.py deleted file mode 100644 index b0f66fbed..000000000 --- a/.github/container/cutlass_dsl_jax/tests/test_args.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple -import pytest -from functools import partial - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute - -import jax -import jax.numpy as jnp - -from jax_cutlass import cutlass_call, TensorMode as TM - -from .tensor import create_tensor - - -class TestConstexprArgs: - @cute.kernel - def kernel( - self, - x: cute.Tensor, - y: cute.Tensor, - z: cute.Tensor, - const_a: cutlass.Constexpr, - const_b: cutlass.Constexpr, - ): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) - frgB = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type) - frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) - - cute.autovec_copy(x[None, tidx, bidx], frgA) - cute.autovec_copy(y[None, tidx, bidx], frgB) - frgC.store(frgA.load() * const_a + frgB.load() * const_b) - cute.autovec_copy(frgC, z[None, tidx, bidx]) - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - a1: cute.Tensor, - b1: cute.Tensor, - c1: cute.Tensor, - *, - const_a: cutlass.Constexpr[float], - const_b: cutlass.Constexpr[float] - ): - self.kernel(a1, b1, c1, const_a, const_b).launch( - grid=[a1.shape[-1], 1, 1], block=[a1.shape[-2], 1, 1], stream=stream - ) - - @partial(jax.jit, static_argnums=[0, 3, 4]) - def ref_call(self, a, b, const_a, const_b): - return a * const_a + b * const_b - - def test(self): - shape = (4, 16, 16) - dtype = jnp.float32 - a_key, b_key = jax.random.split(jax.random.key(1123), 2) - - a = create_tensor(shape, dtype, a_key) - b = create_tensor(shape, dtype, b_key) - - call = partial( - cutlass_call, - self.launch, - output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), - use_static_tensors=True, - ) - c = call(const_a=1.0, const_b=1.0)(a, b) - c_ref = self.ref_call(a, b, 1.0, 1.0) - - c = call(const_a=4.0, const_b=-1.0)(a, b) - c_ref = self.ref_call(a, b, 4.0, -1.0) - - # will use compile cache - c = call(const_a=4.0, const_b=-1.0)(a, b) - c_ref = self.ref_call(a, b, 4.0, -1.0) - - assert jnp.allclose(c, c_ref) - - -class TestListArgs: - - @cute.kernel - def kernel( - self, - a: cute.Tensor, - b: list[cute.Tensor], - c: tuple[cute.Tensor, ...], - ): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - for idx in cutlass.range_constexpr(len(b)): - frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) - cute.autovec_copy(a[None, tidx, bidx], frgA) - frgB = cute.make_rmem_tensor( - cute.size(b[int(idx)], mode=[0]), b[idx].element_type - ) - frgC = cute.make_rmem_tensor( - cute.size(c[idx], mode=[0]), c[idx].element_type - ) - cute.autovec_copy(b[idx][None, tidx, bidx], frgB) - frgC.store(frgA.load() + frgB.load()) - cute.autovec_copy(frgC, c[idx][None, tidx, bidx]) - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - a: cute.Tensor, - b: list[cute.Tensor], - c: tuple[cute.Tensor, ...], - ): - self.kernel(a, b, c).launch( - grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream - ) - - def ref_call(self, a, b): - @partial(jax.jit) - def _call(a, b): - return a + b - - return [_call(a, bi) for bi in b] - - def test(self): - key = jax.random.key(1123) - a_key, *b_keys = jax.random.split(key, 2 + 8) - - shape = (4, 16, 16) - dtype = jnp.bfloat16 - a = create_tensor(shape, dtype, a_key) - b = [create_tensor(shape, dtype, k) for k in b_keys] - c = [jax.ShapeDtypeStruct(shape, dtype) for x in b] - - call = cutlass_call( - self.launch, - output_shape_dtype=(c,), - input_mode=(TM(static=True), [TM(static=True)] * len(b)), - output_mode=[TM(static=True)] * len(b), - ) - (c,) = call(a, b) - - c_ref = self.ref_call(a, b) - for ci, ci_ref in zip(c, c_ref): - assert jnp.allclose(ci, ci_ref) - - -class TestListArgsAlias: - - @cute.kernel - def kernel( - self, - a: cute.Tensor, - b: list[cute.Tensor], - c: tuple[cute.Tensor, ...], - ): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - # Only write to the even lists - for idx in cutlass.range_constexpr(0, len(b), 2): - frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) - cute.autovec_copy(a[None, tidx, bidx], frgA) - frgB = cute.make_rmem_tensor( - cute.size(b[idx], mode=[0]), b[idx].element_type - ) - frgC = cute.make_rmem_tensor( - cute.size(c[idx], mode=[0]), c[idx].element_type - ) - cute.autovec_copy(b[idx][None, tidx, bidx], frgB) - frgC.store(frgA.load() + frgB.load()) - cute.autovec_copy(frgC, c[idx][None, tidx, bidx]) - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - a: cute.Tensor, - b: list[cute.Tensor], - c: tuple[cute.Tensor, ...], - ): - self.kernel(a, b, c).launch( - grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream - ) - - def ref_call(self, a, b): - @partial(jax.jit) - def _call(a, b): - return a + b - - results = [None] * len(b) - for idx, bi in enumerate(b): - if idx % 2 == 0: - results[idx] = _call(a, bi) - else: - results[idx] = jnp.full(bi.shape, idx + 1, bi.dtype) - return results - - def test(self): - key = jax.random.key(1123) - a_key, *b_keys = jax.random.split(key, 2 + 8) - - shape = (4, 16, 16) - dtype = jnp.bfloat16 - a = create_tensor(shape, dtype, a_key) - b = [create_tensor(shape, dtype, k) for k in b_keys] - - # This list of arrays will be updated by the call - c = [jnp.full(shape, idx + 1, dtype) for idx in range(len(b))] - - call = cutlass_call( - self.launch, - output_shape_dtype=(c,), - input_output_aliases={2: 0}, - input_mode=( - TM(static=True), - [TM(static=True)] * len(b), - [TM(static=True)] * len(b), - ), - output_mode=[TM(static=True)] * len(b), - ) - (c,) = call(a, b, c) - - c_ref = self.ref_call(a, b) - for ci, ci_ref in zip(c, c_ref): - assert jnp.allclose(ci, ci_ref) - - -class TestPartialBoundArgs: - @cute.kernel - def kernel( - self, - x: cute.Tensor, - y: cute.Tensor, - z: cute.Tensor, - const_a: cutlass.Constexpr, - const_b: cutlass.Constexpr, - ): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) - frgB = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type) - frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) - - cute.autovec_copy(x[None, tidx, bidx], frgA) - cute.autovec_copy(y[None, tidx, bidx], frgB) - frgC.store(frgA.load() * const_a + frgB.load() * const_b) - cute.autovec_copy(frgC, z[None, tidx, bidx]) - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - a1: cute.Tensor, - b1: cute.Tensor, - c1: cute.Tensor, - *, - const_a: cutlass.Constexpr[float], - const_b: cutlass.Constexpr[float] - ): - self.kernel(a1, b1, c1, const_a, const_b).launch( - grid=[a1.shape[-1], 1, 1], block=[a1.shape[-2], 1, 1], stream=stream - ) - - @partial(jax.jit, static_argnums=[0, 3, 4]) - def ref_call(self, a, b, const_a, const_b): - return a * const_a + b * const_b - - def test(self): - shape = (4, 16, 16) - dtype = jnp.float32 - a_key, b_key = jax.random.split(jax.random.key(1123), 2) - - a = create_tensor(shape, dtype, a_key) - b = create_tensor(shape, dtype, b_key) - - fn = partial(self.launch, const_a=2.0) - - call = partial( - cutlass_call, - fn, - output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), - input_mode=(TM(static=True), TM(static=True)), - output_mode=TM(static=True), - const_b=-3.0, - ) - c = call()(a, b) - c_ref = self.ref_call(a, b, 2.0, -3.0) - - assert jnp.allclose(c, c_ref) - - -class TestCompileOptionsPassing: - @cute.kernel - def kernel(self, x: cute.Tensor, z: cute.Tensor): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) - frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) - - cute.autovec_copy(x[None, tidx, bidx], frgA) - frgC.store(frgA.load()) - cute.autovec_copy(frgC, z[None, tidx, bidx]) - - @cute.jit - def launch( - self, - stream: cuda.CUstream, - a1: cute.Tensor, - c1: cute.Tensor, - ): - self.kernel(a1, c1).launch( - grid=[a1.shape[-1], 1, 1], block=[a1.shape[-2], 1, 1], stream=stream - ) - - def test(self): - shape = (4, 16, 16) - dtype = jnp.float32 - a_key = jax.random.key(1123) - a = create_tensor(shape, dtype, a_key) - - call = cutlass_call( - self.launch, - output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), - input_mode=TM(static=True), - output_mode=TM(static=True), - compile_options="--opt-level=0", - ) - - c = call(a) - assert jnp.allclose(c, a) - - # Combine typed and string - from cutlass.cute import ( - OptLevel, - EnableAssertions, - GenerateLineInfo, - KeepCUBIN, - KeepPTX, - ) - - my_debugging_options = ( - "--opt-level=1", - EnableAssertions, - GenerateLineInfo, - KeepCUBIN, - KeepPTX, - ) - - call = cutlass_call( - self.launch, - output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), - use_static_tensors=True, - compile_options=my_debugging_options, - ) - - c = call(a * 2.0) - assert jnp.allclose(c, a * 2.0) diff --git a/.github/container/cutlass_dsl_jax/tests/test_misc.py b/.github/container/cutlass_dsl_jax/tests/test_misc.py deleted file mode 100644 index d12594ca0..000000000 --- a/.github/container/cutlass_dsl_jax/tests/test_misc.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from functools import partial - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute - -import jax -import jax.numpy as jnp - -from jax_cutlass import cutlass_call - - -@cute.jit -def launch(stream: cuda.CUstream, out: cute.Tensor): - pass - - -def test_vjp_not_allowed(): - with pytest.raises( - NotImplementedError, match=r".*cutlass_call does not support VJP.*" - ): - empty = jnp.zeros(tuple(), jnp.float32) - call = cutlass_call(launch, output_shape_dtype=empty) - jax.value_and_grad(call)(empty) - - -def test_transpose_not_allowed(): - with pytest.raises( - NotImplementedError, match=r".*cutlass_call does not support transpose.*" - ): - empty = jnp.zeros(tuple(), jnp.float32) - call = cutlass_call(launch, output_shape_dtype=empty) - jax.linear_transpose(call, jax.ShapeDtypeStruct(empty.shape, empty.dtype))( - empty - ) - - -def test_vmap_not_allowed(): - with pytest.raises( - NotImplementedError, - match=r".*cutlass_call does not support batching with jax\.vmap.*", - ): - empty = jnp.zeros(tuple(), jnp.float32) - empty_b = jnp.zeros((8,), jnp.float32) - call = cutlass_call(launch, output_shape_dtype=empty) - jax.vmap(call)(empty_b) diff --git a/.github/container/cutlass_dsl_jax/tests/test_sharding.py b/.github/container/cutlass_dsl_jax/tests/test_sharding.py deleted file mode 100644 index e6148dd15..000000000 --- a/.github/container/cutlass_dsl_jax/tests/test_sharding.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from functools import partial - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute - -import jax -import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P - -from jax_cutlass import cutlass_call, TensorMode as TM - -from .tensor import create_tensor - - -@cute.kernel -def kernel( - a: cute.Tensor, - b: cute.Tensor, - c: cute.Tensor, - const_a: cutlass.Constexpr, - const_b: cutlass.Constexpr, -): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) - frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type) - frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type) - - cute.autovec_copy(a[None, tidx, bidx], frgA) - cute.autovec_copy(b[None, tidx, bidx], frgB) - frgC.store(frgA.load() * const_a + frgB.load() * const_b) - cute.autovec_copy(frgC, c[None, tidx, bidx]) - - -@cute.jit -def launch( - stream: cuda.CUstream, - a: cute.Tensor, - b: cute.Tensor, - c: cute.Tensor, - *, - const_a: cutlass.Constexpr, - const_b: cutlass.Constexpr, -): - # these two kernels are launched to the same stream. - kernel(a, b, c, const_a, const_b).launch( - grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream - ) - - -@pytest.mark.parametrize("n", range(3)) -def test_jit_sharding(n): - ngpu = jax.device_count() - mesh = jax.make_mesh((ngpu,), "b") - sharding = P("b", None, None) - - key = jax.random.key(1123 + n) - a_key, b_keys = jax.random.split(key, 2) - - shape = (32 * 8, 32, 32) - dtype = jnp.float32 - a = create_tensor(shape, dtype, a_key) - b = create_tensor(shape, dtype, b_keys) - a = jax.device_put(a, NamedSharding(mesh, sharding)) - b = jax.device_put(b, NamedSharding(mesh, sharding)) - - @partial(jax.jit, static_argnums=(2, 3)) - def compute(a, b, const_a, const_b): - call = cutlass_call( - launch, - output_shape_dtype=jax.ShapeDtypeStruct(a.shape, b.dtype), - input_mode=(TM(static=True), TM(static=True)), - output_mode=TM(static=True), - const_a=const_a, - const_b=const_b, - ) - ref_result = a * const_a + b * const_b - return call(a, b), ref_result - - c, c_ref = compute(a, b, 1.0, 2.0) - assert jnp.allclose(c, c_ref) - - c, c_ref = compute(a, b, 3.0, 4.0) - assert jnp.allclose(c, c_ref) - - c, c_ref = compute(a, b, 1.0, 2.0) - assert jnp.allclose(c, c_ref) - - -@pytest.mark.parametrize("n", range(3)) -def test_shardmap(n): - ngpu = jax.device_count() - mesh = jax.make_mesh((ngpu,), "b") - sharding = P("b", None, None) - - @partial(jax.jit, static_argnums=[0, 1]) - def compute(const_a, const_b): - key = jax.random.key(1123 + n) - a_key, b_keys = jax.random.split(key, 2) - - shape = (32 * 8, 32, 32) - dtype = jnp.float32 - a = create_tensor(shape, dtype, a_key) - b = create_tensor(shape, dtype, b_keys) - a = jax.lax.with_sharding_constraint(a, NamedSharding(mesh, sharding)) - b = jax.lax.with_sharding_constraint(b, NamedSharding(mesh, sharding)) - - @partial( - jax.shard_map, - mesh=mesh, - in_specs=(sharding, sharding), - out_specs=(sharding, sharding), - ) - def sharded_call(a_block, b_block): - call = cutlass_call( - launch, - output_shape_dtype=jax.ShapeDtypeStruct(a_block.shape, a_block.dtype), - input_mode=(TM(static=True), TM(static=True)), - output_mode=TM(static=True), - const_a=const_a, - const_b=const_b, - ) - ref_result = a_block * const_a + b_block * const_b - return call(a_block, b_block), ref_result - - return sharded_call(a, b) - - c, c_ref = compute(1.0, 2.0) - assert jnp.allclose(c, c_ref) - - c, c_ref = compute(3.0, 4.0) - assert jnp.allclose(c, c_ref) - - c, c_ref = compute(1.0, 2.0) - assert jnp.allclose(c, c_ref) diff --git a/.github/container/cutlass_dsl_jax/tests/test_stream.py b/.github/container/cutlass_dsl_jax/tests/test_stream.py deleted file mode 100644 index 72ac51055..000000000 --- a/.github/container/cutlass_dsl_jax/tests/test_stream.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from functools import partial - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute - -import jax -import jax.numpy as jnp - -from jax_cutlass import cutlass_call, TensorMode as TM - -from .tensor import create_tensor - - -@cute.kernel -def kernel( - a: cute.Tensor, - b: cute.Tensor, - c: cute.Tensor, - const_a: cutlass.Constexpr, - const_b: cutlass.Constexpr, -): - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) - frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type) - frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type) - - cute.autovec_copy(a[None, tidx, bidx], frgA) - cute.autovec_copy(b[None, tidx, bidx], frgB) - frgC.store(frgA.load() * const_a + frgB.load() * const_b) - cute.autovec_copy(frgC, c[None, tidx, bidx]) - - -@cute.jit -def launch( - stream: cuda.CUstream, - a: cute.Tensor, - b: cute.Tensor, - c: cute.Tensor, - d: cute.Tensor, -): - # these two kernels are launched to the same stream. - # the second call depends on the first - kernel(a, b, c, 2.0, -3.0).launch( - grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream - ) - kernel(a, c, d, -4.0, 5.0).launch( - grid=[a.shape[-1], 1, 1], block=[a.shape[-2], 1, 1], stream=stream - ) - - -@pytest.mark.parametrize("n", range(3)) -def test_back_to_back(n): - def ref_call(a, b): - c = a * 2.0 + b * -3.0 - d = a * -4.0 + c * 5.0 - return c, d - - shape = (4, 128, 128) - dtype = jnp.float32 - - for i in range(3): - a_key, b_key = jax.random.split(jax.random.key(1123 + i), 2) - - a = create_tensor(shape, dtype, a_key) - b = create_tensor(shape, dtype, b_key) - c, d = cutlass_call( - launch, - output_shape_dtype=((a, b)), - input_mode=[TM(static=True)] * 2, - output_mode=[TM(static=True)] * 2, - )(a, b) - - c_ref, d_ref = ref_call(a, b) - assert jnp.allclose(c, c_ref, atol=1e-6), "c" - assert jnp.allclose(d, d_ref, atol=1e-6), "d" diff --git a/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh b/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh index cd243b10b..9e475d06c 100644 --- a/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh +++ b/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh @@ -1,34 +1,31 @@ - set -xu -o pipefail - - LOG_DIR=${LOG_DIR:-/opt/output} - - SRC_ROOT=${SRC_ROOT:-$PWD/jax-cutlass-src} - SRC_ROOT=$(realpath $SRC_ROOT) - - pip install pytest-reportlog pytest-xdist - - # nvidia-cutlass-dsl-jax is not yet installed to the container by default. - # Clone if not already present locally as indicated by SRC_ROOT - if [[ ! -d ${SRC_ROOT} ]]; then - git clone https://github.com/NVIDIA/JAX-Toolbox.git --branch ${JAX_TOOLBOX_REF} ${SRC_ROOT} - PIP_SRC=${SRC_ROOT}/.github/container/cutlass_dsl_jax - else - PIP_SRC=${SRC_ROOT} - fi - - pip install ${PIP_SRC} - - # Clone CUTLASS examples - CUTLASS_ROOT="${SRC_ROOT}/cutlass" - CUTLASS_EXAMPLES_ROOT="${CUTLASS_ROOT}/examples/python/CuTeDSL" - git clone https://github.com/NVIDIA/cutlass.git ${CUTLASS_ROOT} - - NGPUS=$(nvidia-smi --list-gpus | wc -l) - - # Start MPS daemon - nvidia-cuda-mps-control -d - - export PYTHONPATH=${CUTLASS_EXAMPLES_ROOT} - pytest-xdist.sh ${NGPUS} 1 ${LOG_DIR}/pytest-report.jsonl pytest -xsv --log-file=${LOG_DIR}/pytest_log.log --log-file-level=INFO ${PIP_SRC}/tests/ | tee -a ${LOG_DIR}/pytest_stdout_dist.log - - touch ${LOG_DIR}/done + set -xu -o pipefail + + LOG_DIR=${LOG_DIR:-/opt/output} + + SRC_ROOT=${SRC_ROOT:-$PWD/jax-cutlass-src} + SRC_ROOT=$(realpath $SRC_ROOT) + pip install pytest-reportlog pytest-xdist flatbuffers + + # Clone CUTLASS examples + CUTLASS_ROOT="${SRC_ROOT}/cutlass" + CUTLASS_EXAMPLES_ROOT="${CUTLASS_ROOT}/examples/python/CuTeDSL" + git clone https://github.com/NVIDIA/cutlass.git ${CUTLASS_ROOT} + + NGPUS=$(nvidia-smi --list-gpus | wc -l) + + # Start MPS daemon + nvidia-cuda-mps-control -d + + # Run the examples + for f in ${CUTLASS_ROOT}/examples/python/CuTeDSL/jax/*.py; do + echo "[Executing] $f" + log_output=$(python $f 2>&1) + exit_code=$? + outcome=$( [ $exit_code -eq 0 ] && echo "passed" || echo "failed" ) + echo "=== ${f} ===" | tee -a ${LOG_DIR}/pytest_stdout.log + echo "${log_output}" | tee -a ${LOG_DIR}/pytest_stdout.log + python3 -c "import json,sys; print(json.dumps({'outcome': sys.argv[1], 'nodeid': sys.argv[2], 'longrepr': sys.argv[3]}))" \ + "${outcome}" "${f}" "${log_output}" >> ${LOG_DIR}/pytest-report.jsonl + done + + touch ${LOG_DIR}/done diff --git a/.github/eks-workflow-files/transformer-engine/scripts/unittest.sh b/.github/eks-workflow-files/transformer-engine/scripts/unittest.sh index bb70285bf..18a36338b 100644 --- a/.github/eks-workflow-files/transformer-engine/scripts/unittest.sh +++ b/.github/eks-workflow-files/transformer-engine/scripts/unittest.sh @@ -3,16 +3,16 @@ LOG_DIR=/opt/output pip install pytest-reportlog pytest-xdist - + # Start MPS daemon nvidia-cuda-mps-control -d - + # TE's default is slightly different, without the hyphen export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE} - + # 1 GPU per worker, 4 workers per GPU pytest-xdist.sh 1 4 ${LOG_DIR}/pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh | tee -a ${LOG_DIR}/pytest_stdout.log - + # 8 GPUs per worker, 1 worker per GPU. pytest-xdist.sh allows aggregation # into a single .jsonl file of results from multiple pytest invocations # inside the test.sh script, so it's useful even with a single worker per