Skip to content

should we have an extension point for model transforms out of tree? #790

Closed
@vkuzo

Description

@vkuzo

In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.

torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?

An example of how this could look like:

  1. torchtitan provides a "model transformation" hook that it calls at a specified point in the initialization stage (for quantization, that should be after model init and before parallelization / torch.compile)
  2. user can provide a custom pass to transform the model (such as a prototype low precision training conversion pass)

I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions