Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 59 additions & 54 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
type: type[Node]
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]

def node_dict(self, node: Node) -> dict[Key, Leaf]:
nodes, _ = self.flatten(node)
return dict(nodes)
def node_dict(self, node: Node) -> dict[Key, tp.Any]:
node_seq, _ = self.flatten(node)
nodes = {
key: node.value if isinstance(node, DataElem | StaticElem) else node
for key, node in node_seq
}
return nodes


@dataclasses.dataclass(frozen=True, slots=True)
Expand Down Expand Up @@ -533,32 +537,21 @@ def __treescope_repr__(self, path, subtree_renderer):


@dataclasses.dataclass(frozen=True, slots=True)
class ArrayAttr:
pass


ARRAY_ATTR = ArrayAttr()


@dataclasses.dataclass(frozen=True, slots=True)
class MutableArrayAttr:
class NodeAttr:
pass


MUTABLE_ARRAY_ATTR = MutableArrayAttr()

NODE_ATTR = NodeAttr()

@dataclasses.dataclass(frozen=True, slots=True)
class NodeAttr:
class LeafAttr:
pass


NODE_ATTR = NodeAttr()
LEAF_ATTR = LeafAttr()

AttrType = tp.Union[
NodeAttr,
ArrayAttr,
MutableArrayAttr,
LeafAttr,
'Static[tp.Any]',
]

