-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --training.dataset c4
fails with
traceback : Traceback (most recent call last):
File "/data/users/ezyang/a/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
return f(*args, **kwargs)
File "/data/users/ezyang/a/torchtitan/torchtitan/train.py", line 351, in __init__
model = self.train_spec.parallelize_fn(
File "/data/users/ezyang/a/torchtitan/torchtitan/experiments/auto_parallel/parallelize_llama.py", line 67, in parallelize_llama
with AutoParallel(model, input_fn, world_mesh, mp_policy=mp_policy) as autop:
File "/data/users/ezyang/a/autoparallel/autoparallel/api.py", line 206, in __enter__
sharding_optimizer = ShardingOptimizer(
File "/data/users/ezyang/a/autoparallel/autoparallel/optimize_sharding.py", line 131, in __init__
self.strats = self.build_sharding_metadata()
File "/data/users/ezyang/a/autoparallel/autoparallel/optimize_sharding.py", line 179, in build_sharding_metadata
strat = get_placement_options(
File "/data/users/ezyang/a/autoparallel/autoparallel/utils.py", line 166, in get_placement_options
out_strat = get_op_strategy(op, op_schema)
File "/data/users/ezyang/a/autoparallel/autoparallel/dtensor_util/utils.py", line 228, in get_op_strategy
return propagator.op_strategy_funcs[op](op_schema)
File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_ops/_embedding_ops.py", line 233, in embedding_strategy
return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies)
File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_ops/utils.py", line 332, in expand_to_full_mesh_op_strategy
assert len(input_specs) == len(input_args_strategy), (input_specs, input_args_strategy)
AssertionError: ([], (<torch.distributed.tensor._op_schema.OpStrategy object at 0x7ffa40fff280>, <torch.distributed.tensor._op_schema.OpStrategy object at 0x7ffa40fff250>))
I confirmed that the model does work if I use the llama3 model instead of llama3_auto_parallel. This is not a big deal re functionality but it does make smoke testing torchtitan on single GPU more annoying.
Metadata
Metadata
Assignees
Labels
No labels