Skip to content

[NFC][MLIR][NVVM] Add unit and Python bindings tests for MBarrierTestWaitOp/MBarrierTryWaitOp#193600

Open
kvederni wants to merge 2 commits intollvm:mainfrom
kvederni:mbarrier_cpp_py_tests
Open

[NFC][MLIR][NVVM] Add unit and Python bindings tests for MBarrierTestWaitOp/MBarrierTryWaitOp#193600
kvederni wants to merge 2 commits intollvm:mainfrom
kvederni:mbarrier_cpp_py_tests

Conversation

@kvederni
Copy link
Copy Markdown
Contributor

Add C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp.

Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.

MBarrierTestWaitOp/MBarrierTryWaitOp

Add C++ gtest unit test covering construction and result-type
verification for MBarrierTestWaitOp and MBarrierTryWaitOp.

Add Python bindings test exercising both OpView class construction and
free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-mlir-llvm

Author: Kirill Vedernikov (kvederni)

Changes

Add C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp.

Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.


Full diff: https://github.com/llvm/llvm-project/pull/193600.diff

4 Files Affected:

  • (added) mlir/test/python/dialects/nvvm/mbarrier_wait.py (+44)
  • (modified) mlir/unittests/Dialect/LLVMIR/CMakeLists.txt (+2)
  • (added) mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp (+31)
  • (added) mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc (+68)
diff --git a/mlir/test/python/dialects/nvvm/mbarrier_wait.py b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
new file mode 100644
index 0000000000000..b86d3663adb17
--- /dev/null
+++ b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
@@ -0,0 +1,44 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+"""Tests for MBarrierTestWaitOp and MBarrierTryWaitOp Python bindings.
+Covers the none-phase (single-result i1) variant.
+Two construction styles: OpView class and free function.
+"""
+
+from mlir.ir import *
+from mlir.dialects import nvvm, func
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
+
+
+# CHECK-LABEL: TEST: test_mbarrier_wait
+@run
+def test_mbarrier_wait():
+    """MBarrierTestWaitOp and MBarrierTryWaitOp -- OpView and free-function styles."""
+    i64 = IntegerType.get_signless(64)
+    ptr = Type.parse("!llvm.ptr<3>")
+
+    @func.FuncOp.from_py_func(ptr, i64)
+    def none_phase(addr, state):
+        op_test = nvvm.MBarrierTestWaitOp(addr=addr, stateOrPhase=state)
+        assert op_test.res is not None
+        op_try = nvvm.MBarrierTryWaitOp(addr=addr, stateOrPhase=state)
+        assert op_try.res is not None
+        wc_test = nvvm.mbarrier_test_wait(addr=addr, state_or_phase=state)
+        assert wc_test is not None
+        wc_try = nvvm.mbarrier_try_wait(addr=addr, state_or_phase=state)
+        assert wc_try is not None
+
+# CHECK: func.func @none_phase
+# CHECK:   nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
index 7cc130d02ad74..273fc3db9c659 100644
--- a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -1,7 +1,9 @@
 add_mlir_unittest(MLIRLLVMIRTests
   LLVMTypeTest.cpp
+  NVVMTests.cpp
 )
 mlir_target_link_libraries(MLIRLLVMIRTests
   PRIVATE
   MLIRLLVMDialect
+  MLIRNVVMDialect
   )
