Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
wconstab committed May 18, 2024
1 parent 4a4f642 commit 666e0d3
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def pipeline_llama_manual(

logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}")

# TODO, support this? or just guard against it inside the lib
if job_config.training.batch_size % parallel_dims.pp != 0:
raise ValueError(
f"batch_size {job_config.training.batch_size} not divisible by pp dim, currently unsupported"
)

# TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and
# get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the
# layers of the model that map to this stage, not the whole model.
Expand Down

0 comments on commit 666e0d3

Please sign in to comment.