Open
Description
Outline & Motivation
DeepSpeed works by using a configuration file (dictionary) that allows customizing all of its aspects: https://www.deepspeed.ai/docs/config-json/
The DeepSpeedStrategy
supports two ways of defining this:
- Passing a config file, where every other argument becomes unused: https://github.com/Lightning-AI/lightning/blob/b792c90ea7148d61af192fde6c338ebbd355702f/src/lightning/fabric/strategies/deepspeed.py#L191
- Exposes multiple of these arguments in the
__init__
that are used to define a base config. https://github.com/Lightning-AI/lightning/blob/b792c90ea7148d61af192fde6c338ebbd355702f/src/lightning/fabric/strategies/deepspeed.py#L242-L271
Option 2 is not scalable because:
- It forces us to duplicate all arguments
- Our docstrings might become outdated
- Our strategy defaults might diverge from the defaults in deepspeed
- It forces the user to either create an entire config or use these arguments
- Arguments might be different based on the installed deepspeed version as we support more than a single version.
- When deepspeed adds an argument that we don't expose, users have to switch to using the config
Pitch
Remove all these exposed arguments and just have a config
argument that overloads support for:
- Passing a path to a config file
DeepSpeedStrategy(config="my/config/path.json")
- Passing a full config object
config = ds.runtime.config.DeepSpeedConfig({"train_micro_batch_size_per_gpu": 2})
DeepSpeedStrategy(config=config)
- Passing a config dictionary (or a subset of it) that will update the default config
config = {"zero_optimization": {"offload_optimizer": {"device": "cpu"}}}
DeepSpeedStrategy(config=config)
Where the default config is created by calling: https://github.com/microsoft/DeepSpeed/blob/085981bf1caf5d7d0b26d05f7c7e9487e1b35190/deepspeed/runtime/config.py#L674
Additional context
DeepSpeed is considered experimental so we could do this breaking change: https://github.com/Lightning-AI/lightning/blob/b792c90ea7148d61af192fde6c338ebbd355702f/src/lightning/fabric/strategies/deepspeed.py#L99