[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars#148055
[mlir][EmitC]Expand the MemRefToEmitC pass - Adding scalars#148055
Conversation
|
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Jaden Angella (Jaddyen) ChangesThis aims to expand the the MemRefToEmitC pass so that it can accept global scalars. Full diff: https://github.com/llvm/llvm-project/pull/148055.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..e55c8e48ad105 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -16,7 +16,9 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
@@ -83,7 +85,7 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
-
+ MemRefType type = op.getType();
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
@@ -95,7 +97,13 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
- auto resultTy = getTypeConverter()->convertType(op.getType());
+
+ Type resultTy;
+ if (type.getRank() == 0)
+ resultTy = getTypeConverter()->convertType(type.getElementType());
+ else
+ resultTy = getTypeConverter()->convertType(type);
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
@@ -114,6 +122,10 @@ struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
bool externSpecifier = !staticSpecifier;
Attribute initialValue = operands.getInitialValueAttr();
+ if (type.getRank() == 0) {
+ auto elementsAttr = llvm::cast<ElementsAttr>(*op.getInitialValue());
+ initialValue = elementsAttr.getSplatValue<Attribute>();
+ }
if (isa_and_present<UnitAttr>(initialValue))
initialValue = {};
@@ -132,7 +144,17 @@ struct ConvertGetGlobal final
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
- auto resultTy = getTypeConverter()->convertType(op.getType());
+ MemRefType type = op.getType();
+ Type resultTy;
+ if (type.getRank() == 0)
+ resultTy = emitc::LValueType::get(
+ getTypeConverter()->convertType(type.getElementType()));
+ else
+ resultTy = getTypeConverter()->convertType(type);
+
+ if (!resultTy)
+ return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
+
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..445a28534325a 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -41,6 +41,8 @@ func.func @memref_load(%buff : memref<4x8xf32>, %i: index, %j: index) -> f32 {
module @globals {
memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
// CHECK-NEXT: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
+ // CHECK-NEXT: emitc.global static const @__constant_xi32 : i32 = -1
memref.global @public_global : memref<3x7xf32>
// CHECK-NEXT: emitc.global extern @public_global : !emitc.array<3x7xf32>
memref.global @uninitialized_global : memref<3x7xf32> = uninitialized
@@ -50,6 +52,8 @@ module @globals {
func.func @use_global() {
// CHECK-NEXT: emitc.get_global @public_global : !emitc.array<3x7xf32>
%0 = memref.get_global @public_global : memref<3x7xf32>
+ // CHECK- NEXT: emitc.get_global @__constant_xi32 : !emitc.lvalue<i32>
+ %1 = memref.get_global @__constant_xi32 : memref<i32>
return
}
}
|
There was a problem hiding this comment.
| MemRefType type = op.getType(); | |
| MemRefType opTy = op.getType(); |
I'm not sure we want to use type as a variable name...
There was a problem hiding this comment.
I agree. Thanks for the pointer!
There was a problem hiding this comment.
Maybe introduce a helper, since I see a similar pattern in a few spots?
There was a problem hiding this comment.
Def! Thanks for the pointer.
There was a problem hiding this comment.
I was expecting to see a corresponding change to memref-to-emitc-failed.mlir. Looking I suppose it isn't there, but are any of the cases, like https://github.com/llvm/llvm-project/blob/36b61a6b5731a524b4e79e77d7505c7a5ef3d0f9/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir#L46 going to work now?
There was a problem hiding this comment.
Not yet. This would require changing those particular ops. I could circle back to this once we can completely lower the inliner model.
There was a problem hiding this comment.
You should be able to just do
resultTy = getTypeConverter()->convertType(getElementTypeOrSelf(type));
There was a problem hiding this comment.
Just to check, so previously memref<1xi32> worked, but memref<i32> didn't work before this?
There was a problem hiding this comment.
What is the initialValue before this vs splat value returned?
There was a problem hiding this comment.
Without using getSplatValue, we have the initial value as: initial_value = dense<-1> : tensor<i32>
After getting the splat value, we have emitc.global static const @__constant_xi32 : i32 = -1
There was a problem hiding this comment.
Why does this need to be LValue, while one below not?
There was a problem hiding this comment.
memref.getGlobal gets converted to emitc.getglobal but emitc.getglobal only returns LValue or Array. So in the case that we have a constant, we create an LValue, else we return an array.
There was a problem hiding this comment.
Note that in general lowering getglobal to lvalue doesn't work correctly when the result is passed to function calls for example. So I would expect rank 0 memrefs to be lowered to pointers (at least when it might escape).
There was a problem hiding this comment.
Thanks for pointing this out. I've updated the conversion to reflect this.
There was a problem hiding this comment.
Sorry für being unclear. I think globals should still be lowered to lvalues so that it allocates the necessary storage. But the get_global may be lowered to EmitC.get_global + EmitC.apply "&" to get a Pointer to the variable. But I haven't tried this out.
There was a problem hiding this comment.
What I've tried and been able to implement is the conversion:
From:
memref.global "private" constant @__constant_xi32 : memref<i32> = dense<-1>
func.func @globals() {
memref.get_global @__constant_xi32 : memref<i32>
}
To:
emitc.global static const @__constant_xi32 : i32 = -1
emitc.func @globals() {
%0 = get_global @__constant_xi32 : !emitc.lvalue<i32>
%1 = apply "&"(%0) : (!emitc.lvalue<i32>) -> !emitc.ptr<i32>
return
}
There was a problem hiding this comment.
Don't you have this check just below too?
There was a problem hiding this comment.
I do! Thanks for the pointer.
jpienaar
left a comment
There was a problem hiding this comment.
Looks good in general. I believe you also tested it with compiling the output too.
Yes, I did! |
This aims to expand the the MemRefToEmitC pass so that it can accept global scalars.
From:
To: