Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit a4b85a5

Browse files
wkcnnswamy
authored andcommitted
Support SyncBatchNorm5D (#14542)
* support SyncBatchNorm5D * fix * update testcase and reformat code * retrigger CI * update test case * test * Retrigger CI * disable cudnn for batchnorm * fix BatchNorm(cudnn) * fix build * Remove a testcase * Update sync_batch_norm-inl.h * update unittest * update unittest * update test * fix test * change atol and rtol * BN(cudnn) 5d * update test * test * Testing * Update batch_norm.cu * test cudnnoff * Update test_operator.py * update BN! : )
1 parent e5aadca commit a4b85a5

File tree

6 files changed

+355
-152
lines changed

6 files changed

+355
-152
lines changed

src/operator/contrib/sync_batch_norm-inl.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ struct SyncBatchNormParam : public dmlc::Parameter<SyncBatchNormParam> {
6969
DMLC_DECLARE_FIELD(ndev).set_default(1)
7070
.describe("The count of GPU devices");
7171
DMLC_DECLARE_FIELD(key)
72-
.set_default("")
7372
.describe("Hash key for synchronization, please set the same hash key for same layer, "
7473
"Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`.");
7574
}
@@ -275,14 +274,18 @@ class SyncBatchNorm : public Operator {
275274
static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
276275
Tensor<xpu, 4> data;
277276
Tensor<xpu, 4> out;
278-
if (in_data[syncbatchnorm::kData].ndim() == 2) {
277+
if (in_data[syncbatchnorm::kData].ndim() == 4) {
278+
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
279+
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
280+
} else {
281+
index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ?
282+
in_data[syncbatchnorm::kData].shape_[1] : 1;
283+
index_t spatial_size = in_data[syncbatchnorm::kData].shape_.ProdShape(2,
284+
in_data[syncbatchnorm::kData].ndim());
279285
Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
280-
in_data[syncbatchnorm::kData].shape_[1], 1, 1);
286+
num_channels, 1, spatial_size);
281287
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
282288
out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
283-
} else {
284-
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
285-
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
286289
}
287290
Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
288291
Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
@@ -354,16 +357,20 @@ class SyncBatchNorm : public Operator {
354357
Tensor<xpu, 4> data, grad, grad_in;
355358
const real_t scale = static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
356359
static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
357-
if (in_data[syncbatchnorm::kData].ndim() == 2) {
360+
if (in_data[syncbatchnorm::kData].ndim() == 4) {
361+
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
362+
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
363+
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
364+
} else {
365+
index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ?
366+
out_grad[syncbatchnorm::kOut].shape_[1] : 1;
367+
index_t spatial_size = out_grad[syncbatchnorm::kOut].shape_.ProdShape(2,
368+
out_grad[syncbatchnorm::kOut].ndim());
358369
Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
359-
out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
370+
num_channels, 1, spatial_size);
360371
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
361372
grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
362373
grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
363-
} else {
364-
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
365-
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
366-
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
367374
}
368375

369376
Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);

src/operator/nn/batch_norm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
668668

669669
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
670670
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
671-
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
671+
if (!param.use_global_stats && !param.cudnn_off
672672
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
673673
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
674674
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
@@ -697,7 +697,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
697697

698698
param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
699699
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
700-
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
700+
if (!param.use_global_stats && !param.cudnn_off
701701
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
702702
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
703703
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);

src/operator/nn/cudnn/cudnn_batch_norm-inl.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class CuDNNBatchNormOp {
8484
}
8585
CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
8686
CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
87-
CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4);
8887

8988
Init(in_data[cudnnbatchnorm::kData]);
9089
Stream<gpu> *s = ctx.get_stream<gpu>();
@@ -273,12 +272,15 @@ class CuDNNBatchNormOp {
273272

274273
private:
275274
void Init(const TBlob &in_data) {
276-
for (int i = 0; i < 4; ++i) {
277-
if (i < in_data.ndim()) {
275+
if (in_data.ndim() == 4) {
276+
for (int i = 0; i < 4; ++i)
278277
shape_[i] = in_data.shape_[i];
279-
} else {
280-
shape_[i] = 1;
281-
}
278+
} else {
279+
// when in_data.ndim() != 4
280+
shape_[0] = in_data.shape_[0];
281+
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
282+
shape_[2] = 1;
283+
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
282284
}
283285

284286
CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,

0 commit comments

Comments
 (0)