-
Notifications
You must be signed in to change notification settings - Fork 602
Problems for CIFAR experiments #21
Description
When I try to run problem cifar and cifar-multi experiments, I run into an error that the boolean variable is_training is not specified as follows:
ValueError: Boolean is_training flag must be explicitly specified when using batch normalization.
originally defined at:
File "train.py", line 61, in main
problem, net_config, net_assignments = util.get_config(FLAGS.problem)
File "/qydata/wwangbc/code/learning_to_optimize/l2l/util.py", line 113, in get_config
mode=mode)
File "/qydata/wwangbc/code/learning_to_optimize/l2l/problems.py", line 258, in cifar10
use_batch_norm=batch_norm)
File "/qydata/wwangbc/bin/anaconda/lib/python2.7/site-packages/sonnet/python/modules/nets/convnet.py", line 142, in init
super(ConvNet2D, self).init(name=name)
File "/qydata/wwangbc/bin/anaconda/lib/python2.7/site-packages/sonnet/python/modules/base.py", line 124, in init
custom_getter_=self.custom_getter)
originally defined at:
File "train.py", line 61, in main
problem, net_config, net_assignments = util.get_config(FLAGS.problem)
File "/qydata/wwangbc/code/learning_to_optimize/l2l/util.py", line 113, in get_config
mode=mode)
File "/qydata/wwangbc/code/learning_to_optimize/l2l/problems.py", line 268, in cifar10
network = snt.Sequential([conv, snt.BatchFlatten(), mlp])
File "/qydata/wwangbc/bin/anaconda/lib/python2.7/site-packages/sonnet/python/modules/sequential.py", line 65, in init
super(Sequential, self).init(name=name)
File "/qydata/wwangbc/bin/anaconda/lib/python2.7/site-packages/sonnet/python/modules/base.py", line 124, in init
custom_getter=self._custom_getter)
I think the is_training should be passed for the BN of both Conv2d and MLP and snt.Sequential function seems to be misused since we need to pass extra build arguments.
What's more, the code "network = snt.Sequential([conv, snt.BatchFlatten(), mlp])" shows that there is only one convolution layer in the network while there should be 3 in the paper.
Could you please fix the bug and implement the complete 3-layer CNN network?
Thanks a lot!