[NFC][MLIR][NVVM] Add unit and Python bindings tests for MBarrierTestWaitOp/MBarrierTryWaitOp#193600
[NFC][MLIR][NVVM] Add unit and Python bindings tests for MBarrierTestWaitOp/MBarrierTryWaitOp#193600
Conversation
MBarrierTestWaitOp/MBarrierTryWaitOp Add C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp. Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.
|
@llvm/pr-subscribers-mlir-llvm Author: Kirill Vedernikov (kvederni) ChangesAdd C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp. Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp. Full diff: https://github.com/llvm/llvm-project/pull/193600.diff 4 Files Affected:
diff --git a/mlir/test/python/dialects/nvvm/mbarrier_wait.py b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
new file mode 100644
index 0000000000000..b86d3663adb17
--- /dev/null
+++ b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
@@ -0,0 +1,44 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+"""Tests for MBarrierTestWaitOp and MBarrierTryWaitOp Python bindings.
+Covers the none-phase (single-result i1) variant.
+Two construction styles: OpView class and free function.
+"""
+
+from mlir.ir import *
+from mlir.dialects import nvvm, func
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: test_mbarrier_wait
+@run
+def test_mbarrier_wait():
+ """MBarrierTestWaitOp and MBarrierTryWaitOp -- OpView and free-function styles."""
+ i64 = IntegerType.get_signless(64)
+ ptr = Type.parse("!llvm.ptr<3>")
+
+ @func.FuncOp.from_py_func(ptr, i64)
+ def none_phase(addr, state):
+ op_test = nvvm.MBarrierTestWaitOp(addr=addr, stateOrPhase=state)
+ assert op_test.res is not None
+ op_try = nvvm.MBarrierTryWaitOp(addr=addr, stateOrPhase=state)
+ assert op_try.res is not None
+ wc_test = nvvm.mbarrier_test_wait(addr=addr, state_or_phase=state)
+ assert wc_test is not None
+ wc_try = nvvm.mbarrier_try_wait(addr=addr, state_or_phase=state)
+ assert wc_try is not None
+
+# CHECK: func.func @none_phase
+# CHECK: nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
index 7cc130d02ad74..273fc3db9c659 100644
--- a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -1,7 +1,9 @@
add_mlir_unittest(MLIRLLVMIRTests
LLVMTypeTest.cpp
+ NVVMTests.cpp
)
mlir_target_link_libraries(MLIRLLVMIRTests
PRIVATE
MLIRLLVMDialect
+ MLIRNVVMDialect
)
diff --git a/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
new file mode 100644
index 0000000000000..b43645de9aecb
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM Dialect
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::NVVM;
+using namespace mlir::LLVM;
+
+namespace {
+
+#include "nvvm/NVVMMBarrierWaitBuilderTest.inc"
+
+} // namespace
diff --git a/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
new file mode 100644
index 0000000000000..1c2cde1b8eb37
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
@@ -0,0 +1,68 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM MBarrierTestWaitOp and MBarrierTryWaitOp.
+///
+//===----------------------------------------------------------------------===//
+
+class NVVMMBarrierWaitTest : public ::testing::Test {
+protected:
+ NVVMMBarrierWaitTest() {
+ context.loadDialect<NVVMDialect>();
+ context.loadDialect<LLVMDialect>();
+ }
+
+ // Creates a block, sets the builder insertion point, and adds addr
+ // (!llvm.ptr<3>) and stateOrPhase (stateOrPhaseTy) as block arguments.
+ // Returns {addr, stateOrPhase}.
+ std::pair<Value, Value> makeArgs(Block &block, Type stateOrPhaseTy) {
+ OpBuilder builder(&context);
+ builder.setInsertionPointToStart(&block);
+ auto loc = builder.getUnknownLoc();
+ Value addr = block.addArgument(LLVMPointerType::get(&context, 3), loc);
+ Value stateOrPhase = block.addArgument(stateOrPhaseTy, loc);
+ return {addr, stateOrPhase};
+ }
+
+ MLIRContext context;
+};
+
+TEST_F(NVVMMBarrierWaitTest, MBarrierWait) {
+ auto loc = UnknownLoc::get(&context);
+ auto i1Ty = IntegerType::get(&context, 1);
+ auto i32Ty = IntegerType::get(&context, 32);
+ auto i64Ty = IntegerType::get(&context, 64);
+
+ Block block;
+ auto [addr, state] = makeArgs(block, i64Ty);
+ Value ticks = block.addArgument(i32Ty, loc);
+ OpBuilder builder(&context);
+ builder.setInsertionPointToEnd(&block);
+
+ // MBarrierTestWaitOp.
+ auto testWait = MBarrierTestWaitOp::create(builder, loc, i1Ty, addr, state);
+ EXPECT_EQ(testWait.getRes().getType(), i1Ty);
+ EXPECT_EQ(testWait.getAddr(), addr);
+ EXPECT_EQ(testWait.getStateOrPhase(), state);
+
+ // MBarrierTryWaitOp without ticks.
+ auto tryWaitNoTicks =
+ MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, Value{});
+ EXPECT_EQ(tryWaitNoTicks.getRes().getType(), i1Ty);
+ EXPECT_EQ(tryWaitNoTicks.getAddr(), addr);
+ EXPECT_EQ(tryWaitNoTicks.getStateOrPhase(), state);
+ EXPECT_FALSE(tryWaitNoTicks.getTicks());
+
+ // MBarrierTryWaitOp with ticks.
+ auto tryWaitWithTicks =
+ MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, ticks);
+ EXPECT_EQ(tryWaitWithTicks.getRes().getType(), i1Ty);
+ EXPECT_TRUE(tryWaitWithTicks.getTicks());
+ EXPECT_EQ(tryWaitWithTicks.getTicks(), ticks);
+}
|
|
@llvm/pr-subscribers-mlir Author: Kirill Vedernikov (kvederni) ChangesAdd C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp. Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp. Full diff: https://github.com/llvm/llvm-project/pull/193600.diff 4 Files Affected:
diff --git a/mlir/test/python/dialects/nvvm/mbarrier_wait.py b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
new file mode 100644
index 0000000000000..b86d3663adb17
--- /dev/null
+++ b/mlir/test/python/dialects/nvvm/mbarrier_wait.py
@@ -0,0 +1,44 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+"""Tests for MBarrierTestWaitOp and MBarrierTryWaitOp Python bindings.
+Covers the none-phase (single-result i1) variant.
+Two construction styles: OpView class and free function.
+"""
+
+from mlir.ir import *
+from mlir.dialects import nvvm, func
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+ return f
+
+
+# CHECK-LABEL: TEST: test_mbarrier_wait
+@run
+def test_mbarrier_wait():
+ """MBarrierTestWaitOp and MBarrierTryWaitOp -- OpView and free-function styles."""
+ i64 = IntegerType.get_signless(64)
+ ptr = Type.parse("!llvm.ptr<3>")
+
+ @func.FuncOp.from_py_func(ptr, i64)
+ def none_phase(addr, state):
+ op_test = nvvm.MBarrierTestWaitOp(addr=addr, stateOrPhase=state)
+ assert op_test.res is not None
+ op_try = nvvm.MBarrierTryWaitOp(addr=addr, stateOrPhase=state)
+ assert op_try.res is not None
+ wc_test = nvvm.mbarrier_test_wait(addr=addr, state_or_phase=state)
+ assert wc_test is not None
+ wc_try = nvvm.mbarrier_try_wait(addr=addr, state_or_phase=state)
+ assert wc_try is not None
+
+# CHECK: func.func @none_phase
+# CHECK: nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.test.wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
+# CHECK: nvvm.mbarrier.try_wait %{{.*}}, %{{.*}} : !llvm.ptr<3>, i64 -> i1
diff --git a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
index 7cc130d02ad74..273fc3db9c659 100644
--- a/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
+++ b/mlir/unittests/Dialect/LLVMIR/CMakeLists.txt
@@ -1,7 +1,9 @@
add_mlir_unittest(MLIRLLVMIRTests
LLVMTypeTest.cpp
+ NVVMTests.cpp
)
mlir_target_link_libraries(MLIRLLVMIRTests
PRIVATE
MLIRLLVMDialect
+ MLIRNVVMDialect
)
diff --git a/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
new file mode 100644
index 0000000000000..b43645de9aecb
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/NVVMTests.cpp
@@ -0,0 +1,31 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM Dialect
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::NVVM;
+using namespace mlir::LLVM;
+
+namespace {
+
+#include "nvvm/NVVMMBarrierWaitBuilderTest.inc"
+
+} // namespace
diff --git a/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
new file mode 100644
index 0000000000000..1c2cde1b8eb37
--- /dev/null
+++ b/mlir/unittests/Dialect/LLVMIR/nvvm/NVVMMBarrierWaitBuilderTest.inc
@@ -0,0 +1,68 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Builder tests for NVVM MBarrierTestWaitOp and MBarrierTryWaitOp.
+///
+//===----------------------------------------------------------------------===//
+
+class NVVMMBarrierWaitTest : public ::testing::Test {
+protected:
+ NVVMMBarrierWaitTest() {
+ context.loadDialect<NVVMDialect>();
+ context.loadDialect<LLVMDialect>();
+ }
+
+ // Creates a block, sets the builder insertion point, and adds addr
+ // (!llvm.ptr<3>) and stateOrPhase (stateOrPhaseTy) as block arguments.
+ // Returns {addr, stateOrPhase}.
+ std::pair<Value, Value> makeArgs(Block &block, Type stateOrPhaseTy) {
+ OpBuilder builder(&context);
+ builder.setInsertionPointToStart(&block);
+ auto loc = builder.getUnknownLoc();
+ Value addr = block.addArgument(LLVMPointerType::get(&context, 3), loc);
+ Value stateOrPhase = block.addArgument(stateOrPhaseTy, loc);
+ return {addr, stateOrPhase};
+ }
+
+ MLIRContext context;
+};
+
+TEST_F(NVVMMBarrierWaitTest, MBarrierWait) {
+ auto loc = UnknownLoc::get(&context);
+ auto i1Ty = IntegerType::get(&context, 1);
+ auto i32Ty = IntegerType::get(&context, 32);
+ auto i64Ty = IntegerType::get(&context, 64);
+
+ Block block;
+ auto [addr, state] = makeArgs(block, i64Ty);
+ Value ticks = block.addArgument(i32Ty, loc);
+ OpBuilder builder(&context);
+ builder.setInsertionPointToEnd(&block);
+
+ // MBarrierTestWaitOp.
+ auto testWait = MBarrierTestWaitOp::create(builder, loc, i1Ty, addr, state);
+ EXPECT_EQ(testWait.getRes().getType(), i1Ty);
+ EXPECT_EQ(testWait.getAddr(), addr);
+ EXPECT_EQ(testWait.getStateOrPhase(), state);
+
+ // MBarrierTryWaitOp without ticks.
+ auto tryWaitNoTicks =
+ MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, Value{});
+ EXPECT_EQ(tryWaitNoTicks.getRes().getType(), i1Ty);
+ EXPECT_EQ(tryWaitNoTicks.getAddr(), addr);
+ EXPECT_EQ(tryWaitNoTicks.getStateOrPhase(), state);
+ EXPECT_FALSE(tryWaitNoTicks.getTicks());
+
+ // MBarrierTryWaitOp with ticks.
+ auto tryWaitWithTicks =
+ MBarrierTryWaitOp::create(builder, loc, i1Ty, addr, state, ticks);
+ EXPECT_EQ(tryWaitWithTicks.getRes().getType(), i1Ty);
+ EXPECT_TRUE(tryWaitWithTicks.getTicks());
+ EXPECT_EQ(tryWaitWithTicks.getTicks(), ticks);
+}
|
|
✅ With the latest revision this PR passed the Python code formatter. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Just curious: why call this .inc instead of calling calling it .h or just inlining it in the cpp?
There was a problem hiding this comment.
I don't want to inline it to the .cpp because the .cpp could become unreadable if/when more unit tests are added. I believe it would be helpful to implement more tests for NVVM dialect Ops here. So, having each Op in a separate test file looks convenient.
Regarding using .inc and not .h. The file contains not a class definition but a test definition. In my mind using .h could be misleading in this case.
Add C++ gtest unit test covering construction and result-type verification for MBarrierTestWaitOp and MBarrierTryWaitOp.
Add Python bindings test exercising both OpView class construction and free function style for MBarrierTestWaitOp and MBarrierTryWaitOp.