Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0af40f7

Browse files
junrushaoszha
authored andcommitted
[MXNET-1325] Make InferShapeAttr a standalone pass (#14193)
* Make InferShapeAttr a standalone pass * Fix * Fix * Fix
1 parent 5f32f32 commit 0af40f7

File tree

1 file changed

+267
-1
lines changed

1 file changed

+267
-1
lines changed

src/executor/infer_graph_attr_pass.cc

Lines changed: 267 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,272 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
322322
return ret;
323323
}
324324

325+
template<typename IsNone, typename FDefault>
326+
nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
327+
const nnvm::TShape empty_val,
328+
const char* infer_name,
329+
const char* input_name,
330+
const char* attr_key_name,
331+
const char* attr_name,
332+
const char* unknown_name,
333+
IsNone fis_none,
334+
FDefault fdefault,
335+
bool bwd_identity_assign,
336+
const char* dispatch_mode_name,
337+
const DispatchMode default_mode_val = DispatchMode::kUndefined) {
338+
using nnvm::IndexedGraph;
339+
using nnvm::Op;
340+
using AttrType = nnvm::TShape;
341+
using FInferType = nnvm::FInferShape;
342+
using AttrVector = std::vector<AttrType>;
343+
using NodeAttrVector = std::vector<DispatchMode>;
344+
using dmlc::any;
345+
const IndexedGraph& idx = ret.indexed_graph();
346+
static auto& finfer_shape =
347+
Op::GetAttr<FInferType>(infer_name);
348+
static auto& is_backward =
349+
Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
350+
// gradient function, used to get node correspondence.
351+
static auto& fgrad =
352+
Op::GetAttr<nnvm::FGradient>("FGradient");
353+
// reshape shape vector
354+
AttrVector rshape;
355+
// dispatch mode vector
356+
DispatchModeVector dispatch_modes;
357+
if (ret.attrs.count(attr_name) != 0) {
358+
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
359+
} else {
360+
rshape.resize(idx.num_node_entries(), empty_val);
361+
}
362+
363+
if (ret.attrs.count(input_name) != 0) {
364+
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
365+
CHECK_LE(shape_args.size(), idx.input_nodes().size())
366+
<< "More provided " << attr_name << "s than number of arguments.";
367+
for (size_t i = 0; i < shape_args.size(); ++i) {
368+
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
369+
}
370+
}
371+
372+
// get the shape hints
373+
std::string shape_hints_key = std::string(attr_name) + "_hints";
374+
if (ret.attrs.count(shape_hints_key)) {
375+
nnvm::NodeEntryMap<AttrType> shape_hints =
376+
ret.GetAttr<nnvm::NodeEntryMap<AttrType>>(shape_hints_key);
377+
for (const auto& kv : shape_hints) {
378+
nnvm::NodeEntry e = kv.first;
379+
if (idx.exist(e.node.get())) {
380+
rshape[idx.entry_id(kv.first)] = kv.second;
381+
}
382+
}
383+
}
384+
385+
std::string shape_attr_key;
386+
if (ret.attrs.count(attr_key_name) != 0) {
387+
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
388+
// erase the provided arguments
389+
ret.attrs.erase(attr_key_name);
390+
}
391+
392+
// limit inference to part of the graph
393+
uint32_t node_start = 0, node_end = idx.num_nodes();
394+
if (ret.attrs.count("node_range")) {
395+
const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("node_range");
396+
node_start = range.first;
397+
node_end = range.second;
398+
CHECK_GE(node_start, 0);
399+
CHECK_LE(node_end, idx.num_nodes());
400+
ret.attrs.erase("node_range");
401+
}
402+
uint32_t entry_start = 0, entry_end = idx.num_node_entries();
403+
if (ret.attrs.count("entry_range")) {
404+
const auto& range = ret.GetAttr<std::pair<uint32_t, uint32_t> >("entry_range");
405+
entry_start = range.first;
406+
entry_end = range.second;
407+
CHECK_GE(entry_start, 0);
408+
CHECK_LE(entry_end, idx.num_node_entries());
409+
ret.attrs.erase("entry_range");
410+
}
411+
// populate the node attribute vector
412+
if (dispatch_mode_name != nullptr) {
413+
if (ret.attrs.count(dispatch_mode_name) != 0) {
414+
dispatch_modes = ret.MoveCopyAttr<NodeAttrVector>(dispatch_mode_name);
415+
} else {
416+
LOG(FATAL) << "Node attribute " << dispatch_mode_name << " does not exist in the graph";
417+
}
418+
}
419+
420+
// Temp space for shape inference.
421+
std::vector<AttrType> ishape, oshape;
422+
// whether a shape is dynamic
423+
std::vector<int> is_dynamic(rshape.size(), 0);
424+
// inference step function for nid
425+
auto infer_step = [&](uint32_t nid, bool last_iter) {
426+
const auto& inode = idx[nid];
427+
const std::string name = inode.source->attrs.name;
428+
const uint32_t num_inputs = inode.inputs.size();
429+
const uint32_t num_outputs = inode.source->num_outputs();
430+
if (inode.source->is_variable()) {
431+
// Variable node. No operator. Only one output entry.
432+
CHECK(inode.source->op() == nullptr);
433+
CHECK_EQ(num_outputs, 1U);
434+
const uint32_t out_ent_id = idx.entry_id(nid, 0);
435+
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
436+
auto it = inode.source->attrs.dict.find(shape_attr_key);
437+
if (it != inode.source->attrs.dict.end()) {
438+
std::istringstream is(it->second);
439+
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
440+
}
441+
}
442+
// assign a default value to node attribute
443+
if (dispatch_mode_name != nullptr) {
444+
op::dispatch_mode_assign(&dispatch_modes[nid], default_mode_val);
445+
}
446+
} else if (is_backward.get(inode.source->op(), false) &&
447+
inode.control_deps.size() && bwd_identity_assign) {
448+
CHECK(dispatch_mode_name == nullptr)
449+
<< "Backward inference for node attributes is not available";
450+
CHECK_GE(inode.control_deps.size(), 1U)
451+
<< "BackwardOp need to have control_deps to its forward op";
452+
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
453+
nnvm::NodePtr fwd_ptr = inode.source->control_deps[0];
454+
CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable";
455+
// use gradient function to find out the correspondence.
456+
std::vector<nnvm::NodeEntry> ograd(fwd_ptr->num_outputs());
457+
for (size_t i = 0; i < ograd.size(); ++i) {
458+
ograd[i].index = static_cast<uint32_t>(i);
459+
}
460+
// input gradient list
461+
auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd);
462+
const nnvm::Node* igrad_node = nullptr;
463+
// Input gradient assignement
464+
for (size_t i = 0; i < igrad.size(); ++i) {
465+
if (igrad[i].node->op() == inode.source->op()) {
466+
uint32_t eid = idx.entry_id(nid, igrad[i].index);
467+
if (fis_none(rshape[eid])) {
468+
rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])];
469+
} else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) {
470+
// Need to skip empty forward shape, because it may not be
471+
// available now and it is possible to infer the forward
472+
// shape in one of the next a few passes
473+
CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])])
474+
<< "Backward shape inconsistent with the forward shape";
475+
}
476+
if (igrad_node == nullptr) {
477+
igrad_node = igrad[i].node.get();
478+
} else {
479+
CHECK(igrad_node == igrad[i].node.get());
480+
}
481+
}
482+
}
483+
// out grad entries
484+
CHECK(igrad_node != nullptr)
485+
<< "Cannot find matching backward op for " << inode.source->attrs.name;
486+
for (size_t i = 0; i < igrad_node->inputs.size(); ++i) {
487+
const nnvm::NodeEntry& e = igrad_node->inputs[i];
488+
if (e.node == nullptr) {
489+
uint32_t eid = idx.entry_id(inode.inputs[i]);
490+
if (fis_none(rshape[eid])) {
491+
rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)];
492+
}
493+
}
494+
}
495+
} else {
496+
DispatchMode* dispatch_mode = nullptr;
497+
bool forward_known = true;
498+
// Forward operator inference.
499+
ishape.resize(num_inputs, empty_val);
500+
bool is_input_dynamic_shape = false;
501+
for (uint32_t i = 0; i < ishape.size(); ++i) {
502+
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
503+
if (ishape[i].ndim() == 0 && is_dynamic[idx.entry_id(inode.inputs[i])]) {
504+
is_input_dynamic_shape = true;
505+
}
506+
if (fis_none(ishape[i])) forward_known = false;
507+
}
508+
oshape.resize(num_outputs, empty_val);
509+
for (uint32_t i = 0; i < oshape.size(); ++i) {
510+
oshape[i] = rshape[idx.entry_id(nid, i)];
511+
if (fis_none(oshape[i])) forward_known = false;
512+
}
513+
if (dispatch_mode_name != nullptr) {
514+
dispatch_mode = &dispatch_modes[nid];
515+
if (dispatch_modes[nid] == DispatchMode::kUndefined) forward_known = false;
516+
}
517+
auto finfer = finfer_shape.get(inode.source->op(), fdefault);
518+
if (finfer == nullptr || is_input_dynamic_shape) {
519+
for (uint32_t i = 0; i < oshape.size(); ++i) {
520+
if (oshape[i].ndim() == 0) {
521+
is_dynamic[idx.entry_id(nid, i)] = 1;
522+
}
523+
}
524+
} else if (!forward_known) {
525+
if (finfer != nullptr) {
526+
// Call inference function of the operator.
527+
try {
528+
forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs,
529+
nid, &ishape, &oshape, dispatch_mode);
530+
} catch (const std::exception& e) {
531+
throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
532+
}
533+
} else {
534+
CHECK(!last_iter)
535+
<< "Attribute " << infer_name
536+
<< " is not registed by op " << inode.source->op()->name
537+
<< " we are not able to complete the inference because of this";
538+
}
539+
}
540+
// Save to the result map.
541+
for (uint32_t i = 0; i < num_inputs; ++i) {
542+
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
543+
}
544+
for (uint32_t i = 0; i < num_outputs; ++i) {
545+
rshape[idx.entry_id(nid, i)] = oshape[i];
546+
}
547+
}
548+
};
549+
550+
size_t last_num_unknown;
551+
size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - node_start : 0;
552+
size_t num_unknown_entry_attr = entry_end - entry_start;
553+
size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
554+
int i = 0;
555+
do {
556+
if (i % 2 == 0) {
557+
for (uint32_t nid = node_start; nid < node_end; ++nid) {
558+
infer_step(nid, false);
559+
}
560+
} else {
561+
// backward inference
562+
for (uint32_t i = node_end; i != node_start; --i) {
563+
infer_step(i - 1, false);
564+
}
565+
}
566+
last_num_unknown = num_unknown;
567+
num_unknown = 0;
568+
for (size_t j = entry_start; j < entry_end; ++j) {
569+
if (fis_none(rshape[j])) {
570+
++num_unknown;
571+
}
572+
}
573+
if (dispatch_mode_name) {
574+
for (size_t i = node_start; i < node_end; i++) {
575+
if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
576+
}
577+
}
578+
++i;
579+
} while (num_unknown > 0 && last_num_unknown > num_unknown);
580+
// set the shapes
581+
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
582+
// set the shapes
583+
if (dispatch_mode_name) {
584+
ret.attrs[dispatch_mode_name] = std::make_shared<any>(std::move(dispatch_modes));
585+
}
586+
// number of nodes who knows the shape.
587+
ret.attrs[unknown_name] = std::make_shared<any>(num_unknown);
588+
return ret;
589+
}
590+
325591
nnvm::Graph InferShape(nnvm::Graph&& graph,
326592
nnvm::ShapeVector&& shape_inputs,
327593
const std::string& shape_attr_key) {
@@ -332,7 +598,7 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
332598
if (shape_attr_key.length() != 0) {
333599
graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
334600
}
335-
return InferAttr<nnvm::TShape, nnvm::FInferShape>(
601+
return InferShapeAttr(
336602
std::move(graph), nnvm::TShape(),
337603
"FInferShape", "shape_inputs", "shape_attr_key",
338604
"shape", "shape_num_unknown_nodes",

0 commit comments

Comments
 (0)