diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py index 0107014002..230eea9c03 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py @@ -360,7 +360,8 @@ def wrapper(self, other, **kwargs): def _binary_op(op): """ Decorator to check if the 'other' argument is an ArithValue. - If not, returns NotImplemented. + If 'other' is a Python scalar (int, float, bool), it is cast to a constant + MLIR value matching the type and signedness of 'self'. """ def wrapper(self, other, **kwargs):