Skip to content

Commit 657d706

Browse files
committed
Fix rnglib doctests for graph_updates=False default
With graph_updates=False, split_rngs and fork_rngs no longer mutate the input node in-place. Doctests that relied on in-place mutation, restore_rngs, or StateAxes in vmap now explicitly pass graph_updates=True.
1 parent 7075c6e commit 657d706

1 file changed

Lines changed: 11 additions & 11 deletions

File tree

flax/nnx/rnglib.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,18 +1117,18 @@ def split_rngs(
11171117
>>> from flax import nnx
11181118
...
11191119
>>> rngs = nnx.Rngs(params=0, dropout=1)
1120-
>>> _ = nnx.split_rngs(rngs, splits=5)
1120+
>>> _ = nnx.split_rngs(rngs, splits=5, graph_updates=True)
11211121
>>> rngs.params.key.shape, rngs.dropout.key.shape
11221122
((5,), (5,))
11231123
11241124
>>> rngs = nnx.Rngs(params=0, dropout=1)
1125-
>>> _ = nnx.split_rngs(rngs, splits=(2, 5))
1125+
>>> _ = nnx.split_rngs(rngs, splits=(2, 5), graph_updates=True)
11261126
>>> rngs.params.key.shape, rngs.dropout.key.shape
11271127
((2, 5), (2, 5))
11281128
11291129
11301130
>>> rngs = nnx.Rngs(params=0, dropout=1)
1131-
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
1131+
>>> _ = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True)
11321132
>>> rngs.params.key.shape, rngs.dropout.key.shape
11331133
((5,), ())
11341134
@@ -1140,11 +1140,11 @@ def split_rngs(
11401140
... self.dropout = nnx.Dropout(0.5, rngs=rngs)
11411141
...
11421142
>>> rngs = nnx.Rngs(params=0, dropout=1)
1143-
>>> _ = nnx.split_rngs(rngs, splits=5, only='params')
1143+
>>> _ = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True)
11441144
...
11451145
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
11461146
...
1147-
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
1147+
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes, graph_updates=True)
11481148
... def create_model(rngs):
11491149
... return Model(rngs)
11501150
...
@@ -1158,7 +1158,7 @@ def split_rngs(
11581158
11591159
>>> rngs = nnx.Rngs(params=0, dropout=1)
11601160
...
1161-
>>> backups = nnx.split_rngs(rngs, splits=5, only='params')
1161+
>>> backups = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True)
11621162
>>> model = create_model(rngs)
11631163
>>> nnx.restore_rngs(backups)
11641164
...
@@ -1170,16 +1170,16 @@ def split_rngs(
11701170
11711171
>>> rngs = nnx.Rngs(params=0, dropout=1)
11721172
...
1173-
>>> with nnx.split_rngs(rngs, splits=5, only='params'):
1173+
>>> with nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True):
11741174
... model = create_model(rngs)
11751175
...
11761176
>>> model.dropout.rngs.key.shape
11771177
()
11781178
11791179
>>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
11801180
...
1181-
>>> @nnx.split_rngs(splits=5, only='params')
1182-
... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
1181+
>>> @nnx.split_rngs(splits=5, only='params', graph_updates=True)
1182+
... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes, graph_updates=True)
11831183
... def create_model(rngs):
11841184
... return Model(rngs)
11851185
...
@@ -1420,7 +1420,7 @@ def fork_rngs(
14201420
14211421
>>> rngs = nnx.Rngs(params=0, dropout=1)
14221422
...
1423-
>>> backups = nnx.fork_rngs(rngs)
1423+
>>> backups = nnx.fork_rngs(rngs, graph_updates=True)
14241424
>>> model = nnx.Linear(2, 3, rngs=rngs)
14251425
>>> nnx.restore_rngs(backups)
14261426
...
@@ -1430,7 +1430,7 @@ def fork_rngs(
14301430
14311431
>>> rngs = nnx.Rngs(params=0, dropout=1)
14321432
...
1433-
>>> with nnx.fork_rngs(rngs):
1433+
>>> with nnx.fork_rngs(rngs, graph_updates=True):
14341434
... model = nnx.Linear(2, 3, rngs=rngs)
14351435
14361436
"""

0 commit comments

Comments
 (0)