Skip to content

Commit dac884e

Browse files
committed
Gradient
1 parent 2d18ab0 commit dac884e

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

squeez/encoder/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SqueezEncoderForLineClassification(PreTrainedModel):
6363
"""
6464

6565
config_class = SqueezEncoderConfig
66+
supports_gradient_checkpointing = True
6667

6768
def __init__(self, config: SqueezEncoderConfig):
6869
super().__init__(config)
@@ -138,6 +139,11 @@ def forward(
138139
attentions=outputs.attentions,
139140
)
140141

142+
def _set_gradient_checkpointing(self, module, value: bool = False):
143+
"""Delegate gradient checkpointing to the wrapped encoder when supported."""
144+
if module is self.encoder and hasattr(module, "gradient_checkpointing"):
145+
module.gradient_checkpointing = value
146+
141147
# ------------------------------------------------------------------
142148
# Inference helpers
143149
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)