Skip to content

split on second dimension of 2D array not working with XLA_DISABLE_FUNCTIONALIZATION=1 #8640

Open
@jeffhataws

Description

@jeffhataws

🐛 Bug

When running a small example to split 2D array in the second dimension, the resulting tensors don't have the expected data. The results are different between CPU and XLA-CPU.

To Reproduce

Run. the follow test:

import torch
import torch_xla

a_golden = torch.arange(12, device="cpu").reshape(3, 4)
b_golden, c_golden = a_golden.split([3, 1], dim=-1)
a_xla = torch.arange(12, device="xla").reshape(3, 4)
b_xla, c_xla = a_xla.split([3, 1], dim=-1)

print("a original:", a_golden)
print("b golden :", b_golden)
print("b xla :", b_xla)
print("c golden :", c_golden)
print("c xla :", c_xla)

torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
torch.testing.assert_close(c_golden, c_xla.cpu(), rtol=0, atol=0)

Save as test_split.py and run:

PJRT_DEVICE=CPU python test_split.py
WARNING:torch_neuron:RANK environment variable is not set, defaulting to 0.
WARNING:torch_neuron:LOCAL RANK environment variable is not set to 0, defaulting to 0.
a original: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
b golden : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])
b xla : tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
        [ 7],
        [11]])
c xla : tensor([[3],
        [4],
        [5]], device='xla:0')
Traceback (most recent call last):
  File "/home/ubuntu/transformers/examples/pytorch/text-classification/test_split.py", line 15, in <module>
    torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
  File "/home/ubuntu/aws_neuron_venv_pt26/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)

Expected behavior

XLA CPU result should match CPU results

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.6 (also 2.1)

Additional context

Metadata

Metadata

Assignees

Labels

functionalization-disabledIssues specifically for when functionalization is disabled.pytorch divergenceXLA behavior doesn't match Pytorch eager frontend

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions