Skip to content

Commit abfd76d

Browse files
authored
Added onnx export functionality for MissingValueIndicatorTransformer (#4194)
1 parent 42e39b9 commit abfd76d

File tree

7 files changed

+277
-7
lines changed

7 files changed

+277
-7
lines changed

src/Microsoft.ML.OnnxConverter/OnnxUtils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ private static TensorProto.Types.DataType ConvertToTensorProtoType(Type rawType)
227227
var dataType = TensorProto.Types.DataType.Undefined;
228228

229229
if (rawType == typeof(bool))
230-
dataType = TensorProto.Types.DataType.Float;
230+
dataType = TensorProto.Types.DataType.Bool;
231231
else if (rawType == typeof(ReadOnlyMemory<char>))
232232
dataType = TensorProto.Types.DataType.String;
233233
else if (rawType == typeof(sbyte))

src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.CommandLine;
1111
using Microsoft.ML.Data;
1212
using Microsoft.ML.Internal.Utilities;
13+
using Microsoft.ML.Model.OnnxConverter;
1314
using Microsoft.ML.Runtime;
1415
using Microsoft.ML.Transforms;
1516

@@ -140,7 +141,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
140141

141142
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
142143

143-
private sealed class Mapper : OneToOneMapperBase
144+
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
144145
{
145146
private readonly MissingValueIndicatorTransformer _parent;
146147
private readonly ColInfo[] _infos;
@@ -426,6 +427,46 @@ private void FillValues(int srcLength, ref VBuffer<bool> dst, List<int> indices,
426427
dst = editor.Commit();
427428
}
428429
}
430+
431+
public bool CanSaveOnnx(OnnxContext ctx) => true;
432+
433+
public void SaveAsOnnx(OnnxContext ctx)
434+
{
435+
Host.CheckValue(ctx, nameof(ctx));
436+
437+
for (int iinfo = 0; iinfo < _infos.Length; ++iinfo)
438+
{
439+
ColInfo info = _infos[iinfo];
440+
string inputColumnName = info.InputColumnName;
441+
if (!ctx.ContainsColumn(inputColumnName))
442+
{
443+
ctx.RemoveColumn(info.Name, false);
444+
continue;
445+
}
446+
447+
if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
448+
ctx.AddIntermediateVariable(_infos[iinfo].OutputType, info.Name)))
449+
{
450+
ctx.RemoveColumn(info.Name, true);
451+
}
452+
}
453+
}
454+
455+
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
456+
{
457+
var inputType = _infos[iinfo].InputType;
458+
Type rawType = (inputType is VectorDataViewType vectorType) ? vectorType.ItemType.RawType : inputType.RawType;
459+
460+
if (rawType != typeof(float))
461+
return false;
462+
463+
string opType;
464+
opType = "IsNaN";
465+
var isNaNOutput = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "IsNaNOutput", true);
466+
var nanNode = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType), "");
467+
468+
return true;
469+
}
429470
}
430471
}
431472

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@
459459
"name": "PredictedLabel0",
460460
"type": {
461461
"tensorType": {
462-
"elemType": "FLOAT",
462+
"elemType": "BOOL",
463463
"shape": {
464464
"dim": [
465465
{

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@
786786
"name": "PredictedLabel0",
787787
"type": {
788788
"tensorType": {
789-
"elemType": "FLOAT",
789+
"elemType": "BOOL",
790790
"shape": {
791791
"dim": [
792792
{

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@
414414
"name": "Label",
415415
"type": {
416416
"tensorType": {
417-
"elemType": "FLOAT",
417+
"elemType": "BOOL",
418418
"shape": {
419419
"dim": [
420420
{
@@ -470,7 +470,7 @@
470470
"name": "Label0",
471471
"type": {
472472
"tensorType": {
473-
"elemType": "FLOAT",
473+
"elemType": "BOOL",
474474
"shape": {
475475
"dim": [
476476
{
@@ -542,7 +542,7 @@
542542
"name": "PredictedLabel0",
543543
"type": {
544544
"tensorType": {
545-
"elemType": "FLOAT",
545+
"elemType": "BOOL",
546546
"shape": {
547547
"dim": [
548548
{
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
{
2+
"irVersion": "3",
3+
"producerName": "ML.NET",
4+
"producerVersion": "##VERSION##",
5+
"domain": "machinelearning.dotnet",
6+
"graph": {
7+
"node": [
8+
{
9+
"input": [
10+
"Features"
11+
],
12+
"output": [
13+
"MissingIndicator"
14+
],
15+
"name": "IsNaN",
16+
"opType": "IsNaN"
17+
},
18+
{
19+
"input": [
20+
"MissingIndicator"
21+
],
22+
"output": [
23+
"MissingIndicator0"
24+
],
25+
"name": "Cast",
26+
"opType": "Cast",
27+
"attribute": [
28+
{
29+
"name": "to",
30+
"i": "6",
31+
"type": "INT"
32+
}
33+
]
34+
},
35+
{
36+
"input": [
37+
"Features"
38+
],
39+
"output": [
40+
"Features0"
41+
],
42+
"name": "Identity",
43+
"opType": "Identity"
44+
},
45+
{
46+
"input": [
47+
"MissingIndicator0"
48+
],
49+
"output": [
50+
"MissingIndicator1"
51+
],
52+
"name": "Identity0",
53+
"opType": "Identity"
54+
}
55+
],
56+
"name": "model",
57+
"input": [
58+
{
59+
"name": "Features",
60+
"type": {
61+
"tensorType": {
62+
"elemType": "FLOAT",
63+
"shape": {
64+
"dim": [
65+
{
66+
"dimValue": "1"
67+
},
68+
{
69+
"dimValue": "3"
70+
}
71+
]
72+
}
73+
}
74+
}
75+
}
76+
],
77+
"output": [
78+
{
79+
"name": "Features0",
80+
"type": {
81+
"tensorType": {
82+
"elemType": "FLOAT",
83+
"shape": {
84+
"dim": [
85+
{
86+
"dimValue": "1"
87+
},
88+
{
89+
"dimValue": "3"
90+
}
91+
]
92+
}
93+
}
94+
}
95+
},
96+
{
97+
"name": "MissingIndicator1",
98+
"type": {
99+
"tensorType": {
100+
"elemType": "INT32",
101+
"shape": {
102+
"dim": [
103+
{
104+
"dimValue": "1"
105+
},
106+
{
107+
"dimValue": "3"
108+
}
109+
]
110+
}
111+
}
112+
}
113+
}
114+
],
115+
"valueInfo": [
116+
{
117+
"name": "MissingIndicator",
118+
"type": {
119+
"tensorType": {
120+
"elemType": "BOOL",
121+
"shape": {
122+
"dim": [
123+
{
124+
"dimValue": "1"
125+
},
126+
{
127+
"dimValue": "3"
128+
}
129+
]
130+
}
131+
}
132+
}
133+
},
134+
{
135+
"name": "MissingIndicator0",
136+
"type": {
137+
"tensorType": {
138+
"elemType": "INT32",
139+
"shape": {
140+
"dim": [
141+
{
142+
"dimValue": "1"
143+
},
144+
{
145+
"dimValue": "3"
146+
}
147+
]
148+
}
149+
}
150+
}
151+
}
152+
]
153+
},
154+
"opsetImport": [
155+
{
156+
"domain": "ai.onnx.ml",
157+
"version": "1"
158+
},
159+
{
160+
"version": "9"
161+
}
162+
]
163+
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ public OnnxConversionTest(ITestOutputHelper output) : base(output)
4141
{
4242
}
4343

44+
private bool IsOnnxRuntimeSupported()
45+
{
46+
return Environment.Is64BitProcess && (!RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || AttributeHelpers.CheckLibcVersionGreaterThanMinimum(new System.Version(2, 23)));
47+
}
48+
4449
/// <summary>
4550
/// In this test, we convert a trained <see cref="TransformerChain"/> into ONNX <see cref="ModelProto"/> file and then
4651
/// call <see cref="OnnxScoringEstimator"/> to evaluate that file. The outputs of <see cref="OnnxScoringEstimator"/> are checked against the original
@@ -780,6 +785,67 @@ public void PcaOnnxConversionTest()
780785
Done();
781786
}
782787

788+
private class TransformedDataPoint : DataPoint, IEquatable<TransformedDataPoint>
789+
{
790+
[VectorType(3)]
791+
public int[] MissingIndicator { get; set; }
792+
793+
public bool Equals(TransformedDataPoint other)
794+
{
795+
return Enumerable.SequenceEqual(MissingIndicator, other.MissingIndicator);
796+
}
797+
}
798+
799+
[Fact]
800+
void IndicateMissingValuesOnnxConversionTest()
801+
{
802+
var mlContext = new MLContext(seed: 1);
803+
804+
var samples = new List<DataPoint>()
805+
{
806+
new DataPoint() { Features = new float[3] {1, 1, 0}, },
807+
new DataPoint() { Features = new float[3] {0, float.NaN, 1}, },
808+
new DataPoint() { Features = new float[3] {-1, float.NaN, float.PositiveInfinity}, },
809+
};
810+
var dataView = mlContext.Data.LoadFromEnumerable(samples);
811+
812+
// IsNaN outputs a binary tensor. Support for this has been added in the latest version
813+
// of Onnxruntime, but that hasn't been released yet.
814+
// So we need to convert its type to Int32 until then.
815+
// ConvertType part of the pipeline can be removed once we pick up a new release of the Onnx runtime
816+
817+
var pipeline = mlContext.Transforms.IndicateMissingValues(new[] { new InputOutputColumnPair("MissingIndicator", "Features"), })
818+
.Append(mlContext.Transforms.Conversion.ConvertType("MissingIndicator", outputKind: DataKind.Int32));
819+
820+
var model = pipeline.Fit(dataView);
821+
var transformedData = model.Transform(dataView);
822+
var mlnetData = mlContext.Data.CreateEnumerable<TransformedDataPoint>(transformedData, false);
823+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
824+
825+
var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms");
826+
var onnxFileName = "IndicateMissingValues.onnx";
827+
var onnxTextName = "IndicateMissingValues.txt";
828+
var onnxModelPath = GetOutputPath(onnxFileName);
829+
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
830+
831+
SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);
832+
833+
// Compare results produced by ML.NET and ONNX's runtime.
834+
if (IsOnnxRuntimeSupported())
835+
{
836+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
837+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
838+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
839+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
840+
var onnxTransformer = onnxEstimator.Fit(dataView);
841+
var onnxResult = onnxTransformer.Transform(dataView);
842+
CompareSelectedVectorColumns<int>(model.LastTransformer.ColumnPairs[0].outputColumnName, outputNames[1], transformedData, onnxResult);
843+
}
844+
845+
CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle);
846+
Done();
847+
}
848+
783849
private void CreateDummyExamplesToMakeComplierHappy()
784850
{
785851
var dummyExample = new BreastCancerFeatureVector() { Features = null };

0 commit comments

Comments
 (0)