Skip to content
11 changes: 10 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
from keras.src import ops
from keras.src import regularizers
from keras.src import tree
from keras.src import utils
Expand Down Expand Up @@ -974,7 +975,15 @@ def maybe_convert(x):
if self.activity_regularizer is not None:
for output in tree.flatten(outputs):
if backend.is_tensor(output):
self.add_loss(self.activity_regularizer(output))
loss = self.activity_regularizer(output)
if output.ndim > 0:
# Normalize by batch size to ensure consistent
# regularization strength across batch sizes
batch_size = ops.cast(
ops.shape(output)[0], dtype=loss.dtype
)
loss = ops.divide_no_nan(loss, batch_size)
self.add_loss(loss)

# Set `previous_mask` on outputs if available. It is provided only
# for the first positional input arg and its mask.
Expand Down
17 changes: 17 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,23 @@ def call(self, x):
layer(layers.Input(batch_shape=(2, 2)))
self.assertLen(layer.losses, 0)

@parameterized.named_parameters(
("batch_size_0", 0),
("batch_size_1", 1),
("batch_size_5", 5),
("batch_size_10", 10),
)
def test_activity_regularization_batch_normalization(self, batch_size):
class SimpleLayer(layers.Layer):
def call(self, x):
return x

layer = SimpleLayer(activity_regularizer="l2")
layer(ops.ones((batch_size, 5)) * 2.0)
self.assertLen(layer.losses, 1)
expected_loss = 0.0 if batch_size == 0 else 0.2
self.assertAllClose(layer.losses[0], expected_loss)

@pytest.mark.requires_trainable_backend
def test_add_loss(self):
class LossLayer(layers.Layer):
Expand Down
4 changes: 2 additions & 2 deletions keras/src/regularizers/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class Regularizer:
>>> out = layer(tensor)

>>> # The kernel regularization term is 0.25
>>> # The activity regularization term (after dividing by the batch size)
>>> # is 5
>>> # The activity regularization term (after dividing by batch size of 5)
>>> # is 5.0
>>> ops.sum(layer.losses)
5.25

Expand Down
Loading