-
Notifications
You must be signed in to change notification settings - Fork 774
Description
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Fedora 43
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib:
Name: flax
Version: 0.12.2
Name: jax
Version: 0.8.2
Name: jaxlib
Version: 0.8.2
- Python version: Using Python 3.11.14
- GPU/TPU model and memory: n/a
- CUDA version (if applicable): n/a
Problem you have encountered:
When I run the docs example code for "Scan over layers" I get an error.
What you expected to happen:
Docs should be updated to show that the split_rngs in __call__ (and therefore self.num_layers?) is no longer necessary.
0.10.6 was the last release where the example code worked.
Logs, error messages, etc:
Traceback (most recent call last):
File "/.../bug.py", line 35, in
model(jax.numpy.ones((10, 64)))
File "/.../bug.py", line 32, in call
return forward(x, self.blocks)
^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 804, in split_rngs_wrapper
with split_rngs(
^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 824, in split_rngs
key = stream()
^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 119, in call
key = random.fold_in(self.key[...], self.count[...])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/jax/_src/random.py", line 262, in fold_in
key, wrapped = _check_prng_key("fold_in", key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/jax/_src/random.py", line 101, in _check_prng_key
raise ValueError(f"{name} accepts a single key, but was given a key array of"
ValueError: fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.
Steps to reproduce:
# bug.py
import jax
from flax import nnx
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
model(jax.numpy.ones((10, 64)))
# > ValueError: fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.