Closed
Description
In torchao, we have various low precision training features which are in prototype: MX, int8, bitnet. While we expect most of these to eventually end up in the main torchao APIs, it often takes ~months for a prototype to graduate.
torchtitan is extremely useful for helping us test low precision prototypes in real-world settings. For now, we've been creating unlanded PRs to test functionality (examples: #614, #778). Would torchtitan consider building an extension point to support this kind of experimentation fully out-of-tree?
An example of how this could look like:
- torchtitan provides a "model transformation" hook that it calls at a specified point in the initialization stage (for quantization, that should be after model init and before parallelization / torch.compile)
- user can provide a custom pass to transform the model (such as a prototype low precision training conversion pass)
I'm not entirely sure on how this hook would be implemented since the current interface of torchtitan is CLI based, but wanted to share the request and start the discussion.
Activity