Skip to content

fsdp=16 model=16 gbs=16 should work on 256 chips #773

Open
@samos123

Description

@samos123

fsdp=16 model=16 global_batch_size=16 should work on 256 chips

The use case is being able to use a global batch size smaller than total jax processes.

This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291

Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:

I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process  19 step       -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
    app.run(main)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
    launch_trainer.run_trainer(trainer_config)
  File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
    trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/trainer.py", line 244, in __init__
    self._add_child("input", cfg.input.set(is_training=True))
  File "/root/axlearn/common/module.py", line 760, in _add_child
    module = child_config.instantiate(parent=self, **kwargs)
  File "/root/axlearn/common/config.py", line 734, in instantiate
    return self.klass(self, **kwargs)
  File "/root/axlearn/common/module.py", line 520, in __call__
    instance = super().__call__(*args, **kwds)
  File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
    self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
  File "/root/axlearn/common/config.py", line 801, in instantiate
    return self.fn(*args, **kwargs)
  File "/root/axlearn/common/input_tf_data.py", line 799, in batch
    raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions