@@ -322,6 +322,272 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
322322 return ret;
323323}
324324
325+ template <typename IsNone, typename FDefault>
326+ nnvm::Graph InferShapeAttr (nnvm::Graph &&ret,
327+ const nnvm::TShape empty_val,
328+ const char * infer_name,
329+ const char * input_name,
330+ const char * attr_key_name,
331+ const char * attr_name,
332+ const char * unknown_name,
333+ IsNone fis_none,
334+ FDefault fdefault,
335+ bool bwd_identity_assign,
336+ const char * dispatch_mode_name,
337+ const DispatchMode default_mode_val = DispatchMode::kUndefined ) {
338+ using nnvm::IndexedGraph;
339+ using nnvm::Op;
340+ using AttrType = nnvm::TShape;
341+ using FInferType = nnvm::FInferShape;
342+ using AttrVector = std::vector<AttrType>;
343+ using NodeAttrVector = std::vector<DispatchMode>;
344+ using dmlc::any;
345+ const IndexedGraph& idx = ret.indexed_graph ();
346+ static auto & finfer_shape =
347+ Op::GetAttr<FInferType>(infer_name);
348+ static auto & is_backward =
349+ Op::GetAttr<nnvm::TIsBackward>(" TIsBackward" );
350+ // gradient function, used to get node correspondence.
351+ static auto & fgrad =
352+ Op::GetAttr<nnvm::FGradient>(" FGradient" );
353+ // reshape shape vector
354+ AttrVector rshape;
355+ // dispatch mode vector
356+ DispatchModeVector dispatch_modes;
357+ if (ret.attrs .count (attr_name) != 0 ) {
358+ rshape = ret.MoveCopyAttr <AttrVector>(attr_name);
359+ } else {
360+ rshape.resize (idx.num_node_entries (), empty_val);
361+ }
362+
363+ if (ret.attrs .count (input_name) != 0 ) {
364+ const AttrVector& shape_args = ret.GetAttr <AttrVector>(input_name);
365+ CHECK_LE (shape_args.size (), idx.input_nodes ().size ())
366+ << " More provided " << attr_name << " s than number of arguments." ;
367+ for (size_t i = 0 ; i < shape_args.size (); ++i) {
368+ rshape[idx.entry_id (idx.input_nodes ()[i], 0 )] = shape_args[i];
369+ }
370+ }
371+
372+ // get the shape hints
373+ std::string shape_hints_key = std::string (attr_name) + " _hints" ;
374+ if (ret.attrs .count (shape_hints_key)) {
375+ nnvm::NodeEntryMap<AttrType> shape_hints =
376+ ret.GetAttr <nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
377+ for (const auto & kv : shape_hints) {
378+ nnvm::NodeEntry e = kv.first ;
379+ if (idx.exist (e.node .get ())) {
380+ rshape[idx.entry_id (kv.first )] = kv.second ;
381+ }
382+ }
383+ }
384+
385+ std::string shape_attr_key;
386+ if (ret.attrs .count (attr_key_name) != 0 ) {
387+ shape_attr_key = ret.GetAttr <std::string>(attr_key_name);
388+ // erase the provided arguments
389+ ret.attrs .erase (attr_key_name);
390+ }
391+
392+ // limit inference to part of the graph
393+ uint32_t node_start = 0 , node_end = idx.num_nodes ();
394+ if (ret.attrs .count (" node_range" )) {
395+ const auto & range = ret.GetAttr <std::pair<uint32_t , uint32_t > >(" node_range" );
396+ node_start = range.first ;
397+ node_end = range.second ;
398+ CHECK_GE (node_start, 0 );
399+ CHECK_LE (node_end, idx.num_nodes ());
400+ ret.attrs .erase (" node_range" );
401+ }
402+ uint32_t entry_start = 0 , entry_end = idx.num_node_entries ();
403+ if (ret.attrs .count (" entry_range" )) {
404+ const auto & range = ret.GetAttr <std::pair<uint32_t , uint32_t > >(" entry_range" );
405+ entry_start = range.first ;
406+ entry_end = range.second ;
407+ CHECK_GE (entry_start, 0 );
408+ CHECK_LE (entry_end, idx.num_node_entries ());
409+ ret.attrs .erase (" entry_range" );
410+ }
411+ // populate the node attribute vector
412+ if (dispatch_mode_name != nullptr ) {
413+ if (ret.attrs .count (dispatch_mode_name) != 0 ) {
414+ dispatch_modes = ret.MoveCopyAttr <NodeAttrVector>(dispatch_mode_name);
415+ } else {
416+ LOG (FATAL) << " Node attribute " << dispatch_mode_name << " does not exist in the graph" ;
417+ }
418+ }
419+
420+ // Temp space for shape inference.
421+ std::vector<AttrType> ishape, oshape;
422+ // whether a shape is dynamic
423+ std::vector<int > is_dynamic (rshape.size (), 0 );
424+ // inference step function for nid
425+ auto infer_step = [&](uint32_t nid, bool last_iter) {
426+ const auto & inode = idx[nid];
427+ const std::string name = inode.source ->attrs .name ;
428+ const uint32_t num_inputs = inode.inputs .size ();
429+ const uint32_t num_outputs = inode.source ->num_outputs ();
430+ if (inode.source ->is_variable ()) {
431+ // Variable node. No operator. Only one output entry.
432+ CHECK (inode.source ->op () == nullptr );
433+ CHECK_EQ (num_outputs, 1U );
434+ const uint32_t out_ent_id = idx.entry_id (nid, 0 );
435+ if (shape_attr_key.length () != 0 && fis_none (rshape[out_ent_id])) {
436+ auto it = inode.source ->attrs .dict .find (shape_attr_key);
437+ if (it != inode.source ->attrs .dict .end ()) {
438+ std::istringstream is (it->second );
439+ CHECK (is >> rshape[out_ent_id]) << " Invalid attribute" ;
440+ }
441+ }
442+ // assign a default value to node attribute
443+ if (dispatch_mode_name != nullptr ) {
444+ op::dispatch_mode_assign (&dispatch_modes[nid], default_mode_val);
445+ }
446+ } else if (is_backward.get (inode.source ->op (), false ) &&
447+ inode.control_deps .size () && bwd_identity_assign) {
448+ CHECK (dispatch_mode_name == nullptr )
449+ << " Backward inference for node attributes is not available" ;
450+ CHECK_GE (inode.control_deps .size (), 1U )
451+ << " BackwardOp need to have control_deps to its forward op" ;
452+ const IndexedGraph::Node& fnode = idx[inode.control_deps [0 ]];
453+ nnvm::NodePtr fwd_ptr = inode.source ->control_deps [0 ];
454+ CHECK (fwd_ptr->op () != nullptr ) << " Forward op cannot be a variable" ;
455+ // use gradient function to find out the correspondence.
456+ std::vector<nnvm::NodeEntry> ograd (fwd_ptr->num_outputs ());
457+ for (size_t i = 0 ; i < ograd.size (); ++i) {
458+ ograd[i].index = static_cast <uint32_t >(i);
459+ }
460+ // input gradient list
461+ auto igrad = fgrad[fwd_ptr->op ()](fwd_ptr, ograd);
462+ const nnvm::Node* igrad_node = nullptr ;
463+ // Input gradient assignement
464+ for (size_t i = 0 ; i < igrad.size (); ++i) {
465+ if (igrad[i].node ->op () == inode.source ->op ()) {
466+ uint32_t eid = idx.entry_id (nid, igrad[i].index );
467+ if (fis_none (rshape[eid])) {
468+ rshape[eid] = rshape[idx.entry_id (fnode.inputs [i])];
469+ } else if (!fis_none (rshape[idx.entry_id (fnode.inputs [i])])) {
470+ // Need to skip empty forward shape, because it may not be
471+ // available now and it is possible to infer the forward
472+ // shape in one of the next a few passes
473+ CHECK_EQ (rshape[eid], rshape[idx.entry_id (fnode.inputs [i])])
474+ << " Backward shape inconsistent with the forward shape" ;
475+ }
476+ if (igrad_node == nullptr ) {
477+ igrad_node = igrad[i].node .get ();
478+ } else {
479+ CHECK (igrad_node == igrad[i].node .get ());
480+ }
481+ }
482+ }
483+ // out grad entries
484+ CHECK (igrad_node != nullptr )
485+ << " Cannot find matching backward op for " << inode.source ->attrs .name ;
486+ for (size_t i = 0 ; i < igrad_node->inputs .size (); ++i) {
487+ const nnvm::NodeEntry& e = igrad_node->inputs [i];
488+ if (e.node == nullptr ) {
489+ uint32_t eid = idx.entry_id (inode.inputs [i]);
490+ if (fis_none (rshape[eid])) {
491+ rshape[eid] = rshape[idx.entry_id (inode.control_deps [0 ], e.index )];
492+ }
493+ }
494+ }
495+ } else {
496+ DispatchMode* dispatch_mode = nullptr ;
497+ bool forward_known = true ;
498+ // Forward operator inference.
499+ ishape.resize (num_inputs, empty_val);
500+ bool is_input_dynamic_shape = false ;
501+ for (uint32_t i = 0 ; i < ishape.size (); ++i) {
502+ ishape[i] = rshape[idx.entry_id (inode.inputs [i])];
503+ if (ishape[i].ndim () == 0 && is_dynamic[idx.entry_id (inode.inputs [i])]) {
504+ is_input_dynamic_shape = true ;
505+ }
506+ if (fis_none (ishape[i])) forward_known = false ;
507+ }
508+ oshape.resize (num_outputs, empty_val);
509+ for (uint32_t i = 0 ; i < oshape.size (); ++i) {
510+ oshape[i] = rshape[idx.entry_id (nid, i)];
511+ if (fis_none (oshape[i])) forward_known = false ;
512+ }
513+ if (dispatch_mode_name != nullptr ) {
514+ dispatch_mode = &dispatch_modes[nid];
515+ if (dispatch_modes[nid] == DispatchMode::kUndefined ) forward_known = false ;
516+ }
517+ auto finfer = finfer_shape.get (inode.source ->op (), fdefault);
518+ if (finfer == nullptr || is_input_dynamic_shape) {
519+ for (uint32_t i = 0 ; i < oshape.size (); ++i) {
520+ if (oshape[i].ndim () == 0 ) {
521+ is_dynamic[idx.entry_id (nid, i)] = 1 ;
522+ }
523+ }
524+ } else if (!forward_known) {
525+ if (finfer != nullptr ) {
526+ // Call inference function of the operator.
527+ try {
528+ forward_known = ApplyOpInferAttr (ret, finfer, inode.source ->attrs ,
529+ nid, &ishape, &oshape, dispatch_mode);
530+ } catch (const std::exception& e) {
531+ throw dmlc::Error (" Error in operator " + inode.source ->attrs .name + " : " + e.what ());
532+ }
533+ } else {
534+ CHECK (!last_iter)
535+ << " Attribute " << infer_name
536+ << " is not registed by op " << inode.source ->op ()->name
537+ << " we are not able to complete the inference because of this" ;
538+ }
539+ }
540+ // Save to the result map.
541+ for (uint32_t i = 0 ; i < num_inputs; ++i) {
542+ rshape[idx.entry_id (inode.inputs [i])] = ishape[i];
543+ }
544+ for (uint32_t i = 0 ; i < num_outputs; ++i) {
545+ rshape[idx.entry_id (nid, i)] = oshape[i];
546+ }
547+ }
548+ };
549+
550+ size_t last_num_unknown;
551+ size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0 ;
552+ size_t num_unknown_entry_attr = entry_end - entry_start;
553+ size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
554+ int i = 0 ;
555+ do {
556+ if (i % 2 == 0 ) {
557+ for (uint32_t nid = node_start; nid < node_end; ++nid) {
558+ infer_step (nid, false );
559+ }
560+ } else {
561+ // backward inference
562+ for (uint32_t i = node_end; i != node_start; --i) {
563+ infer_step (i - 1 , false );
564+ }
565+ }
566+ last_num_unknown = num_unknown;
567+ num_unknown = 0 ;
568+ for (size_t j = entry_start; j < entry_end; ++j) {
569+ if (fis_none (rshape[j])) {
570+ ++num_unknown;
571+ }
572+ }
573+ if (dispatch_mode_name) {
574+ for (size_t i = node_start; i < node_end; i++) {
575+ if (dispatch_modes[i] == DispatchMode::kUndefined ) ++num_unknown;
576+ }
577+ }
578+ ++i;
579+ } while (num_unknown > 0 && last_num_unknown > num_unknown);
580+ // set the shapes
581+ ret.attrs [attr_name] = std::make_shared<any>(std::move (rshape));
582+ // set the shapes
583+ if (dispatch_mode_name) {
584+ ret.attrs [dispatch_mode_name] = std::make_shared<any>(std::move (dispatch_modes));
585+ }
586+ // number of nodes who knows the shape.
587+ ret.attrs [unknown_name] = std::make_shared<any>(num_unknown);
588+ return ret;
589+ }
590+
325591nnvm::Graph InferShape (nnvm::Graph&& graph,
326592 nnvm::ShapeVector&& shape_inputs,
327593 const std::string& shape_attr_key) {
@@ -332,7 +598,7 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
332598 if (shape_attr_key.length () != 0 ) {
333599 graph.attrs [" shape_attr_key" ] = std::make_shared<any>(shape_attr_key);
334600 }
335- return InferAttr<nnvm::TShape, nnvm::FInferShape> (
601+ return InferShapeAttr (
336602 std::move (graph), nnvm::TShape (),
337603 " FInferShape" , " shape_inputs" , " shape_attr_key" ,
338604 " shape" , " shape_num_unknown_nodes" ,
0 commit comments