Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions xla/pjrt/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_is_configured",
)
load(
"@local_config_sycl//sycl:build_defs.bzl",
"if_sycl_is_configured",
)
load(
"//xla:xla.default.bzl",
"xla_cc_binary",
Expand Down Expand Up @@ -433,11 +437,10 @@ cc_library(
name = "pjrt_c_api_gpu_internal",
srcs = ["pjrt_c_api_gpu_internal.cc"],
hdrs = ["pjrt_c_api_gpu_internal.h"],
local_defines = (
if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([
"TENSORFLOW_USE_ROCM=1",
])
),
local_defines =
if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]) + if_sycl_is_configured([
"TENSORFLOW_USE_SYCL=1",
]),
visibility = ["//visibility:public"],
deps = [
":pjrt_c_api_custom_partitioner_extension_hdrs",
Expand Down Expand Up @@ -526,6 +529,8 @@ xla_cc_binary(
"//xla/stream_executor:cuda_platform",
]) + if_rocm_is_configured([
"//xla/stream_executor:rocm_platform",
]) + if_sycl_is_configured([
"//xla/stream_executor:sycl_platform",
]),
)

Expand Down
23 changes: 17 additions & 6 deletions xla/pjrt/c/pjrt_c_api_gpu_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ namespace gpu_plugin {

#if TENSORFLOW_USE_ROCM
#define PJRT_GPU_PLUGIN_PLATFORM_NAME "ROCM"
#elif TENSORFLOW_USE_SYCL
#define PJRT_GPU_PLUGIN_PLATFORM_NAME "ONEAPI"
#else
#define PJRT_GPU_PLUGIN_PLATFORM_NAME "CUDA"
#endif
Expand Down Expand Up @@ -265,13 +267,22 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create(
PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size));

// Determine the platform ID and name based on the platform.
xla::PjRtPlatformId platform_id =
(std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmId()
: xla::CudaId();
std::string platform_name =
(std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmName()
: xla::CudaName();

xla::PjRtPlatformId platform_id;
std::string platform_name;

absl::string_view plugin_platform = PJRT_GPU_PLUGIN_PLATFORM_NAME;

if (plugin_platform == "ROCM") {
platform_id = xla::RocmId();
platform_name = xla::RocmName();
} else if (plugin_platform == "ONEAPI") {
platform_id = xla::OneapiId(); // Assuming this is defined in XLA
platform_name = xla::OneapiName(); // Assuming this is defined in XLA
} else {
platform_id = xla::CudaId();
platform_name = xla::CudaName();
}
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options =
pjrt::ConvertFromPjRtNamedValueList(args->create_options,
args->num_options);
Expand Down
22 changes: 16 additions & 6 deletions xla/pjrt/c/pjrt_c_api_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ using ::testing::ElementsAreArray;
using ::testing::HasSubstr;
using ::testing::IsNull;

#ifdef TENSORFLOW_USE_ROCM
#if defined(TENSORFLOW_USE_ROCM)
const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); },
/*platform_name=*/"rocm"),
true);
#else // TENSORFLOW_USE_ROCM
#elif defined(TENSORFLOW_USE_SYCL)
const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); },
/*platform_name=*/"oneapi"),
true);
#else // Default to CUDA
const bool kUnused = (RegisterPjRtCApiTestFactory([]() { return GetPjrtApi(); },
/*platform_name=*/"cuda"),
true);
Expand Down Expand Up @@ -678,7 +682,7 @@ TEST(PjrtCApiGpuAllocatorTest, ValidOptionsParsing) {
for (const std::string& allocator_option : allocator_options) {
#ifdef TENSORFLOW_USE_ROCM
if (allocator_option == "cuda_async") {
VLOG(1) << "cuda_async allocator not available on ROCm!";
VLOG(1) << "cuda_async allocator not available on ROCm and SYCL!";
continue;
}
#endif
Expand Down Expand Up @@ -767,6 +771,7 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) {
auto api = GetPjrtApi();
std::string expected_platform_name_for_cuda = "cuda";
std::string expected_platform_name_for_rocm = "rocm";
std::string expected_platform_name_for_oneapi = "oneapi";
absl::flat_hash_map<std::string, xla::PjRtValueType> options = {
{"platform_name", static_cast<std::string>("gpu")},
{"allocator", static_cast<std::string>("default")},
Expand Down Expand Up @@ -799,7 +804,8 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) {
EXPECT_EQ(platform_name_error, nullptr);
EXPECT_THAT(platform_name_args.platform_name,
testing::AnyOf(expected_platform_name_for_cuda,
expected_platform_name_for_rocm));
expected_platform_name_for_rocm,
expected_platform_name_for_oneapi));

PJRT_Client_Destroy_Args destroy_args;
destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE;
Expand Down Expand Up @@ -949,7 +955,9 @@ TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) {
EXPECT_TRUE((pjrt_topology->topology->platform_id() == xla::CudaId() &&
pjrt_topology->topology->platform_name() == xla::CudaName()) ||
(pjrt_topology->topology->platform_id() == xla::RocmId() &&
pjrt_topology->topology->platform_name() == xla::RocmName()));
pjrt_topology->topology->platform_name() == xla::RocmName()) ||
(pjrt_topology->topology->platform_id() == xla::OneapiId() &&
pjrt_topology->topology->platform_name() == xla::OneapiName()));

PJRT_TopologyDescription_Destroy_Args destroy_args;
destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE;
Expand Down Expand Up @@ -1016,7 +1024,9 @@ TEST(PJRTGpuDeviceTopologyTest, CreateExplicitGpuTopologyAndTargetConfig) {
EXPECT_TRUE((pjrt_topology->topology->platform_id() == xla::CudaId() &&
pjrt_topology->topology->platform_name() == xla::CudaName()) ||
(pjrt_topology->topology->platform_id() == xla::RocmId() &&
pjrt_topology->topology->platform_name() == xla::RocmName()));
pjrt_topology->topology->platform_name() == xla::RocmName()) ||
(pjrt_topology->topology->platform_id() == xla::OneapiId() &&
pjrt_topology->topology->platform_name() == xla::OneapiName()));

EXPECT_EQ(pjrt_topology->topology->ProcessCount().value(), 16 * 2);
EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions().size(), 16 * 2 * 4);
Expand Down
33 changes: 29 additions & 4 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
load("//xla:xla.default.bzl", "xla_cc_test")
load("//xla/pjrt/gpu:package_groups.bzl", "xla_gpu_internal_packages")
load("//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm")
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")
Expand Down Expand Up @@ -50,7 +51,8 @@ cc_library(
name = "se_gpu_pjrt_client",
srcs = ["se_gpu_pjrt_client.cc"],
hdrs = ["se_gpu_pjrt_client.h"],
defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"]),
defines = if_cuda(["GOOGLE_CUDA=1"]) + if_rocm(["TENSORFLOW_USE_ROCM=1"])
+ if_sycl(["TENSORFLOW_USE_SYCL=1"]),
visibility = internal_visibility(["//xla/pjrt/gpu:legacy_gpu_client_users"]),
deps = [
":gpu_helpers",
Expand Down Expand Up @@ -170,7 +172,7 @@ cc_library(
"@tsl//tsl/profiler/lib:connected_traceme",
"@tsl//tsl/profiler/lib:nvtx_utils",
"@tsl//tsl/profiler/lib:traceme",
] + if_cuda_or_rocm([
] + if_gpu_is_configured([
# keep sorted
"//xla:debug_options_flags",
"//xla/service/gpu:gpu_compiler",
Expand All @@ -184,9 +186,11 @@ cc_library(
]) + if_rocm([
# keep sorted
"@local_config_rocm//rocm:rocm_headers",
]) + if_sycl([
# keep sorted
"@oneapi//:headers",
]),
)

xla_test(
name = "se_gpu_pjrt_client_test",
size = "large",
Expand Down Expand Up @@ -512,6 +516,24 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "se_gpu_pjrt_compiler_sycl_registration",
srcs = ["se_gpu_pjrt_compiler_sycl_registration.cc"],
compatible_with = [],
tags = [
"gpu",
"oneapi-only",
],
deps = [
":se_gpu_pjrt_compiler_impl",
"//xla/pjrt:pjrt_compiler",
"//xla/service/gpu:intel_gpu_compiler",
"//xla/stream_executor/platform:initialize",
"//xla/stream_executor/sycl:sycl_platform_id",
],
alwayslink = 1,
)

cc_library(
name = "se_gpu_pjrt_compiler",
hdrs = ["se_gpu_pjrt_compiler.h"],
Expand All @@ -532,6 +554,8 @@ cc_library(
":se_gpu_pjrt_compiler_cuda_registration",
]) + if_rocm([
":se_gpu_pjrt_compiler_rocm_registration",
]) + if_sycl([
":se_gpu_pjrt_compiler_sycl_registration",
]),
)

Expand Down Expand Up @@ -662,3 +686,4 @@ xla_cc_test(
"@com_google_googletest//:gtest_main",
],
)

22 changes: 11 additions & 11 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ limitations under the License.
#include "tsl/profiler/lib/nvtx_utils.h"
#include "tsl/profiler/lib/traceme.h"

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) || defined(TENSORFLOW_USE_SYCL)
#include "xla/debug_options_flags.h"
#include "xla/pjrt/gpu/gpu_metrics.h"
#include "xla/pjrt/proto/compile_options.pb.h"
Expand All @@ -137,7 +137,7 @@ limitations under the License.
#include "xla/service/gpu/gpu_executable.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/xla.pb.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
Expand Down Expand Up @@ -1291,7 +1291,7 @@ StreamExecutorGpuClient::CompileAndLoad(mlir::ModuleOp module,
CompileOptions options) {
auto executable = PjRtStreamExecutorClient::CompileAndLoad(module, options);

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) || defined(TENSORFLOW_USE_SYCL)
for (const PjRtDevice* device : addressable_devices()) {
LocalDeviceState* local_device_state =
tensorflow::down_cast<const PjRtStreamExecutorDevice*>(device)
Expand All @@ -1308,7 +1308,7 @@ StreamExecutorGpuClient::CompileAndLoad(mlir::ModuleOp module,
}
}
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
return executable;
}

Expand All @@ -1318,7 +1318,7 @@ StreamExecutorGpuClient::CompileAndLoad(const XlaComputation& computation,
auto executable =
PjRtStreamExecutorClient::CompileAndLoad(computation, options);

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) || defined(TENSORFLOW_USE_SYCL)
for (const PjRtDevice* device : addressable_devices()) {
LocalDeviceState* local_device_state =
tensorflow::down_cast<const PjRtStreamExecutorDevice*>(device)
Expand All @@ -1335,7 +1335,7 @@ StreamExecutorGpuClient::CompileAndLoad(const XlaComputation& computation,
}
}
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
return executable;
}

Expand Down Expand Up @@ -1778,7 +1778,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
#if TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::RocmName();
#elif TENSORFLOW_USE_SYCL
auto pjrt_platform_name = xla::SyclName();
auto pjrt_platform_name = xla::OneapiName();
#else // TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::CudaName();
#endif // TENSORFLOW_USE_ROCM
Expand Down Expand Up @@ -1879,7 +1879,7 @@ std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> BuildLocalDevices(
return devices;
}

#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) || defined(TENSORFLOW_USE_SYCL)
static absl::Status CheckAlignment(const BufferAllocation& allocation,
se::DeviceAddressBase buffer, int arg_idx) {
const int64_t expected_alignment = [&] {
Expand All @@ -1900,7 +1900,7 @@ static absl::Status CheckAlignment(const BufferAllocation& allocation,
}
return absl::OkStatus();
}
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL

absl::StatusOr<PjRtStreamExecutorExecutionOutput>
StreamExecutorGpuClient::RunAsync(
Expand All @@ -1909,7 +1909,7 @@ StreamExecutorGpuClient::RunAsync(
absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> results,
ExecutableRunOptions run_options_inp, bool parameter_is_tupled_arguments,
absl::Span<const Shape> executable_parameter_shapes) {
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM) || defined(TENSORFLOW_USE_SYCL)
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(flat_arguments.size());
for (const Shape& arg_shape : executable_parameter_shapes) {
Expand Down Expand Up @@ -2131,7 +2131,7 @@ StreamExecutorGpuClient::RunAsync(
return PjRtStreamExecutorClient::RunAsync(
exec, device, flat_arguments, results, std::move(run_options_inp),
parameter_is_tupled_arguments, executable_parameter_shapes);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
}

} // namespace xla
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace {

bool IsGpuClient(const PjRtClient& client) {
return client.platform_id() == CudaId() || client.platform_id() == RocmId() ||
client.platform_id() == SyclId();
client.platform_id() == OneapiId();
}

bool IsSameTopology(const PjRtTopologyDescription& topology1,
Expand Down
30 changes: 30 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_compiler_sycl_registration.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>

#include "xla/pjrt/gpu/se_gpu_pjrt_compiler.h"
#include "xla/pjrt/pjrt_compiler.h"
#include "xla/stream_executor/platform/initialize.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"

namespace xla {

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(OneapiName(), std::make_unique<StreamExecutorGpuCompiler>(
stream_executor::sycl::kSyclPlatformId));
});

} // namespace xla
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtGpuClientInternal(
#if TENSORFLOW_USE_ROCM
const auto* pjrt_platform_name = xla::RocmName();
#elif TENSORFLOW_USE_SYCL
const auto* pjrt_platform_name = xla::SyclName();
const auto* pjrt_platform_name = xla::OneapiName();
#else // TENSORFLOW_USE_ROCM
const auto* pjrt_platform_name = xla::CudaName();
#endif // TENSORFLOW_USE_ROCM
Expand Down
Loading