@@ -375,41 +375,30 @@ struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
375375 }
376376};
377377
378- template <typename xpu>
379- void NumpyMoveaxisCompute (const nnvm::NodeAttrs& attrs,
380- const OpContext& ctx,
381- const std::vector<TBlob>& inputs,
382- const std::vector<OpReqType>& req,
383- const std::vector<TBlob>& outputs) {
384- using namespace mshadow ;
385- using namespace mshadow ::expr;
378+ inline mxnet::TShape NumpyMoveaxisShapeImpl (const nnvm::NodeAttrs& attrs,
379+ const int & ndim) {
386380 const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed );
387- CHECK_EQ (inputs.size (), 1U );
388- CHECK_EQ (outputs.size (), 1U );
389- CHECK_EQ (req[0 ], kWriteTo ) << " Moveaxis does not support inplace" ;
390- mxnet::TShape axes (inputs[0 ].ndim (), -1 );
381+ mxnet::TShape axes (ndim, -1 );
382+ std::vector<bool > state_axes (ndim, false );
391383 mxnet::TShape real_src (param.source .ndim (), -1 );
392384 mxnet::TShape real_des (param.destination .ndim (), -1 );
393- std::vector<bool > state_axes (inputs[0 ].ndim (), false );
394- CHECK_EQ (param.source .ndim (), param.destination .ndim ())
395- << " source and destination not equal." ;
396385 for (int i = 0 ; i < param.source .ndim (); ++i) {
397386 if (param.source [i] >= 0 ) {
398- CHECK_LT (static_cast <size_t >(param.source [i]), inputs[ 0 ]. ndim () );
387+ CHECK_LT (static_cast <size_t >(param.source [i]), ndim);
399388 real_src[i] = param.source [i];
400389 } else {
401- CHECK_LT (param.source [i] + inputs[ 0 ]. ndim (), inputs[ 0 ]. ndim () );
402- real_src[i] = param.source [i] + inputs[ 0 ]. ndim () ;
390+ CHECK_LT (param.source [i] + ndim, ndim);
391+ real_src[i] = param.source [i] + ndim;
403392 }
404393 if (param.destination [i] >= 0 ) {
405- CHECK_LT (static_cast <size_t >(param.destination [i]), inputs[ 0 ]. ndim () );
394+ CHECK_LT (static_cast <size_t >(param.destination [i]), ndim);
406395 real_des[i] = param.destination [i];
407396 } else {
408- CHECK_LT (param.destination [i] + inputs[ 0 ]. ndim (), inputs[ 0 ]. ndim () );
409- real_des[i] = param.destination [i] + inputs[ 0 ]. ndim () ;
397+ CHECK_LT (param.destination [i] + ndim, ndim);
398+ real_des[i] = param.destination [i] + ndim;
410399 }
411400 }
412- if (inputs[ 0 ]. ndim () > 1 ) {
401+ if (ndim > 1 ) {
413402 for (int i = 0 ; i < param.source .ndim () - 1 ; ++i) {
414403 for (int j = i + 1 ; j < param.source .ndim (); ++j) {
415404 CHECK_NE (real_src[i], real_src[j])
@@ -434,6 +423,25 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
434423 }
435424 }
436425 }
426+ return axes;
427+ }
428+
429+ template <typename xpu>
430+ void NumpyMoveaxisCompute (const nnvm::NodeAttrs& attrs,
431+ const OpContext& ctx,
432+ const std::vector<TBlob>& inputs,
433+ const std::vector<OpReqType>& req,
434+ const std::vector<TBlob>& outputs) {
435+ using namespace mshadow ;
436+ using namespace mshadow ::expr;
437+ const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed );
438+ CHECK_EQ (inputs.size (), 1U );
439+ CHECK_EQ (outputs.size (), 1U );
440+ CHECK_EQ (req[0 ], kWriteTo ) << " Moveaxis does not support inplace" ;
441+ CHECK_EQ (param.source .ndim (), param.destination .ndim ())
442+ << " source and destination not equal." ;
443+ mxnet::TShape axes;
444+ axes = NumpyMoveaxisShapeImpl (attrs, inputs[0 ].ndim ());
437445 MSHADOW_TYPE_SWITCH (outputs[0 ].type_flag_ , Dtype, {
438446 TransposeImpl<xpu>(ctx.run_ctx , inputs[0 ], outputs[0 ], axes);
439447 })
0 commit comments