Skip to content

Commit 5df09b9

Browse files
committed
Utilities to save and load only tasks weights
In addition to the private H5IOStore (which we rely on for lora), this also relies on `_save_state` and `_load_state`. We should probably think of a better contract with core Keras before we ship, but this will allow a prototype.
1 parent e674fd2 commit 5df09b9

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

keras_nlp/models/task.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,55 @@ def from_preset(
321321
config_overrides=kwargs,
322322
)
323323

324+
def load_task_weights(self, filepath, skip_mismatch=False):
325+
"""Load only the tasks specific weights not in the backbone."""
326+
if not str(filepath).endswith(".weights.h5"):
327+
raise ValueError(
328+
"The filename must end in `.weights.h5`. "
329+
f"Received: filepath={filepath}"
330+
)
331+
weights_store = keras.src.saving.saving_lib.H5IOStore(
332+
filepath, mode="r"
333+
)
334+
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
335+
keras.src.saving.saving_lib._load_state(
336+
self,
337+
weights_store=weights_store,
338+
assets_store=None,
339+
inner_path="",
340+
skip_mismatch=skip_mismatch,
341+
visited_trackables=backbone_layer_ids,
342+
failed_trackables=set(),
343+
)
344+
weights_store.close()
345+
346+
def save_task_weights(self, filepath):
347+
"""Save only the tasks specific weights not in the backbone."""
348+
if not str(filepath).endswith(".weights.h5"):
349+
raise ValueError(
350+
"The filename must end in `.weights.h5`. "
351+
f"Received: filepath={filepath}"
352+
)
353+
task_weight_ids = set(id(w) for w in self.weights)
354+
backbone_weight_ids = set(id(w) for w in self.backbone.weights)
355+
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
356+
if set(task_weight_ids) in set(backbone_weight_ids):
357+
raise ValueError(
358+
f"Task {self} has no weights not in the `backbone`. "
359+
"`save_task_weights()` has nothing to save."
360+
)
361+
weights_store = keras.src.saving.saving_lib.H5IOStore(
362+
filepath, mode="w"
363+
)
364+
keras.src.saving.saving_lib._save_state(
365+
self,
366+
weights_store=weights_store,
367+
assets_store=None,
368+
inner_path="",
369+
visited_trackables=backbone_layer_ids,
370+
)
371+
weights_store.close()
372+
324373
@property
325374
def layers(self):
326375
# Remove preprocessor from layers so it does not show up in the summary.

0 commit comments

Comments
 (0)