Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c7d35db

Browse files
ptrendxliuzh47
andauthored
fix parameter names in the estimator api (#17051) (#17162)
Co-authored-by: liuzh91 <liuzhuanghua1991@gmail.com>
1 parent 80a850d commit c7d35db

File tree

4 files changed

+35
-35
lines changed

4 files changed

+35
-35
lines changed

python/mxnet/gluon/contrib/estimator/batch_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def evaluate_batch(self, estimator,
6161
Batch axis to split the validation data into devices.
6262
"""
6363
data, label = self._get_data_and_label(val_batch, estimator.context, batch_axis)
64-
pred = [estimator.eval_net(x) for x in data]
65-
loss = [estimator.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
64+
pred = [estimator.val_net(x) for x in data]
65+
loss = [estimator.val_loss(y_hat, y) for y_hat, y in zip(pred, label)]
6666

6767
return data, label, pred, loss
6868

python/mxnet/gluon/contrib/estimator/estimator.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,29 @@ class Estimator(object):
6161
Trainer to apply optimizer on network parameters.
6262
context : Context or list of Context
6363
Device(s) to run the training on.
64-
evaluation_loss : gluon.loss.loss
65-
Loss (objective) function to calculate during validation. If set evaluation_loss
66-
None, it will use the same loss function as self.loss
67-
eval_net : gluon.Block
64+
val_net : gluon.Block
6865
The model used for validation. The validation model does not necessarily belong to
6966
the same model class as the training model. But the two models typically share the
7067
same architecture. Therefore the validation model can reuse parameters of the
7168
training model.
7269
73-
The code example of consruction of eval_net sharing the same network parameters as
70+
The code example of consruction of val_net sharing the same network parameters as
7471
the training net is given below:
7572
7673
>>> net = _get_train_network()
77-
>>> eval_net = _get_test_network(params=net.collect_params())
74+
>>> val_net = _get_test_network(params=net.collect_params())
7875
>>> net.initialize(ctx=ctx)
79-
>>> est = Estimator(net, loss, eval_net=eval_net)
76+
>>> est = Estimator(net, loss, val_net=val_net)
8077
8178
Proper namespace match is required for weight sharing between two networks. Most networks
8279
inheriting :py:class:`Block` can share their parameters correctly. An exception is
8380
Sequential networks that Block scope must be specified for correct weight sharing. For
8481
the naming in mxnet Gluon API, please refer to the site
8582
(https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/naming.html)
8683
for future information.
84+
val_loss : gluon.loss.loss
85+
Loss (objective) function to calculate during validation. If set val_loss
86+
None, it will use the same loss function as self.loss
8787
batch_processor: BatchProcessor
8888
BatchProcessor provides customized fit_batch() and evaluate_batch() methods
8989
"""
@@ -113,21 +113,21 @@ def __init__(self, net,
113113
initializer=None,
114114
trainer=None,
115115
context=None,
116-
evaluation_loss=None,
117-
eval_net=None,
116+
val_net=None,
117+
val_loss=None,
118118
batch_processor=None):
119119
self.net = net
120120
self.loss = self._check_loss(loss)
121121
self._train_metrics = _check_metrics(train_metrics)
122122
self._val_metrics = _check_metrics(val_metrics)
123123
self._add_default_training_metrics()
124124
self._add_validation_metrics()
125-
self.evaluation_loss = self.loss
126-
if evaluation_loss is not None:
127-
self.evaluation_loss = self._check_loss(evaluation_loss)
128-
self.eval_net = self.net
129-
if eval_net is not None:
130-
self.eval_net = eval_net
125+
self.val_loss = self.loss
126+
if val_loss is not None:
127+
self.val_loss = self._check_loss(val_loss)
128+
self.val_net = self.net
129+
if val_net is not None:
130+
self.val_net = val_net
131131

132132
self.logger = logging.Logger(name='Estimator', level=logging.INFO)
133133
self.logger.addHandler(logging.StreamHandler(sys.stdout))

tests/python/unittest/test_gluon_batch_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_batch_processor_validation():
8484
ctx = mx.cpu()
8585
loss = gluon.loss.L2Loss()
8686
acc = mx.metric.Accuracy()
87-
evaluation_loss = gluon.loss.L1Loss()
87+
val_loss = gluon.loss.L1Loss()
8888
net.initialize(ctx=ctx)
8989
processor = BatchProcessor()
9090
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
@@ -93,7 +93,7 @@ def test_batch_processor_validation():
9393
train_metrics=acc,
9494
trainer=trainer,
9595
context=ctx,
96-
evaluation_loss=evaluation_loss,
96+
val_loss=val_loss,
9797
batch_processor=processor)
9898
# Input dataloader
9999
est.fit(train_data=dataloader,

tests/python/unittest/test_gluon_estimator.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def test_validation():
8888
ctx = mx.cpu()
8989
loss = gluon.loss.L2Loss()
9090
acc = mx.metric.Accuracy()
91-
evaluation_loss = gluon.loss.L1Loss()
91+
val_loss = gluon.loss.L1Loss()
9292
net.initialize(ctx=ctx)
9393
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
9494
est = Estimator(net=net,
9595
loss=loss,
9696
train_metrics=acc,
9797
trainer=trainer,
9898
context=ctx,
99-
evaluation_loss=evaluation_loss)
99+
val_loss=val_loss)
100100
# Input dataloader
101101
est.fit(train_data=dataloader,
102102
val_data=dataloader,
@@ -376,16 +376,16 @@ def test_default_handlers():
376376
assert isinstance(handlers[1], MetricHandler)
377377
assert isinstance(handlers[4], LoggingHandler)
378378

379-
def test_eval_net():
380-
''' test estimator with a different evaluation net '''
379+
def test_val_net():
380+
''' test estimator with different training and validation networks '''
381381
''' test weight sharing of sequential networks without namescope '''
382382
net = _get_test_network()
383-
eval_net = _get_test_network(params=net.collect_params())
383+
val_net = _get_test_network(params=net.collect_params())
384384
dataloader, dataiter = _get_test_data()
385385
num_epochs = 1
386386
ctx = mx.cpu()
387387
loss = gluon.loss.L2Loss()
388-
evaluation_loss = gluon.loss.L2Loss()
388+
val_loss = gluon.loss.L2Loss()
389389
acc = mx.metric.Accuracy()
390390
net.initialize(ctx=ctx)
391391
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
@@ -394,8 +394,8 @@ def test_eval_net():
394394
train_metrics=acc,
395395
trainer=trainer,
396396
context=ctx,
397-
evaluation_loss=evaluation_loss,
398-
eval_net=eval_net)
397+
val_loss=val_loss,
398+
val_net=val_net)
399399

