diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 7858b84709..a487058a0d 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -321,6 +321,55 @@ def from_preset( config_overrides=kwargs, ) + def load_task_weights(self, filepath, skip_mismatch=False): + """Load only the tasks specific weights not in the backbone.""" + if not str(filepath).endswith(".weights.h5"): + raise ValueError( + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath}" + ) + weights_store = keras.src.saving.saving_lib.H5IOStore( + filepath, mode="r" + ) + backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers()) + keras.src.saving.saving_lib._load_state( + self, + weights_store=weights_store, + assets_store=None, + inner_path="", + skip_mismatch=skip_mismatch, + visited_trackables=backbone_layer_ids, + failed_trackables=set(), + ) + weights_store.close() + + def save_task_weights(self, filepath): + """Save only the tasks specific weights not in the backbone.""" + if not str(filepath).endswith(".weights.h5"): + raise ValueError( + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath}" + ) + task_weight_ids = set(id(w) for w in self.weights) + backbone_weight_ids = set(id(w) for w in self.backbone.weights) + backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers()) + if task_weight_ids.issubset(backbone_weight_ids): + raise ValueError( + f"Task {self} has no weights not in the `backbone`. " + "`save_task_weights()` has nothing to save." + ) + weights_store = keras.src.saving.saving_lib.H5IOStore( + filepath, mode="w" + ) + keras.src.saving.saving_lib._save_state( + self, + weights_store=weights_store, + assets_store=None, + inner_path="", + visited_trackables=backbone_layer_ids, + ) + weights_store.close() + @property def layers(self): # Remove preprocessor from layers so it does not show up in the summary.