@@ -1117,18 +1117,18 @@ def split_rngs(
11171117 >>> from flax import nnx
11181118 ...
11191119 >>> rngs = nnx.Rngs(params=0, dropout=1)
1120- >>> _ = nnx.split_rngs(rngs, splits=5)
1120+ >>> _ = nnx.split_rngs(rngs, splits=5, graph_updates=True )
11211121 >>> rngs.params.key.shape, rngs.dropout.key.shape
11221122 ((5,), (5,))
11231123
11241124 >>> rngs = nnx.Rngs(params=0, dropout=1)
1125- >>> _ = nnx.split_rngs(rngs, splits=(2, 5))
1125+ >>> _ = nnx.split_rngs(rngs, splits=(2, 5), graph_updates=True )
11261126 >>> rngs.params.key.shape, rngs.dropout.key.shape
11271127 ((2, 5), (2, 5))
11281128
11291129
11301130 >>> rngs = nnx.Rngs(params=0, dropout=1)
1131- >>> _ = nnx.split_rngs(rngs, splits=5, only='params')
1131+ >>> _ = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True )
11321132 >>> rngs.params.key.shape, rngs.dropout.key.shape
11331133 ((5,), ())
11341134
@@ -1140,11 +1140,11 @@ def split_rngs(
11401140 ... self.dropout = nnx.Dropout(0.5, rngs=rngs)
11411141 ...
11421142 >>> rngs = nnx.Rngs(params=0, dropout=1)
1143- >>> _ = nnx.split_rngs(rngs, splits=5, only='params')
1143+ >>> _ = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True )
11441144 ...
11451145 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
11461146 ...
1147- >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
1147+ >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes, graph_updates=True )
11481148 ... def create_model(rngs):
11491149 ... return Model(rngs)
11501150 ...
@@ -1158,7 +1158,7 @@ def split_rngs(
11581158
11591159 >>> rngs = nnx.Rngs(params=0, dropout=1)
11601160 ...
1161- >>> backups = nnx.split_rngs(rngs, splits=5, only='params')
1161+ >>> backups = nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True )
11621162 >>> model = create_model(rngs)
11631163 >>> nnx.restore_rngs(backups)
11641164 ...
@@ -1170,16 +1170,16 @@ def split_rngs(
11701170
11711171 >>> rngs = nnx.Rngs(params=0, dropout=1)
11721172 ...
1173- >>> with nnx.split_rngs(rngs, splits=5, only='params'):
1173+ >>> with nnx.split_rngs(rngs, splits=5, only='params', graph_updates=True ):
11741174 ... model = create_model(rngs)
11751175 ...
11761176 >>> model.dropout.rngs.key.shape
11771177 ()
11781178
11791179 >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None})
11801180 ...
1181- >>> @nnx.split_rngs(splits=5, only='params')
1182- ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes)
1181+ >>> @nnx.split_rngs(splits=5, only='params', graph_updates=True )
1182+ ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes, graph_updates=True )
11831183 ... def create_model(rngs):
11841184 ... return Model(rngs)
11851185 ...
@@ -1420,7 +1420,7 @@ def fork_rngs(
14201420
14211421 >>> rngs = nnx.Rngs(params=0, dropout=1)
14221422 ...
1423- >>> backups = nnx.fork_rngs(rngs)
1423+ >>> backups = nnx.fork_rngs(rngs, graph_updates=True )
14241424 >>> model = nnx.Linear(2, 3, rngs=rngs)
14251425 >>> nnx.restore_rngs(backups)
14261426 ...
@@ -1430,7 +1430,7 @@ def fork_rngs(
14301430
14311431 >>> rngs = nnx.Rngs(params=0, dropout=1)
14321432 ...
1433- >>> with nnx.fork_rngs(rngs):
1433+ >>> with nnx.fork_rngs(rngs, graph_updates=True ):
14341434 ... model = nnx.Linear(2, 3, rngs=rngs)
14351435
14361436 """
0 commit comments