Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1937,7 +1937,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetStreamExecutorGpuClient(
#if TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::RocmName();
#elif TENSORFLOW_USE_SYCL
auto pjrt_platform_name = xla::SyclName();
auto pjrt_platform_name = xla::OneapiName();
#else // TENSORFLOW_USE_ROCM
auto pjrt_platform_name = xla::CudaName();
#endif // TENSORFLOW_USE_ROCM
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/gpu/tfrt/tfrt_gpu_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1305,7 +1305,7 @@ absl::StatusOr<std::unique_ptr<PjRtClient>> GetTfrtGpuClientInternal(
#if TENSORFLOW_USE_ROCM
const auto* pjrt_platform_name = xla::RocmName();
#elif TENSORFLOW_USE_SYCL
const auto* pjrt_platform_name = xla::SyclName();
const auto* pjrt_platform_name = xla::OneapiName();
#else // TENSORFLOW_USE_ROCM
const auto* pjrt_platform_name = xla::CudaName();
#endif // TENSORFLOW_USE_ROCM
Expand Down
17 changes: 12 additions & 5 deletions xla/pjrt/pjrt_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ inline const char* RocmName() {
static constexpr char kRocmName[] = "rocm";
return kRocmName;
}
inline const char* SyclName() {
static constexpr char kSyclName[] = "sycl";
return kSyclName;
inline const char* OneapiName() {
static constexpr char kOneapiName[] = "oneapi";
return kOneapiName;
}
inline const char* TpuName() {
static constexpr char kTpuName[] = "tpu";
Expand All @@ -79,9 +79,16 @@ inline PjRtPlatformId RocmId() {
static const PjRtPlatformId kRocmId = tsl::Fingerprint64(RocmName());
return kRocmId;
}
inline PjRtPlatformId OneapiId() {
static const PjRtPlatformId kOneapiId = tsl::Fingerprint64(OneapiName());
return kOneapiId;
}

// Temporarily keep SyclId() as there are references to it in Jaxlib.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also add temporarily SyclName here? There are references to it from the Tensorflow repository. So the sequence of changes should be: this PR, then change tensorflow references, then delete SyclName.

// TODO(intel-tf): Remove this function once Jaxlib is updated to use
// OneapId() instead of SyclId()
inline PjRtPlatformId SyclId() {
static const PjRtPlatformId kSyclId = tsl::Fingerprint64(SyclName());
return kSyclId;
return OneapiId();
}
inline PjRtPlatformId TpuId() {
static const PjRtPlatformId kTpuId = tsl::Fingerprint64(TpuName());
Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/profiling/device_time_measurement.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void RecordDeviceTimeMeasurement(
inline DeviceTimeMeasurement::DeviceType GetDeviceType(
PjRtPlatformId platform_id) {
if (platform_id == CudaId() || platform_id == RocmId() ||
platform_id == SyclId()) {
platform_id == OneapiId()) {
return DeviceTimeMeasurement::DeviceType::kGpu;
}
if (platform_id == TpuId()) {
Expand Down
2 changes: 1 addition & 1 deletion xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
auto callbacks = std::make_unique<std::vector<void*>>();
// Forward callbacks via FFI's ExecutionContext for CPU/GPU platforms only.
if (platform_id == CpuId() || platform_id == CudaId() ||
platform_id == RocmId() || platform_id == SyclId()) {
platform_id == RocmId() || platform_id == OneapiId()) {
for (const auto& loaded_host_callback : *all_loaded_host_callbacks_) {
auto* ffi_loaded_host_callback =
llvm::dyn_cast<PjRtFfiLoadedHostCallback>(loaded_host_callback.get());
Expand Down
2 changes: 1 addition & 1 deletion xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ bool HloRunnerPjRt::HasProperty(const HloRunnerPropertyTag::Type tag) const {
return pjrt_client_->platform_name() == CudaName();
}
if (tag == HloRunnerPropertyTag::kUsingGpuOneAPI) {
return pjrt_client_->platform_name() == SyclName();
return pjrt_client_->platform_name() == OneapiName();
}
return false;
}
Expand Down