From 750527aa8ce145e090b7a3c828d8e1110b912e2b Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 12 Dec 2025 15:41:43 -0500 Subject: [PATCH 1/2] flatten respect nnx.data --- flax/nnx/graph.py | 111 ++++++++++++++++++---------------- flax/nnx/pytreelib.py | 11 +++- tests/nnx/graph_utils_test.py | 20 +++++- 3 files changed, 84 insertions(+), 58 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 508f6771a..ce3e1bf91 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -179,9 +179,13 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): type: type[Node] flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] - def node_dict(self, node: Node) -> dict[Key, Leaf]: + def node_dict(self, node: Node) -> dict[Key, tp.Any]: nodes, _ = self.flatten(node) - return dict(nodes) + nodes = { + key: node.value if isinstance(node, DataElem | StaticElem) else node + for key, node in nodes + } + return nodes @dataclasses.dataclass(frozen=True, slots=True) @@ -533,32 +537,21 @@ def __treescope_repr__(self, path, subtree_renderer): @dataclasses.dataclass(frozen=True, slots=True) -class ArrayAttr: - pass - - -ARRAY_ATTR = ArrayAttr() - - -@dataclasses.dataclass(frozen=True, slots=True) -class MutableArrayAttr: +class NodeAttr: pass -MUTABLE_ARRAY_ATTR = MutableArrayAttr() - +NODE_ATTR = NodeAttr() @dataclasses.dataclass(frozen=True, slots=True) -class NodeAttr: +class LeafAttr: pass - -NODE_ATTR = NodeAttr() +LEAF_ATTR = LeafAttr() AttrType = tp.Union[ NodeAttr, - ArrayAttr, - MutableArrayAttr, + LeafAttr, 'Static[tp.Any]', ] @@ -710,6 +703,14 @@ def flatten( # type: ignore[invalid-annotation] else: return graphdef, leaves +@dataclasses.dataclass(frozen=True, slots=True) +class DataElem: + value: tp.Any + + +@dataclasses.dataclass(frozen=True, slots=True) +class StaticElem: + value: tp.Any def _graph_flatten( node: Node, @@ -827,6 +828,18 @@ def make_mutable_arraydef(value: variablelib.Ref): nodes.append(nodedef) for key, value in values: + is_data = None + if isinstance(value, DataElem): + value = value.value + is_data = True + elif isinstance(value, StaticElem): + value = value.value + is_data = False + + if is_data is False: + attributes.append((key, Static(value))) + continue + value_node_impl = get_node_impl(value) if path is not None: path.append(key) @@ -844,15 +857,15 @@ def make_mutable_arraydef(value: variablelib.Ref): paths, ) elif variablelib.is_array_ref(value): - attributes.append((key, MUTABLE_ARRAY_ATTR)) + attributes.append((key, NODE_ATTR)) array_refdef, leaf = make_mutable_arraydef(value) if not isinstance(leaf, Repeated): leaves.append(leaf) if paths is not None: paths.append(tuple(path)) # type: ignore nodes.append(array_refdef) - elif isinstance(value, (jax.Array, np.ndarray)): - attributes.append((key, ARRAY_ATTR)) + elif isinstance(value, (jax.Array, np.ndarray)) or is_data: + attributes.append((key, LEAF_ATTR)) if paths is not None: paths.append(tuple(path)) # type: ignore leaves.append(value) @@ -1092,41 +1105,33 @@ def _get_children() -> list[tuple[Key, tp.Any]]: key, value = next(attribute_iter) if type(value) is Static: children.append((key, value.value)) # type: ignore[attribute-error] - elif type(value) is MutableArrayAttr: - array_refdef = next(node_iter) - assert ( - type(array_refdef) is ArrayRefDef or type(array_refdef) is NodeRef - ) - if type(array_refdef) is NodeRef: - array_ref = index_ref[array_refdef.index] - else: - assert type(array_refdef) is ArrayRefDef + elif type(value) is LeafAttr: + leaf = next(leaves_iter) + children.append((key, leaf)) + elif type(value) is NodeAttr: + node_def = next(node_iter) + if isinstance(node_def, NodeRef): + node = index_ref[node_def.index] + elif isinstance(node_def, ArrayRefDef): leaf = next(leaves_iter) - array_ref = get_mutable_array(array_refdef, leaf) - children.append((key, array_ref)) - elif type(value) is ArrayAttr: - array = next(leaves_iter) - children.append((key, array)) + node = get_mutable_array(node_def, leaf) + elif isinstance(node_def, NodeDef | VariableDef): + value_node_impl = get_node_impl_for_type(node_def.type) + node = _graph_unflatten( + node_def, + value_node_impl, + node_iter, + attribute_iter, + leaves_iter, + index_ref, + outer_index_outer_ref, + copy_variables, + ) + else: + raise RuntimeError(f'Unknown node definition: {node_def!r}') + children.append((key, node)) elif type(value) is NodeRef: children.append((key, index_ref[value.index])) # type: ignore[attribute-error] - elif type(value) is NodeAttr: - # if the key is a subgraph we create an empty node - subgraphdef = next(node_iter) - if type(subgraphdef) is NodeDef: - value_node_impl = get_node_impl_for_type(subgraphdef.type) # type: ignore[attribute-error] - else: - value_node_impl = None - subnode = _graph_unflatten( - subgraphdef, - value_node_impl, - node_iter, - attribute_iter, - leaves_iter, - index_ref, - outer_index_outer_ref, - copy_variables, - ) - children.append((key, subnode)) else: raise RuntimeError(f'Unknown static field: {key!r}') diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index b84c1bad6..9c55e20e3 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -24,6 +24,7 @@ import warnings from flax.nnx import variablelib +from flax import nnx import jax import numpy as np import treescope # type: ignore[import-untyped] @@ -917,8 +918,14 @@ def _pytree__unflatten( # Graph Definition # ------------------------- def _graph_node_flatten(self): - nodes = vars(self) - nodes = sorted(nodes.items(), key=self._pytree__key_sort_fn) + pytree_nodes = self._pytree__nodes + nodes = ( + (name, nnx.graph.DataElem(value) + if name in pytree_nodes and pytree_nodes[name] + else nnx.graph.StaticElem(value)) + for name, value in vars(self).items() + ) + nodes = sorted(nodes, key=self._pytree__key_sort_fn) return nodes, type(self) def _graph_node_set_key(self, key: str, value: tp.Any): diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index f02e51124..7daddeef2 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -629,7 +629,7 @@ def __init__(self): self.assertFalse(hasattr(ctx, 'ctxtag')) self.assertIsInstance(graphdef1.nodes[0], nnx.graph.NodeDef) self.assertIsInstance(graphdef2.nodes[0], nnx.graph.NodeRef) - self.assertLen(nnx.to_flat_state(state1), 1) + self.assertLen(nnx.to_flat_state(state1), 2) self.assertLen(nnx.to_flat_state(state2), 0) @jax.jit @@ -717,7 +717,7 @@ def __init__(self): assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef) self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef) - self.assertLen(nnx.to_flat_state(t1.states[0]), 1) + self.assertLen(nnx.to_flat_state(t1.states[0]), 2) self.assertLen(nnx.to_flat_state(t2.states[0]), 0) @jax.jit @@ -744,7 +744,7 @@ def f(pure_tree): assert isinstance(t2, nnx.NodeStates) self.assertIsInstance(t1.graphdef.nodes[0], nnx.graph.NodeDef) self.assertIsInstance(t2.graphdef.nodes[0], nnx.graph.NodeRef) - self.assertLen(nnx.to_flat_state(t1.states[0]), 1) + self.assertLen(nnx.to_flat_state(t1.states[0]), 2) self.assertLen(nnx.to_flat_state(t2.states[0]), 0) return pure_tree2 @@ -762,6 +762,20 @@ def f(pure_tree): self.assertEqual(m.b[...], 1) # type: ignore self.assertEqual(impure_tree2[1], 1) + def test_graph_flatten_with_data_wrapper(self): + class Foo(nnx.Pytree): + def __init__(self, value, static): + self.value = nnx.data(value) + self.static = nnx.static(static) + + tree = Foo(1, 2) + state = nnx.state(tree) + + self.assertIn('value', state) + self.assertIsInstance(state['value'], int) + self.assertEqual(state['value'], 1) + self.assertNotIn('static', state) + def test_to_tree_consistent_prefix(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) impure_tree = (m, 1, {'b': m}) From ebd786852d36a8eccaebeba888c5e3342d121589 Mon Sep 17 00:00:00 2001 From: Sam Anklesaria Date: Mon, 29 Dec 2025 10:38:53 -0600 Subject: [PATCH 2/2] Fix type error --- flax/nnx/graph.py | 4 ++-- flax/nnx/pytreelib.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index ce3e1bf91..2ee785b3a 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -180,10 +180,10 @@ class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] def node_dict(self, node: Node) -> dict[Key, tp.Any]: - nodes, _ = self.flatten(node) + node_seq, _ = self.flatten(node) nodes = { key: node.value if isinstance(node, DataElem | StaticElem) else node - for key, node in nodes + for key, node in node_seq } return nodes diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 9c55e20e3..e1e848f92 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -920,9 +920,14 @@ def _pytree__unflatten( def _graph_node_flatten(self): pytree_nodes = self._pytree__nodes nodes = ( - (name, nnx.graph.DataElem(value) - if name in pytree_nodes and pytree_nodes[name] - else nnx.graph.StaticElem(value)) + ( + name, + value + if not self._pytree__is_pytree + else nnx.graph.DataElem(value) + if name in pytree_nodes and pytree_nodes[name] + else nnx.graph.StaticElem(value) + ) for name, value in vars(self).items() ) nodes = sorted(nodes, key=self._pytree__key_sort_fn)