Expand Down Expand Up @@ -710,6 +703,14 @@ def flatten( # type: ignore[invalid-annotation]
else:
return graphdef, leaves

@dataclasses.dataclass(frozen=True, slots=True)
class DataElem:
value: tp.Any


@dataclasses.dataclass(frozen=True, slots=True)
class StaticElem:
value: tp.Any

def _graph_flatten(
node: Node,
Expand Down Expand Up @@ -827,6 +828,18 @@ def make_mutable_arraydef(value: variablelib.Ref):
nodes.append(nodedef)

for key, value in values:
is_data = None
if isinstance(value, DataElem):
value = value.value
is_data = True
elif isinstance(value, StaticElem):
value = value.value
is_data = False

if is_data is False:
attributes.append((key, Static(value)))
continue

value_node_impl = get_node_impl(value)
if path is not None:
path.append(key)
Expand All @@ -844,15 +857,15 @@ def make_mutable_arraydef(value: variablelib.Ref):
paths,
)
elif variablelib.is_array_ref(value):
attributes.append((key, MUTABLE_ARRAY_ATTR))
attributes.append((key, NODE_ATTR))
array_refdef, leaf = make_mutable_arraydef(value)
if not isinstance(leaf, Repeated):
leaves.append(leaf)
if paths is not None:
paths.append(tuple(path)) # type: ignore
nodes.append(array_refdef)
elif isinstance(value, (jax.Array, np.ndarray)):
attributes.append((key, ARRAY_ATTR))
elif isinstance(value, (jax.Array, np.ndarray)) or is_data:
attributes.append((key, LEAF_ATTR))
if paths is not None:
paths.append(tuple(path)) # type: ignore
leaves.append(value)
Expand Down Expand Up @@ -1092,41 +1105,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
key, value = next(attribute_iter)
if type(value) is Static:
children.append((key, value.value)) # type: ignore[attribute-error]
elif type(value) is MutableArrayAttr:
array_refdef = next(node_iter)
assert (
type(array_refdef) is ArrayRefDef or type(array_refdef) is NodeRef
)
if type(array_refdef) is NodeRef:
array_ref = index_ref[array_refdef.index]
else:
assert type(array_refdef) is ArrayRefDef
elif type(value) is LeafAttr:
leaf = next(leaves_iter)
children.append((key, leaf))
elif type(value) is NodeAttr:
node_def = next(node_iter)
if isinstance(node_def, NodeRef):
node = index_ref[node_def.index]
elif isinstance(node_def, ArrayRefDef):
leaf = next(leaves_iter)
array_ref = get_mutable_array(array_refdef, leaf)
children.append((key, array_ref))
elif type(value) is ArrayAttr:
array = next(leaves_iter)
children.append((key, array))
node = get_mutable_array(node_def, leaf)
elif isinstance(node_def, NodeDef | VariableDef):
value_node_impl = get_node_impl_for_type(node_def.type)
node = _graph_unflatten(
node_def,
value_node_impl,
node_iter,
attribute_iter,
leaves_iter,
index_ref,
outer_index_outer_ref,
copy_variables,
)
else:
raise RuntimeError(f'Unknown node definition: {node_def!r}')
children.append((key, node))
elif type(value) is NodeRef:
children.append((key, index_ref[value.index])) # type: ignore[attribute-error]
elif type(value) is NodeAttr:
# if the key is a subgraph we create an empty node
subgraphdef = next(node_iter)
if type(subgraphdef) is NodeDef:
value_node_impl = get_node_impl_for_type(subgraphdef.type) # type: ignore[attribute-error]
else:
value_node_impl = None
subnode = _graph_unflatten(
subgraphdef,
value_node_impl,
node_iter,
attribute_iter,
leaves_iter,
index_ref,
outer_index_outer_ref,
copy_variables,
)
children.append((key, subnode))
else:
raise RuntimeError(f'Unknown static field: {key!r}')

Expand Down
11 changes: 9 additions & 2 deletions flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import warnings

from flax.nnx import variablelib
from flax import nnx
import jax
import numpy as np
import treescope # type: ignore[import-untyped]
Expand Down Expand Up @@ -917,8 +918,14 @@ def _pytree__unflatten(
# Graph Definition
# -------------------------
def _graph_node_flatten(self):
nodes = vars(self)
nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn)
pytree_nodes = self._pytree__nodes
nodes = (
(name, nnx.graph.DataElem(value)
if name in pytree_nodes and pytree_nodes[name]
else nnx.graph.StaticElem(value))
for name, value in vars(self).items()
)
nodes = sorted(nodes, key=self._pytree__key_sort_fn)
return nodes, type(self)

def _graph_node_set_key(self, key: str, value: tp.Any):
Expand Down
20 changes: 17 additions & 3 deletions tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def __init__(self):
self.assertFalse(hasattr(ctx, 'ctxtag'))
self.assertIsInstance(graphdef1.nodes[0], nnx.graph.NodeDef)
self.assertIsInstance(graphdef2.nodes[0], nnx.graph.NodeRef)
self.assertLen(nnx.to_flat_state(state1), 1)
self.assertLen(nnx.to_flat_state(state1), 2)
self.assertLen(nnx.to_flat_state(state2), 0)

@jax.jit
Expand Down Expand Up @@ -717,7 +717,7 @@ def __init__(self):
assert isinstance(t2, nnx.NodeStates)
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)

@jax.jit
Expand All @@ -744,7 +744,7 @@ def f(pure_tree):
assert isinstance(t2, nnx.NodeStates)
self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef)
self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef)
self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
self.assertLen(nnx.to_flat_state(t2.states[0]), 0)

return pure_tree2
Expand All @@ -762,6 +762,20 @@ def f(pure_tree):
self.assertEqual(m.b[...], 1) # type: ignore
self.assertEqual(impure_tree2[1], 1)

def test_graph_flatten_with_data_wrapper(self):
class Foo(nnx.Pytree):
def __init__(self, value, static):
self.value = nnx.data(value)
self.static = nnx.static(static)

tree = Foo(1, 2)
state = nnx.state(tree)

self.assertIn('value', state)
self.assertIsInstance(state['value'], int)
self.assertEqual(state['value'], 1)
self.assertNotIn('static', state)

def test_to_tree_consistent_prefix(self):
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
impure_tree = (m, 1, {'b': m})
Expand Down
Loading