diff --git a/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
new file mode 100644
index 0000000000000..b43645de9aecb
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM Dialect 
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::NVVM;
+using namespace mlir::LLVM;
+
+namespace {
+
+#include "nvvm/NVVMMBarrierWaitBuilderTest.inc"
+
+} // namespace
diff --git a/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
new file mode 100644
index 0000000000000..1c2cde1b8eb37
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
@@ -0,0 +1,68 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM MBarrierTestWaitOp and MBarrierTryWaitOp.
+///
+//===----------------------------------------------------------------------===//
+
+class NVVMMBarrierWaitTest : public ::testing::Test {
+protected:
+  NVVMMBarrierWaitTest() {
+    context.loadDialect<NVVMDialect>();
+    context.loadDialect<LLVMDialect>();
+  }
+
+  // Creates a block, sets the builder insertion point, and adds addr
+  // (!llvm.ptr<3>) and stateOrPhase (stateOrPhaseTy) as block arguments.
+  // Returns {addr, stateOrPhase}.
+  std::pair<Value, Value> makeArgs(Block &block, Type stateOrPhaseTy) {
+    OpBuilder builder(&context);
+    builder.setInsertionPointToStart(&block);
+    auto loc = builder.getUnknownLoc();
+    Value addr = block.addArgument(LLVMPointerType::get(&context, 3), loc);
+    Value stateOrPhase = block.addArgument(stateOrPhaseTy, loc);
+    return {addr, stateOrPhase};
+  }
+
+  MLIRContext context;
+};
+
+TEST_F(NVVMMBarrierWaitTest, MBarrierWait) {
+  auto loc = UnknownLoc::get(&context);
+  auto i1Ty = IntegerType::get(&context, 1);
+  auto i32Ty = IntegerType::get(&context, 32);
+  auto i64Ty = IntegerType::get(&context, 64);
+
+  Block block;
+  auto [addr, state] = makeArgs(block, i64Ty);
+  Value ticks = block.addArgument(i32Ty, loc);
+  OpBuilder builder(&context);
+  builder.setInsertionPointToEnd(&block);
+
+  // MBarrierTestWaitOp.
+  auto testWait = MBarrierTestWaitOp::create(builder, loc, i1Ty, addr, state);
+  EXPECT_EQ(testWait.getRes().getType(), i1Ty);
+  EXPECT_EQ(testWait.getAddr(), addr);
+  EXPECT_EQ(testWait.getStateOrPhase(), state);
+
+  // MBarrierTryWaitOp without ticks.
+  auto tryWaitNoTicks =
+      MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, Value{});
+  EXPECT_EQ(tryWaitNoTicks.getRes().getType(), i1Ty);
+  EXPECT_EQ(tryWaitNoTicks.getAddr(), addr);
+  EXPECT_EQ(tryWaitNoTicks.getStateOrPhase(), state);
+  EXPECT_FALSE(tryWaitNoTicks.getTicks());
+
+  // MBarrierTryWaitOp with ticks.
+  auto tryWaitWithTicks =
+      MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, ticks);
+  EXPECT_EQ(tryWaitWithTicks.getRes().getType(), i1Ty);
+  EXPECT_TRUE(tryWaitWithTicks.getTicks());
+  EXPECT_EQ(tryWaitWithTicks.getTicks(), ticks);
+}

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 22, 2026

@llvm/pr-subscribers-mlir

Author: Kirill Vedernikov (kvederni)

Changes

Add C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp.

Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.


Full diff: https://github.com/llvm/llvm-project/pull/193600.diff

4 Files Affected:

  • (added) mlir/test/python/dialects/nvvm/mbarrier_wait.py (+44)
  • (modified) mlir/unittests/Dialect/LLVMIR/CMakeLists.txt (+2)
  • (added) mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp (+31)
  • (added) mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc (+68)
diff --git a/mlir/test/python/dialects/nvvm/mbarrier_wait.py b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
new file mode 100644
index 0000000000000..b86d3663adb17
--- /dev/null
+++ b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
@@ -0,0 +1,44 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+"""Tests for MBarrierTestWaitOp and MBarrierTryWaitOp Python bindings.
+Covers the none-phase (single-result i1) variant.
+Two construction styles: OpView class and free function.
+"""
+
+from mlir.ir import *
+from mlir.dialects import nvvm, func
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+    return f
+
+
+# CHECK-LABEL: TEST: test_mbarrier_wait
+@run
+def test_mbarrier_wait():
+    """MBarrierTestWaitOp and MBarrierTryWaitOp -- OpView and free-function styles."""
+    i64 = IntegerType.get_signless(64)
+    ptr = Type.parse("!llvm.ptr<3>")
+
+    @func.FuncOp.from_py_func(ptr, i64)
+    def none_phase(addr, state):
+        op_test = nvvm.MBarrierTestWaitOp(addr=addr, stateOrPhase=state)
+        assert op_test.res is not None
+        op_try = nvvm.MBarrierTryWaitOp(addr=addr, stateOrPhase=state)
+        assert op_try.res is not None
+        wc_test = nvvm.mbarrier_test_wait(addr=addr, state_or_phase=state)
+        assert wc_test is not None
+        wc_try = nvvm.mbarrier_try_wait(addr=addr, state_or_phase=state)
+        assert wc_try is not None
+
+# CHECK: func.func @none_phase
+# CHECK:   nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK:   nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
index 7cc130d02ad74..273fc3db9c659 100644
--- a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -1,7 +1,9 @@
 add_mlir_unittest(MLIRLLVMIRTests
   LLVMTypeTest.cpp
+  NVVMTests.cpp
 )
 mlir_target_link_libraries(MLIRLLVMIRTests
   PRIVATE
   MLIRLLVMDialect
+  MLIRNVVMDialect
   )
