Open
Description
fsdp=16 model=16 global_batch_size=16 should work on 256 chips
The use case is being able to use a global batch size smaller than total jax processes.
This is supported in maxtext by using this trick: https://github.com/AI-Hypercomputer/maxtext/blob/4cf51b7f204e109df502cf2d54b4d5005f597b09/MaxText/train.py#L289-L291
Trying to get 405b model running on v6e-256 (fsdp=16 model=16) but getting hit with this error:
I1022 20:32:33.715831 139189201369088 trainer.py:323] gpt_trainer process 19 step -1] Global mesh: Mesh('pipeline': 1, 'data': 1, 'expert': 1, 'fsdp': 16, 'seq': 1, 'model': 16)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/local/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/axlearn/common/launch_trainer_main.py", line 21, in <module>
app.run(main)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/opt/venv/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/axlearn/common/launch_trainer_main.py", line 16, in main
launch_trainer.run_trainer(trainer_config)
File "/root/axlearn/common/launch_trainer.py", line 129, in run_trainer
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
File "/root/axlearn/common/config.py", line 734, in instantiate
return self.klass(self, **kwargs)
File "/root/axlearn/common/module.py", line 520, in __call__
instance = super().__call__(*args, **kwds)
File "/root/axlearn/common/trainer.py", line 244, in __init__
self._add_child("input", cfg.input.set(is_training=True))
File "/root/axlearn/common/module.py", line 760, in _add_child
module = child_config.instantiate(parent=self, **kwargs)
File "/root/axlearn/common/config.py", line 734, in instantiate
return self.klass(self, **kwargs)
File "/root/axlearn/common/module.py", line 520, in __call__
instance = super().__call__(*args, **kwds)
File "/root/axlearn/common/input_tf_data.py", line 1185, in __init__
self._batcher = maybe_set_config(cfg.batcher, is_training=cfg.is_training).instantiate()
File "/root/axlearn/common/config.py", line 801, in instantiate
return self.fn(*args, **kwargs)
File "/root/axlearn/common/input_tf_data.py", line 799, in batch
raise ValueError(
ValueError: global_batch_size (16.0) must be divisible by number of JAX processes (data feeds) (64).
Metadata
Metadata
Assignees
Labels
No labels