@@ -68,6 +68,26 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const nnvm::Graph& g,
6868 * shape/type inference functions'. The nnvm InferAttr will be deprecated
6969 * in the future. Please use interfaces InferShape, InferType, and InferStorageType
7070 * to call this function.
71+ *
72+ * \param ret graph used for attribute inference
73+ * \param emmpty_val empty value of the attribute
74+ * \param infer_name name of the function used for attribute inference
75+ * \param input_name name of the attribute in the graph used to store the
76+ * input data for attribute inference
77+ * \param attr_key_name name of the attribute used for inference for variable nodes
78+ * \param attr_name name of the inferred attribute
79+ * \param unknown_name name of the attribute storing number of entries
80+ * impossible to infer
81+ * \param fis_none function returning true for not fully inferred values
82+ * \param fdefault default function used for inference if the node does not
83+ * provide its own implementation.
84+ * \param bwd_identity_assign whether the attributes of forward NDArray and backward
85+ * NDArray have to be the same. False only for storage
86+ * type inference
87+ * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
88+ * storage type inference
89+ * \param default_mode_val default value of the dispatch mode attribute on the node. Used
90+ * for storage type inference
7191 */
7292template <typename AttrType, typename FInferType, typename IsNone, typename FDefault>
7393nnvm::Graph InferAttr (nnvm::Graph &&ret,
@@ -322,7 +342,32 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
322342 return ret;
323343}
324344
325- template <typename IsNone, typename FDefault>
345+ /* !\brief
346+ * This is a version of the InferAttr function specifically for shape inference.
347+ *
348+ * \param ret graph used for attribute inference
349+ * \param emmpty_val empty value of the attribute
350+ * \param infer_name name of the function used for attribute inference
351+ * \param input_name name of the attribute in the graph used to store the
352+ * input data for attribute inference
353+ * \param attr_key_name name of the attribute used for inference for variable nodes
354+ * \param attr_name name of the inferred attribute
355+ * \param unknown_name name of the attribute storing number of entries
356+ * impossible to infer
357+ * \param fis_none function returning true for not fully inferred values
358+ * \param fnum_unknown function returning how many elements are unknown in
359+ * partially inferred value of the attribute
360+ * \param fdefault default function used for inference if the node does not
361+ * provide its own implementation.
362+ * \param bwd_identity_assign whether the attributes of forward NDArray and backward
363+ * NDArray have to be the same. False only for storage
364+ * type inference
365+ * \param dispatch_mode_name name of the dispatch mode attribute on the node. Used for
366+ * storage type inference
367+ * \param default_mode_val default value of the dispatch mode attribute on the node. Used
368+ * for storage type inference
369+ */
370+ template <typename IsNone, typename FDefault, typename FNumUnknown>
326371nnvm::Graph InferShapeAttr (nnvm::Graph &&ret,
327372 const nnvm::TShape empty_val,
328373 const char * infer_name,
@@ -331,6 +376,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
331376 const char * attr_name,
332377 const char * unknown_name,
333378 IsNone fis_none,
379+ FNumUnknown fnum_unknown,
334380 FDefault fdefault,
335381 bool bwd_identity_assign,
336382 const char * dispatch_mode_name,
@@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
548594 };
549595
550596 size_t last_num_unknown;
551- size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0 ;
552- size_t num_unknown_entry_attr = entry_end - entry_start;
553- size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
597+ size_t num_unknown = static_cast <size_t >(-1 ); // Infinity
598+
554599 int i = 0 ;
555600 do {
556601 if (i % 2 == 0 ) {
602+ // forward inference
557603 for (uint32_t nid = node_start; nid < node_end; ++nid) {
558604 infer_step (nid, false );
559605 }
@@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
567613 num_unknown = 0 ;
568614 for (size_t j = entry_start; j < entry_end; ++j) {
569615 if (fis_none (rshape[j])) {
570- ++ num_unknown;
616+ num_unknown += fnum_unknown (rshape[j]) ;
571617 }
572618 }
573619 if (dispatch_mode_name) {
@@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
598644 if (shape_attr_key.length () != 0 ) {
599645 graph.attrs [" shape_attr_key" ] = std::make_shared<any>(shape_attr_key);
600646 }
601- return InferAttr<mxnet::TShape, mxnet::FInferShape> (
647+ return InferShapeAttr (
602648 std::move (graph), mxnet::TShape (),
603649 " FInferShape" , " shape_inputs" , " shape_attr_key" ,
604650 " shape" , " shape_num_unknown_nodes" ,
605651 [](const mxnet::TShape& s) { return s.ndim () == 0 || s.Size () == 0 ; },
652+ [](const mxnet::TShape& s) {
653+ if (s.ndim () == 0 ) { // TODO(reminisce): Usage of ndim
654+ return static_cast <size_t >(1 );
655+ }
656+ size_t ret = 0 ;
657+ for (const auto & val : s) {
658+ if (val == 0 ) {
659+ ++ret;
660+ }
661+ }
662+ return ret;
663+ },
606664 nullptr , true , nullptr );
607665}
608666
0 commit comments