Skip to content

Commit df2d381

Browse files
committed
Add _graph_node_set_key method for List class
1 parent 9048b4e commit df2d381

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

flax/nnx/helpers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from flax.nnx.rnglib import Rngs
2828
from flax.nnx.statelib import State
2929
from flax.training.train_state import struct
30+
from flax.nnx.variablelib import Variable
3031

3132
A = tp.TypeVar('A')
3233
M = tp.TypeVar('M', bound=Module)
@@ -185,6 +186,17 @@ def __setitem__(self, index: int | slice, value: A | tp.Iterable[A]) -> None:
185186
else:
186187
raise TypeError('Invalid index type')
187188

189+
def _graph_node_set_key(self, key: str, value: tp.Any):
190+
if not isinstance(key, int):
191+
raise KeyError(f'Invalid key: {key}')
192+
elif key < len(self):
193+
if isinstance(variable := self[key], Variable) and isinstance(value, Variable):
194+
variable.update_from_state(value)
195+
else:
196+
self[key] = value
197+
else:
198+
self.insert(key, value)
199+
188200
def __delitem__(self, index: int | slice) -> None:
189201
if isinstance(index, int):
190202
if index < 0:

tests/nnx/graph_utils_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,17 @@ def swap(path, node):
11801180
self.assertEqual(bar2[1].d, -20)
11811181
self.assertEqual(n, 2)
11821182

1183+
def test_recursive_map_with_list(self):
1184+
rngs = nnx.Rngs(0)
1185+
model = nnx.Sequential(nnx.Linear(2, 3, rngs=rngs), nnx.relu, nnx.Linear(3, 4, rngs=rngs))
1186+
1187+
def add_rank2_lora(_, node):
1188+
if isinstance(node, nnx.Linear):
1189+
return nnx.LoRA(node.in_features, 2, node.out_features, base_module=node, rngs=rngs)
1190+
return node
1191+
1192+
self.assertEqual(len(nnx.recursive_map(add_rank2_lora, model).layers), 3)
1193+
11831194
def test_graphdef_hash_with_sequential(self):
11841195
rngs = nnx.Rngs(0)
11851196
net = nnx.Sequential(

0 commit comments

Comments
 (0)