diff --git a/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
new file mode 100644
index 0000000000000..b43645de9aecb
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM Dialect 
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::NVVM;
+using namespace mlir::LLVM;
+
+namespace {
+
+#include "nvvm/NVVMMBarrierWaitBuilderTest.inc"
+
+} // namespace
diff --git a/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
new file mode 100644
index 0000000000000..1c2cde1b8eb37
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
@@ -0,0 +1,68 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM MBarrierTestWaitOp and MBarrierTryWaitOp.
+///
+//===----------------------------------------------------------------------===//
+
+class NVVMMBarrierWaitTest : public ::testing::Test {
+protected:
+  NVVMMBarrierWaitTest() {
+    context.loadDialect<NVVMDialect>();
+    context.loadDialect<LLVMDialect>();
+  }
+
+  // Creates a block, sets the builder insertion point, and adds addr
+  // (!llvm.ptr<3>) and stateOrPhase (stateOrPhaseTy) as block arguments.
+  // Returns {addr, stateOrPhase}.
+  std::pair<Value, Value> makeArgs(Block &block, Type stateOrPhaseTy) {
+    OpBuilder builder(&context);
+    builder.setInsertionPointToStart(&block);
+    auto loc = builder.getUnknownLoc();
+    Value addr = block.addArgument(LLVMPointerType::get(&context, 3), loc);
+    Value stateOrPhase = block.addArgument(stateOrPhaseTy, loc);
+    return {addr, stateOrPhase};
+  }
+
+  MLIRContext context;
+};
+
+TEST_F(NVVMMBarrierWaitTest, MBarrierWait) {
+  auto loc = UnknownLoc::get(&context);
+  auto i1Ty = IntegerType::get(&context, 1);
+  auto i32Ty = IntegerType::get(&context, 32);
+  auto i64Ty = IntegerType::get(&context, 64);
+
+  Block block;
+  auto [addr, state] = makeArgs(block, i64Ty);
+  Value ticks = block.addArgument(i32Ty, loc);
+  OpBuilder builder(&context);
+  builder.setInsertionPointToEnd(&block);
+
+  // MBarrierTestWaitOp.
+  auto testWait = MBarrierTestWaitOp::create(builder, loc, i1Ty, addr, state);
+  EXPECT_EQ(testWait.getRes().getType(), i1Ty);
+  EXPECT_EQ(testWait.getAddr(), addr);
+  EXPECT_EQ(testWait.getStateOrPhase(), state);
+
+  // MBarrierTryWaitOp without ticks.
+  auto tryWaitNoTicks =
+      MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, Value{});
+  EXPECT_EQ(tryWaitNoTicks.getRes().getType(), i1Ty);
+  EXPECT_EQ(tryWaitNoTicks.getAddr(), addr);
+  EXPECT_EQ(tryWaitNoTicks.getStateOrPhase(), state);
+  EXPECT_FALSE(tryWaitNoTicks.getTicks());
+
+  // MBarrierTryWaitOp with ticks.
+  auto tryWaitWithTicks =
+      MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, ticks);
+  EXPECT_EQ(tryWaitWithTicks.getRes().getType(), i1Ty);
+  EXPECT_TRUE(tryWaitWithTicks.getTicks());
+  EXPECT_EQ(tryWaitWithTicks.getTicks(), ticks);
+}

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

✅ With the latest revision this PR passed the Python code formatter.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 22, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: why call this .inc instead of calling calling it .h or just inlining it in the cpp?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to inline it to the .cpp because the .cpp could become unreadable if/when more unit tests are added. I believe it would be helpful to implement more tests for NVVM dialect Ops here. So, having each Op in a separate test file looks convenient.
Regarding using .inc and not .h. The file contains not a class definition but a test definition. In my mind using .h could be misleading in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants