Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,40 +827,41 @@ def make_mutable_arraydef(value: variablelib.Ref):
nodes.append(nodedef)

for key, value in values:
value_node_impl = get_node_impl(value)
if path is not None:
path.append(key)
if value_node_impl is not None or isinstance(value, Variable):
attributes.append((key, NODE_ATTR))
_graph_flatten(
value,
value_node_impl,
path,
ref_index,
ref_outer_index,
nodes,
attributes,
leaves,
paths,
)
elif variablelib.is_array_ref(value):
attributes.append((key, MUTABLE_ARRAY_ATTR))
array_refdef, leaf = make_mutable_arraydef(value)
if not isinstance(leaf, Repeated):
leaves.append(leaf)
if not hasattr(node, '_pytree__nodes') or (
key in node._pytree__nodes and node._pytree__nodes[key]):
value_node_impl = get_node_impl(value)
if value_node_impl is not None or isinstance(value, Variable):
attributes.append((key, NODE_ATTR))
_graph_flatten(
value,
value_node_impl,
path,
ref_index,
ref_outer_index,
nodes,
attributes,
leaves,
paths,
)
elif variablelib.is_array_ref(value):
attributes.append((key, MUTABLE_ARRAY_ATTR))
array_refdef, leaf = make_mutable_arraydef(value)
if not isinstance(leaf, Repeated):
leaves.append(leaf)
if paths is not None:
paths.append(tuple(path)) # type: ignore
nodes.append(array_refdef)
else:
attributes.append((key, ARRAY_ATTR))
if paths is not None:
paths.append(tuple(path)) # type: ignore
nodes.append(array_refdef)
elif isinstance(value, (jax.Array, np.ndarray)):
attributes.append((key, ARRAY_ATTR))
if paths is not None:
paths.append(tuple(path)) # type: ignore
leaves.append(value)
leaves.append(value)
else:
attributes.append((key, Static(value)))

if path is not None:
path.pop()
path.pop()

return

Expand Down
39 changes: 31 additions & 8 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,46 @@ def __call__(self, x):


class TestGraphUtils(absltest.TestCase):
def test_data_is_not_static(self):
class Module(nnx.Module):
def __init__(self):
self.data = nnx.data({"a": jnp.ones((8, 8))})

module = Module()
# assert False
f1 = nnx.flatten(module)
abstract_module = nnx.eval_shape(Module)
f2 = nnx.flatten(abstract_module)
assert f1[0].attributes == f2[0].attributes

def test_flatten(self):
a = {'a': 1, 'b': nnx.Param(2)}
g = [a, 3, a, nnx.Param(4)]
class A(nnx.Module):
def __init__(self):
self.a = 1
self.b = nnx.Param(2)

a = A()
g = nnx.List([a, 3, a, nnx.Param(4)])

refmap = nnx.graph.RefMap()
graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap)

print(refmap)

assert flat_state[0][1].get_value() == 2
assert flat_state[1][1].get_value() == 4

assert len(refmap) == 2 # 2 Variables
assert a['b'] in refmap
assert len(refmap) == 4 # 2 Variables, 2 Modules
assert a.b in refmap
assert g[3] in refmap

def test_flatten_no_paths(self):
a = {'a': 1, 'b': nnx.Param(jnp.array(2))}
g = [a, 3, a, nnx.Param(jnp.array(4))]
class A(nnx.Module):
def __init__(self):
self.a = 1
self.b = nnx.Param(jnp.array(2))
a = A()
g = nnx.List([a, 3, a, nnx.Param(jnp.array(4))])

refmap = nnx.graph.RefMap()
graphdef, flat_state = nnx.graph.flatten(
Expand All @@ -65,8 +88,8 @@ def test_flatten_no_paths(self):
assert flat_state[0][...] == 2
assert flat_state[1][...] == 4

assert len(refmap) == 2 # 2 Variables
assert a['b'] in refmap
assert len(refmap) == 4 # 2 Variables, 2 Modules
assert a.b in refmap
assert g[3] in refmap

def test_unflatten(self):
Expand Down
Loading