Skip to content

Outdated docs for Scan over layers #5158

@giovannic

Description

@giovannic

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Fedora 43
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:

Name: flax
Version: 0.12.2

Name: jax
Version: 0.8.2

Name: jaxlib
Version: 0.8.2

  • Python version: Using Python 3.11.14
  • GPU/TPU model and memory: n/a
  • CUDA version (if applicable): n/a

Problem you have encountered:

When I run the docs example code for "Scan over layers" I get an error.

What you expected to happen:

Docs should be updated to show that the split_rngs in __call__ (and therefore self.num_layers?) is no longer necessary.

0.10.6 was the last release where the example code worked.

Logs, error messages, etc:

Traceback (most recent call last):
File "/.../bug.py", line 35, in
model(jax.numpy.ones((10, 64)))
File "/.../bug.py", line 32, in call
return forward(x, self.blocks)
^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 804, in split_rngs_wrapper
with split_rngs(
^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 824, in split_rngs
key = stream()
^^^^^^^^
File "/.../env/lib/python3.11/site-packages/flax/nnx/rnglib.py", line 119, in call
key = random.fold_in(self.key[...], self.count[...])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/jax/_src/random.py", line 262, in fold_in
key, wrapped = _check_prng_key("fold_in", key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../env/lib/python3.11/site-packages/jax/_src/random.py", line 101, in _check_prng_key
raise ValueError(f"{name} accepts a single key, but was given a key array of"
ValueError: fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.

Steps to reproduce:

colab link

# bug.py

import jax
from flax import nnx

class Block(nnx.Module):
  def __init__(self, input_dim, features, rngs):
    self.linear = nnx.Linear(input_dim, features, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)

  def __call__(self, x: jax.Array):  # No need to require a second input!
    x = self.linear(x)
    x = self.dropout(x)
    x = jax.nn.relu(x)
    return x   # No need to return a second output!

class MLP(nnx.Module):
  def __init__(self, features, num_layers, rngs):
    @nnx.split_rngs(splits=num_layers)
    @nnx.vmap(in_axes=(0,), out_axes=0)
    def create_block(rngs: nnx.Rngs):
      return Block(features, features, rngs=rngs)

    self.blocks = create_block(rngs)
    self.num_layers = num_layers

  def __call__(self, x):
    @nnx.split_rngs(splits=self.num_layers)
    @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
    def forward(x, model):
      x = model(x)
      return x

    return forward(x, self.blocks)

model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
model(jax.numpy.ones((10, 64)))
# > ValueError: fold_in accepts a single key, but was given a key array of shape (5,) != (). Use jax.vmap for batching.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions