diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 548fbcb5..54ac7104 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -205,6 +205,12 @@ def pipeline_llama_manual( logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + # TODO, support this? or just guard against it inside the lib + if job_config.training.batch_size % parallel_dims.pp != 0: + raise ValueError( + f"batch_size {job_config.training.batch_size} not divisible by pp dim, currently unsupported" + ) + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the # layers of the model that map to this stage, not the whole model.