Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8ce4075

Browse files
committed
refactor moveaxis code
1 parent 9753eb1 commit 8ce4075

File tree

2 files changed

+32
-68
lines changed

2 files changed

+32
-68
lines changed

src/operator/numpy/np_matrix_op-inl.h

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -375,41 +375,30 @@ struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
375375
}
376376
};
377377

378-
template<typename xpu>
379-
void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
380-
const OpContext& ctx,
381-
const std::vector<TBlob>& inputs,
382-
const std::vector<OpReqType>& req,
383-
const std::vector<TBlob>& outputs) {
384-
using namespace mshadow;
385-
using namespace mshadow::expr;
378+
inline mxnet::TShape NumpyMoveaxisShapeImpl(const nnvm::NodeAttrs& attrs,
379+
const int& ndim) {
386380
const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed);
387-
CHECK_EQ(inputs.size(), 1U);
388-
CHECK_EQ(outputs.size(), 1U);
389-
CHECK_EQ(req[0], kWriteTo) << "Moveaxis does not support inplace";
390-
mxnet::TShape axes(inputs[0].ndim(), -1);
381+
mxnet::TShape axes(ndim, -1);
382+
std::vector<bool> state_axes(ndim, false);
391383
mxnet::TShape real_src(param.source.ndim(), -1);
392384
mxnet::TShape real_des(param.destination.ndim(), -1);
393-
std::vector<bool> state_axes(inputs[0].ndim(), false);
394-
CHECK_EQ(param.source.ndim(), param.destination.ndim())
395-
<< "source and destination not equal.";
396385
for (int i = 0; i < param.source.ndim(); ++i) {
397386
if (param.source[i] >= 0) {
398-
CHECK_LT(static_cast<size_t>(param.source[i]), inputs[0].ndim());
387+
CHECK_LT(static_cast<size_t>(param.source[i]), ndim);
399388
real_src[i] = param.source[i];
400389
} else {
401-
CHECK_LT(param.source[i] + inputs[0].ndim(), inputs[0].ndim());
402-
real_src[i] = param.source[i] + inputs[0].ndim();
390+
CHECK_LT(param.source[i] + ndim, ndim);
391+
real_src[i] = param.source[i] + ndim;
403392
}
404393
if (param.destination[i] >= 0) {
405-
CHECK_LT(static_cast<size_t>(param.destination[i]), inputs[0].ndim());
394+
CHECK_LT(static_cast<size_t>(param.destination[i]), ndim);
406395
real_des[i] = param.destination[i];
407396
} else {
408-
CHECK_LT(param.destination[i] + inputs[0].ndim(), inputs[0].ndim());
409-
real_des[i] = param.destination[i] + inputs[0].ndim();
397+
CHECK_LT(param.destination[i] + ndim, ndim);
398+
real_des[i] = param.destination[i] + ndim;
410399
}
411400
}
412-
if (inputs[0].ndim() > 1) {
401+
if (ndim > 1) {
413402
for (int i = 0; i < param.source.ndim() - 1; ++i) {
414403
for (int j = i + 1; j < param.source.ndim(); ++j) {
415404
CHECK_NE(real_src[i], real_src[j])
@@ -434,6 +423,25 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
434423
}
435424
}
436425
}
426+
return axes;
427+
}
428+
429+
template<typename xpu>
430+
void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
431+
const OpContext& ctx,
432+
const std::vector<TBlob>& inputs,
433+
const std::vector<OpReqType>& req,
434+
const std::vector<TBlob>& outputs) {
435+
using namespace mshadow;
436+
using namespace mshadow::expr;
437+
const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed);
438+
CHECK_EQ(inputs.size(), 1U);
439+
CHECK_EQ(outputs.size(), 1U);
440+
CHECK_EQ(req[0], kWriteTo) << "Moveaxis does not support inplace";
441+
CHECK_EQ(param.source.ndim(), param.destination.ndim())
442+
<< "source and destination not equal.";
443+
mxnet::TShape axes;
444+
axes = NumpyMoveaxisShapeImpl(attrs, inputs[0].ndim());
437445
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
438446
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
439447
})

src/operator/numpy/np_matrix_op.cc

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -625,51 +625,8 @@ bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs,
625625
CHECK_EQ(param.source.ndim(), param.destination.ndim())
626626
<< "source and destination not equal.";
627627
mxnet::TShape ret(shp.ndim(), -1);
628-
mxnet::TShape axes(shp.ndim(), -1);
629-
std::vector<bool> state_axes(shp.ndim(), false);
630-
mxnet::TShape real_src(param.source.ndim(), -1);
631-
mxnet::TShape real_des(param.destination.ndim(), -1);
632-
for (int i = 0; i < param.source.ndim(); ++i) {
633-
if (param.source[i] >= 0) {
634-
CHECK_LT(static_cast<size_t>(param.source[i]), shp.ndim());
635-
real_src[i] = param.source[i];
636-
} else {
637-
CHECK_LT(param.source[i] + shp.ndim(), shp.ndim());
638-
real_src[i] = param.source[i] + shp.ndim();
639-
}
640-
if (param.destination[i] >= 0) {
641-
CHECK_LT(static_cast<size_t>(param.destination[i]), shp.ndim());
642-
real_des[i] = param.destination[i];
643-
} else {
644-
CHECK_LT(param.destination[i] + shp.ndim(), shp.ndim());
645-
real_des[i] = param.destination[i] + shp.ndim();
646-
}
647-
}
648-
if (shp.ndim() > 1) {
649-
for (int i = 0; i < param.source.ndim() - 1; ++i) {
650-
for (int j = i + 1; j < param.source.ndim(); ++j) {
651-
CHECK_NE(real_src[i], real_src[j])
652-
<< "repeated axis in `source` argument";
653-
CHECK_NE(real_des[i], real_des[j])
654-
<< "repeated axis in `destination` argument";
655-
}
656-
}
657-
}
658-
for (int i = 0; i < param.source.ndim(); ++i) {
659-
axes[real_des[i]] = real_src[i];
660-
state_axes[real_src[i]] = true;
661-
}
662-
for (int i = 0; i < axes.ndim(); ++i) {
663-
if (axes[i] < 0) {
664-
for (int j = 0; j < axes.ndim(); ++j) {
665-
if (state_axes[j] == false) {
666-
axes[i] = j;
667-
state_axes[j] = true;
668-
break;
669-
}
670-
}
671-
}
672-
}
628+
mxnet::TShape axes;
629+
axes = NumpyMoveaxisShapeImpl(attrs, shp.ndim());
673630
for (int i = 0; i < shp.ndim(); ++i) {
674631
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
675632
ret[i] = shp[axes[i]];
@@ -745,7 +702,6 @@ inline bool NumpyRot90Shape(const nnvm::NodeAttrs& attrs,
745702
res[real_axes[0]] += res[real_axes[1]];
746703
res[real_axes[1]] = res[real_axes[0]] - res[real_axes[1]];
747704
res[real_axes[0]] -= res[real_axes[1]];
748-
749705
SHAPE_ASSIGN_CHECK(*out_attrs, 0, res);
750706
return shape_is_known(res);
751707
}

0 commit comments

Comments
 (0)