From 3efe1ebf827ff51d804656c9c5c241643268a32e Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:47:11 -0700 Subject: [PATCH] update runtime error message for minibatch (#8243) --- test/spmd/test_xla_sharding.py | 2 +- torch_xla/core/xla_model.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 29749a19596..45a7b2796c9 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1373,7 +1373,7 @@ def test_data_loader_with_non_batch_size_and_mini_batch(self): mesh, ('data', None, None, None), minibatch=True)) with self.assertRaisesRegex( RuntimeError, - "When minibatch is configured, batch dimension of the tensor must be divisible by local runtime device count*" + "When minibatch is configured, the per-host batch size must be divisible by local runtime device count. Per host input data shape *" ): data, _ = iter(train_device_loader).__next__() diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 931115db6d8..607f1cb9c57 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -1309,10 +1309,9 @@ def convert_fn(tensors): if sharding and tensor.dim() > 0 and (tensor.size()[0] % local_runtime_device_count) != 0: raise RuntimeError( - "When minibatch is configured, batch dimension of the tensor " + - "must be divisible by local runtime device count.input data shape " - + - f"={tensor.size()}, local_runtime_device_count = {local_runtime_device_count}" + "When minibatch is configured, the per-host batch size must be divisible " + + "by local runtime device count. Per host input data shape " + + f"= {tensor.size()}, local_runtime_device_count = {local_runtime_device_count}" ) xtensors = torch_xla._XLAC._xla_tensors_from_aten(tensors, devices,