400400
with assert_raises(RuntimeError):
401401
est.fit(train_data=dataloader,
@@ -404,16 +404,16 @@ def test_eval_net():
404404

405405
''' test weight sharing of sequential networks with namescope '''
406406
net = _get_test_network_with_namescope()
407-
eval_net = _get_test_network_with_namescope(params=net.collect_params())
407+
val_net = _get_test_network_with_namescope(params=net.collect_params())
408408
net.initialize(ctx=ctx)
409409
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
410410
est = Estimator(net=net,
411411
loss=loss,
412412
train_metrics=acc,
413413
trainer=trainer,
414414
context=ctx,
415-
evaluation_loss=evaluation_loss,
416-
eval_net=eval_net)
415+
val_loss=val_loss,
416+
val_net=val_net)
417417

418418
est.fit(train_data=dataloader,
419419
val_data=dataloader,
@@ -422,20 +422,20 @@ def test_eval_net():
422422
''' test weight sharing of two resnets '''
423423
net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
424424
net.output = gluon.nn.Dense(10)
425-
eval_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
426-
eval_net.output = gluon.nn.Dense(10, params=net.collect_params())
425+
val_net = gluon.model_zoo.vision.resnet18_v1(pretrained=False, ctx=ctx)
426+
val_net.output = gluon.nn.Dense(10, params=net.collect_params())
427427
dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10)))
428428
dataloader = gluon.data.DataLoader(dataset=dataset, batch_size=5)
429429
net.initialize(ctx=ctx)
430-
eval_net.initialize(ctx=ctx)
430+
val_net.initialize(ctx=ctx)
431431
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
432432
est = Estimator(net=net,
433433
loss=loss,
434434
train_metrics=acc,
435435
trainer=trainer,
436436
context=ctx,
437-
evaluation_loss=evaluation_loss,
438-
eval_net=eval_net)
437+
val_loss=val_loss,
438+
val_net=val_net)
439439

440440
est.fit(train_data=dataloader,
441441
val_data=dataloader,

0 commit comments

Comments
 (0)