Skip to content

split on second dimension of 2D array not working #8640

Open
@jeffhataws

Description

🐛 Bug

When running a small example to split 2D array in the second dimension, the resulting tensors don't have the expected data. The results are different between CPU and XLA-CPU.

To Reproduce

Run. the follow test:

import torch
import torch_xla

a_golden = torch.arange(12, device="cpu").reshape(3, 4)
b_golden, c_golden = a_golden.split([3, 1], dim=-1)
a_xla = torch.arange(12, device="xla").reshape(3, 4)
b_xla, c_xla = a_xla.split([3, 1], dim=-1)

print("a original:", a_golden)
print("b golden :", b_golden)
print("b xla :", b_xla)
print("c golden :", c_golden)
print("c xla :", c_xla)

torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
torch.testing.assert_close(c_golden, c_xla.cpu(), rtol=0, atol=0)

Save as test_split.py and run:

PJRT_DEVICE=CPU python test_split.py
WARNING:torch_neuron:RANK environment variable is not set, defaulting to 0.
WARNING:torch_neuron:LOCAL RANK environment variable is not set to 0, defaulting to 0.
a original: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
b golden : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])
b xla : tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
        [ 7],
        [11]])
c xla : tensor([[3],
        [4],
        [5]], device='xla:0')
Traceback (most recent call last):
  File "/home/ubuntu/transformers/examples/pytorch/text-classification/test_split.py", line 15, in <module>
    torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
  File "/home/ubuntu/aws_neuron_venv_pt26/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)

Expected behavior

XLA CPU result should match CPU results

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.6 (also 2.1)

Additional context

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions