File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -63,6 +63,7 @@ class SqueezEncoderForLineClassification(PreTrainedModel):
6363 """
6464
6565 config_class = SqueezEncoderConfig
66+ supports_gradient_checkpointing = True
6667
6768 def __init__ (self , config : SqueezEncoderConfig ):
6869 super ().__init__ (config )
@@ -138,6 +139,11 @@ def forward(
138139 attentions = outputs .attentions ,
139140 )
140141
142+ def _set_gradient_checkpointing (self , module , value : bool = False ):
143+ """Delegate gradient checkpointing to the wrapped encoder when supported."""
144+ if module is self .encoder and hasattr (module , "gradient_checkpointing" ):
145+ module .gradient_checkpointing = value
146+
141147 # ------------------------------------------------------------------
142148 # Inference helpers
143149 # ------------------------------------------------------------------
You can’t perform that action at this time.
0 commit comments