diff --git a/WORKSPACE b/WORKSPACE index 8a119d8a236..44c02a0785c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,7 +50,7 @@ new_local_repository( # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -xla_hash = 'd28bfbdc366627c9ac9f57fcaa512ff04de19d6f' +xla_hash = '8d06f3680ad046ea44f8e7159f52c728bb66c069' http_archive( name = "xla", diff --git a/setup.py b/setup.py index 1a808ec00aa..0e0f4971c05 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) USE_NIGHTLY = True # whether to use nightly or stable libtpu and jax -_date = '20250106' +_date = '20250113' _libtpu_version = f'0.0.8' _jax_version = f'0.4.39' _jaxlib_version = f'0.4.39' diff --git a/test/cpp/test_aten_xla_tensor_2.cpp b/test/cpp/test_aten_xla_tensor_2.cpp index 92d3bce02ed..28b107b1e6f 100755 --- a/test/cpp/test_aten_xla_tensor_2.cpp +++ b/test/cpp/test_aten_xla_tensor_2.cpp @@ -513,8 +513,6 @@ TEST_F(AtenXlaTensorTest, TestLinalgVectorNormInDimsKeepDtype) { } TEST_F(AtenXlaTensorTest, TestLinalgEigh) { - // TODO: Broken by XLA pin update on 20250106. - GTEST_SKIP(); // Hardcode the test input to avoid numerical instability from randomness, // which is a problem in eigenvalue decomposition. auto complex64 = [](float real, float imag) { diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index a2310f61d35..2f174a4af22 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -154,7 +154,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { pack->shape = std::vector(pjrt_buffer->dimensions().begin(), pjrt_buffer->dimensions().end()); - xla::Layout xla_layout = xla::GetXlaLayoutUnsafe(pjrt_buffer->layout()); + xla::Layout xla_layout = pjrt_buffer->layout()->xla_layout(); pack->strides = StridesForShape(pjrt_buffer->element_type(), pjrt_buffer->dimensions(), xla_layout); dt.shape = reinterpret_cast(pack->shape.data()); diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index 213761748c7..0e5b714cd1e 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -63,6 +63,7 @@ cc_library( "@xla//xla:literal_util", "@xla//xla/client:xla_computation", "@xla//xla/hlo/ir:hlo", + "@xla//xla/pjrt:pjrt_client", ], ) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 30e648919b3..8caad6d230f 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -62,8 +62,7 @@ std::unordered_map build_index_map( xla::Shape host_output_shape(xla::PjRtBuffer* buffer) { xla::Shape shape = xla::ShapeUtil::MakeShape( buffer->element_type(), buffer->logical_dimensions().value()); - *shape.mutable_layout() = xla::GetXlaLayoutUnsafe(buffer->layout()); - + *shape.mutable_layout() = buffer->layout()->xla_layout(); return xla::ShapeUtil::DeviceShapeToHostShape(shape); }