Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 29ed75d

Browse files
committed
Fix InferShape pass
1 parent 992c3c0 commit 29ed75d

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

src/executor/infer_graph_attr_pass.cc

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,26 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
6868
* shape/type inference functions'. The nnvm InferAttr will be deprecated
6969
* in the future. Please use interfaces InferShape, InferType, and InferStorageType
7070
* to call this function.
71+
*
72+
* \param ret graph used for attribute inference
73+
* \param emmpty_val empty value of the attribute
74+
* \param infer_name name of the function used for attribute inference
75+
* \param input_name name of the attribute in the graph used to store the
76+
* input data for attribute inference
77+
* \param attr_key_name name of the attribute used for inference for variable nodes
78+
* \param attr_name name of the inferred attribute
79+
* \param unknown_name name of the attribute storing number of entries
80+
* impossible to infer
81+
* \param fis_none function returning true for not fully inferred values
82+
* \param fdefault default function used for inference if the node does not
83+
* provide its own implementation.
84+
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
85+
* NDArray have to be the same. False only for storage
86+
* type inference
87+
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
88+
* storage type inference
89+
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
90+
* for storage type inference
7191
*/
7292
template<typename AttrType, typename FInferType, typename IsNone, typename FDefault>
7393
nnvm::Graph InferAttr(nnvm::Graph &&ret,
@@ -322,7 +342,32 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
322342
return ret;
323343
}
324344

325-
template<typename IsNone, typename FDefault>
345+
/*!\brief
346+
* This is a version of the InferAttr function specifically for shape inference.
347+
*
348+
* \param ret graph used for attribute inference
349+
* \param emmpty_val empty value of the attribute
350+
* \param infer_name name of the function used for attribute inference
351+
* \param input_name name of the attribute in the graph used to store the
352+
* input data for attribute inference
353+
* \param attr_key_name name of the attribute used for inference for variable nodes
354+
* \param attr_name name of the inferred attribute
355+
* \param unknown_name name of the attribute storing number of entries
356+
* impossible to infer
357+
* \param fis_none function returning true for not fully inferred values
358+
* \param fnum_unknown function returning how many elements are unknown in
359+
* partially inferred value of the attribute
360+
* \param fdefault default function used for inference if the node does not
361+
* provide its own implementation.
362+
* \param bwd_identity_assign whether the attributes of forward NDArray and backward
363+
* NDArray have to be the same. False only for storage
364+
* type inference
365+
* \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
366+
* storage type inference
367+
* \param default_mode_val default value of the dispatch mode attribute on the node. Used
368+
* for storage type inference
369+
*/
370+
template<typename IsNone, typename FDefault, typename FNumUnknown>
326371
nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
327372
const nnvm::TShape empty_val,
328373
const char* infer_name,
@@ -331,6 +376,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
331376
const char* attr_name,
332377
const char* unknown_name,
333378
IsNone fis_none,
379+
FNumUnknown fnum_unknown,
334380
FDefault fdefault,
335381
bool bwd_identity_assign,
336382
const char* dispatch_mode_name,
@@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
548594
};
549595

550596
size_t last_num_unknown;
551-
size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
552-
size_t num_unknown_entry_attr = entry_end - entry_start;
553-
size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
597+
size_t num_unknown = static_cast<size_t>(-1); // Infinity
598+
554599
int i = 0;
555600
do {
556601
if (i % 2 == 0) {
602+
// forward inference
557603
for (uint32_t nid = node_start; nid < node_end; ++nid) {
558604
infer_step(nid, false);
559605
}
@@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
567613
num_unknown = 0;
568614
for (size_t j = entry_start; j < entry_end; ++j) {
569615
if (fis_none(rshape[j])) {
570-
++num_unknown;
616+
num_unknown += fnum_unknown(rshape[j]);
571617
}
572618
}
573619
if (dispatch_mode_name) {
@@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
598644
if (shape_attr_key.length() != 0) {
599645
graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
600646
}
601-
return InferAttr<mxnet::TShape, mxnet::FInferShape>(
647+
return InferShapeAttr(
602648
std::move(graph), mxnet::TShape(),
603649
"FInferShape", "shape_inputs", "shape_attr_key",
604650
"shape", "shape_num_unknown_nodes",
605651
[](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
652+
[](const mxnet::TShape& s) {
653+
if (s.ndim() == 0) { // TODO(reminisce): Usage of ndim
654+
return static_cast<size_t>(1);
655+
}
656+
size_t ret = 0;
657+
for (const auto& val : s) {
658+
if (val == 0) {
659+
++ret;
660+
}
661+
}
662+
return ret;
663+
},
606664
nullptr, true, nullptr);
607665
}
608666

tests/python/unittest/test_symbol.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,19 @@ def test_symbol_infer_shape():
157157
assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
158158
assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden)
159159

160+
# Partial shape inference with some unknown dimensions
161+
data_shape = (1, 0, 0, 0)
162+
data = mx.sym.Variable('data', shape=data_shape)
163+
weight = mx.sym.Variable('weight')
164+
cdata = mx.sym.cast(data, dtype='float16')
165+
cweight = mx.sym.cast(weight, dtype='float16')
166+
test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))
167+
168+
arg, _, _ = test.infer_shape_partial()
169+
arg_shapes = dict(zip(test.list_arguments(), arg))
170+
assert arg_shapes['data'] == data_shape
171+
assert arg_shapes['weight'] == (64, 0, 7, 7)
172+
160173

161174
def test_symbol_infer_shape_var():
162175
"Test specifying shape information when constructing a variable"

0 commit comments

Comments
 (0)