Skip to content

How to integrate huggingface.PyTorchModelHubMixin.save_pretrained() with Lightning Trainer (checkpointing & loading) #21366

@yilin404

Description

@yilin404

Description & Motivation

📝 Summary

I am trying to train a custom policy module using PyTorch Lightning, where my model components (policy, preprocessor, postprocessor) all inherit from huggingface_hub.PyTorchModelHubMixin.

This mixin provides:

  • _save_pretrained()
  • _from_pretrained()
    which work similarly to HuggingFace’s save_pretrained() and from_pretrained(), and are very convenient for packaging model weights + config.

However, I am not sure how to properly integrate these HF-style save/load utilities into Lightning's standard training flow — especially Lightning's ModelCheckpoint callback.

📦 Minimal example

class PolicyModule(LightningModule):
    def __init__(
        self,
        policy: BasePolicy,
        preprocessor: DataProcessorPipeline,
        postprocessor: DataProcessorPipeline,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters(
            logger=False, ignore=["policy", "preprocessor", "postprocessor"]
        )

        self.policy = policy
        self.preprocessor = preprocessor
        self.postprocessor = postprocessor

    def training_step(self, batch, batch_idx):
        batch = self.preprocessor(batch)
        loss, loss_dict = self.policy(batch)

        self.log_dict(
            {"train/_loss": loss},
            on_step=True, on_epoch=True, prog_bar=True, sync_dist=True,
        )
        self.log_dict(
            {f"train/{k}": v for k, v in loss_dict.items()},
            on_step=True, on_epoch=True, sync_dist=True,
        )
        return loss

❗The problem

How to make Lightning call save_pretrained() during checkpointing?

🙏 Additional context

PyTorch Lightning provides many excellent callback tools—especially ModelCheckpoint—which greatly simplify training workflows. Ideally, I would like to remain fully within the standard Lightning Trainer + callbacks framework, while still taking advantage of HuggingFace-style save_pretrained() / from_pretrained() for model components.

Thanks in advance for any guidance or best practices!

Pitch

No response

Alternatives

No response

Additional context

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    3rd partyRelated to a 3rd-partycheckpointingRelated to checkpointingquestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions