Description
To reproduce:
run provided example generation_mllama.py
Result:
INFO:Neuron:Generating 4 hlos for key: context_encoding_model
INFO:Neuron:Started loading module context_encoding_model
INFO:Neuron:Finished loading module context_encoding_model in 2.552138328552246 seconds
INFO:Neuron:generating HLO: context_encoding_model, input example shape = torch.Size([1, 128])
/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/neuronx_distributed/parallel_layers/layers.py:522: FutureWarning: torch.cuda.amp.autocast(args...)
is deprecated. Please use torch.amp.autocast('cuda', args...)
instead.
with torch.cuda.amp.autocast(enabled=False):
Traceback (most recent call last):
File "/home/ubuntu/repos/HPC/NxDInference-2_21/examples/generation_mllama.py", line 142, in
run_llama_generate()
File "/home/ubuntu/repos/HPC/NxDInference-2_21/examples/generation_mllama.py", line 70, in run_llama_generate
model.compile(traced_model_path)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/application_base.py", line 145, in compile
traced_model = self.get_builder(debug).trace(initialize_model_weights=False)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/neuronx_distributed/trace/model_builder.py", line 229, in trace
hlo_artifact_collection = self._generate_hlo(key)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/neuronx_distributed/trace/model_builder.py", line 404, in _generate_hlo
hlo_artifacts = torch_neuronx.xla_impl.trace.generate_hlo(
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch_neuronx/xla_impl/trace.py", line 450, in generate_hlo
) = xla_trace(
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch_neuronx/xla_impl/hlo_conversion.py", line 138, in xla_trace
outputs = func(*example_inputs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama.py", line 1207, in forward
vision_tokens = self.vision_model(pixel_values, aspect_ratios) * has_image.view(
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 665, in forward
vision_tokens = self.vision_encoder(images, aspect_ratios, aspect_ratios_ids)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 525, in forward
x, int_x = self.transformer(x, return_intermediate=self.return_intermediate, mask=attn_mask)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 365, in forward
x = r(x, mask=mask)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 311, in forward
x = x + _gate_attn * self.attn(self.ln_1(x), mask)[0]
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/venv_nxd_221/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/modules/attention/attention_base.py", line 406, in forward
attn_output, flash_attn_strategy = self.perform_prefill(
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 252, in perform_prefill
attn_output = self.perform_maskless_sdpa(
File "/home/ubuntu/repos/HPC/NxDInference-2_21/src/neuronx_distributed_inference/models/mllama/modeling_mllama_vision.py", line 186, in perform_maskless_sdpa
Q_cat = torch.cat([Q, mask_gen_vectors.unsqueeze(3).to(Q.dtype)], dim=3)
RuntimeError: torch_xla/csrc/tensor_methods.cpp:1161 : Check failed: xla::ShapeUtil::CompatibleIgnoringElementType(shapes.back(), tensor_shape)