Skip to content

Commit 1265270

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Update EnqueueExecution to return PjRtRawLoadedExecutable::RawExecuteResult
directly. PiperOrigin-RevId: 861421465
1 parent 1ddfc07 commit 1265270

File tree

2 files changed

+67
-69
lines changed

2 files changed

+67
-69
lines changed

xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 65 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,14 +1655,17 @@ PjRtStreamExecutorClient::RunAsync(
16551655
// converted on success.
16561656
// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
16571657
// to initialize `run_options`.
1658-
absl::Status PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
1658+
absl::StatusOr<PjRtRawLoadedExecutable::RawExecuteResult>
1659+
PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
16591660
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
16601661
const RunId& run_id, const ExecuteOptions& options, PjRtDevice* device,
16611662
absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> flat_arguments,
16621663
absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> results,
16631664
PjRtDeviceEventSet& events,
16641665
std::shared_ptr<DeviceAssignment> device_assignment,
1665-
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const {
1666+
bool fill_future) const {
1667+
const uint64_t start_time_usecs = tsl::Env::Default()->NowMicros();
1668+
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
16661669
int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
16671670
->local_device_state()
16681671
->local_device_id()
@@ -1838,7 +1841,57 @@ absl::Status PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
18381841
device_assignment)}]() {});
18391842
}
18401843

1841-
return absl::OkStatus(); // std::move(results);
1844+
auto definition_event = [&]() -> tsl::RCReference<PjRtDeviceEvent> {
1845+
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1846+
se::Stream* stream = device_state->compute_stream();
1847+
1848+
auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint(
1849+
device_state->GetNextComputeStreamSyncPoint(),
1850+
client_->async_work_runner());
1851+
if (!definition_event_or.ok()) {
1852+
StallStreamOnError(device_state, stream);
1853+
return client_->CreateErrorDeviceEvent(definition_event_or.status());
1854+
}
1855+
return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
1856+
std::move(*definition_event_or), "PjRtStreamExecutorLoadedExecutable",
1857+
"Execute");
1858+
}();
1859+
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
1860+
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1861+
buffers_to_release.reserve(results.size() + flat_arguments.size());
1862+
for (auto& node : results) {
1863+
buffers_to_release.push_back(
1864+
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get())
1865+
->device_buffer());
1866+
}
1867+
for (auto& node : flat_arguments) {
1868+
buffers_to_release.push_back(
1869+
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get())
1870+
->device_buffer());
1871+
}
1872+
}
1873+
std::optional<Future<>> maybe_future;
1874+
if (fill_future) {
1875+
auto [promise, future] = MakePromise<>();
1876+
maybe_future = std::move(future);
1877+
compute_callbacks.push_back(
1878+
[promise = std::move(promise)]() mutable { promise.Set(); });
1879+
}
1880+
definition_event->AndThen(
1881+
[callbacks{std::move(compute_callbacks)},
1882+
buffers_to_release{std::move(buffers_to_release)}]() mutable {
1883+
for (auto& fn : callbacks) {
1884+
std::move(fn)();
1885+
}
1886+
callbacks.clear();
1887+
});
1888+
metrics::ReportExecutableEnqueueTime(tsl::Env::Default()->NowMicros() -
1889+
start_time_usecs);
1890+
1891+
PjRtRawLoadedExecutable::RawExecuteResult execute_results;
1892+
execute_results.future = std::move(maybe_future);
1893+
execute_results.primary_execute_event = std::move(definition_event);
1894+
return execute_results;
18421895
}
18431896

18441897
static absl::Status GetFirstInputError(
@@ -1866,7 +1919,6 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
18661919
absl::Span<PjRtBuffer* const> argument_handles, int replica, int partition,
18671920
const RunId& run_id, const ExecuteOptions& options, bool fill_future,
18681921
PjRtDevice* device) const {
1869-
const uint64_t start_time_usecs = tsl::Env::Default()->NowMicros();
18701922
std::shared_ptr<DeviceAssignment> device_assignment;
18711923
if (device == nullptr) {
18721924
CHECK(device_assignment_ != nullptr);
@@ -1919,7 +1971,6 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
19191971
VLOG(1) << "Replica " << replica << ", partition " << partition
19201972
<< " mapped to device ordinal for execution: " << device_ordinal;
19211973

1922-
std::vector<absl::AnyInvocable<void() &&>> compute_callbacks;
19231974
absl::InlinedVector<CommonPjRtBuffer::ScopedHold, 4> device_buffers;
19241975
device_buffers.reserve(argument_handles.size());
19251976
PjRtStreamExecutorDeviceEventSet events(argument_handles.size());
@@ -1936,81 +1987,28 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
19361987
result_shape_, device_buffers,
19371988
executable_->executable()->module().input_output_alias_config(),
19381989
device, output_memory_space_kind_ids_));
1939-
absl::Status status =
1990+
1991+
absl::StatusOr<PjRtRawLoadedExecutable::RawExecuteResult> status_or_results =
19401992
EnqueueExecution(argument_handles, replica, partition, run_id, options,
19411993
device, input_buffers, result_buffer, events,
1942-
std::move(device_assignment), compute_callbacks);
1994+
std::move(device_assignment), fill_future);
19431995

1944-
if (!status.ok()) {
1945-
LOG(ERROR) << "Execution of replica " << replica << " failed: " << status;
1946-
return status;
1996+
if (!status_or_results.ok()) {
1997+
LOG(ERROR) << "Execution of replica " << replica
1998+
<< " failed: " << status_or_results.status();
1999+
return status_or_results.status();
19472000
}
19482001

1949-
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1950-
se::Stream* stream = device_state->compute_stream();
1951-
1952-
auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint(
1953-
device_state->GetNextComputeStreamSyncPoint(),
1954-
client_->async_work_runner());
1955-
if (!definition_event_or.ok()) {
1956-
StallStreamOnError(device_state, stream);
1957-
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
1958-
if (b.type() == CommonPjRtBuffer::ScopedHold::kDonation) {
1959-
// Even though there was an error we need to call ConfirmDonation, which
1960-
// renders b invalid, since the computation has been enqueued and b has
1961-
// been donated.
1962-
b.ConfirmDonation();
1963-
}
1964-
}
1965-
return definition_event_or.status();
1966-
}
1967-
std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
1968-
if (device_state->allocation_model() == LocalDeviceState::kSynchronous) {
1969-
buffers_to_release.reserve(result_buffer.size() + device_buffers.size());
1970-
for (auto& node : result_buffer) {
1971-
buffers_to_release.push_back(
1972-
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get())
1973-
->device_buffer());
1974-
}
1975-
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
1976-
if (b.type() == CommonPjRtBuffer::ScopedHold::kUsage) {
1977-
buffers_to_release.push_back(
1978-
tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(
1979-
b.buffer()->raw_buffer().get())
1980-
->device_buffer());
1981-
}
1982-
}
1983-
}
1984-
auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
1985-
*definition_event_or, "PjRtStreamExecutorLoadedExecutable", "Execute");
1986-
PjRtRawLoadedExecutable::RawExecuteResult results;
1987-
std::optional<Future<>>& maybe_future = results.future;
1988-
results.primary_execute_event = definition_event;
1989-
if (fill_future) {
1990-
auto [promise, future] = MakePromise<>();
1991-
maybe_future = std::move(future);
1992-
compute_callbacks.push_back(
1993-
[promise = std::move(promise)]() mutable { promise.Set(); });
1994-
}
1995-
definition_event->AndThen(
1996-
[callbacks{std::move(compute_callbacks)},
1997-
buffers_to_release{std::move(buffers_to_release)}]() mutable {
1998-
for (auto& fn : callbacks) {
1999-
std::move(fn)();
2000-
}
2001-
callbacks.clear();
2002-
});
2002+
auto& results = status_or_results.value();
20032003

20042004
for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
20052005
if (b.type() == CommonPjRtBuffer::ScopedHold::kUsage) {
2006-
b.ConvertUsageHold(definition_event);
2006+
b.ConvertUsageHold(results.primary_execute_event);
20072007
} else {
20082008
CHECK(b.type() == CommonPjRtBuffer::ScopedHold::kDonation);
20092009
b.ConfirmDonation();
20102010
}
20112011
}
2012-
metrics::ReportExecutableEnqueueTime(tsl::Env::Default()->NowMicros() -
2013-
start_time_usecs);
20142012
return PjRtLoadedExecutable::Result(
20152013
{/*future=*/std::move(results.future),
20162014
/*buffers=*/client()->CreateOutputs(

xla/pjrt/pjrt_stream_executor_client.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,15 +656,15 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
656656
// donated due to aliases that were specified by the computation.
657657
absl::Status SetUpDonation(bool tuple_inputs);
658658

659-
absl::Status EnqueueExecution(
659+
absl::StatusOr<PjRtRawLoadedExecutable::RawExecuteResult> EnqueueExecution(
660660
absl::Span<PjRtBuffer* const> argument_handles, int replica,
661661
int partition, const RunId& run_id, const ExecuteOptions& options,
662662
PjRtDevice* device,
663663
absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> flat_arguments,
664664
absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> results,
665665
PjRtDeviceEventSet& events,
666666
std::shared_ptr<DeviceAssignment> device_assignment,
667-
std::vector<absl::AnyInvocable<void() &&>>& compute_callbacks) const;
667+
bool fill_future) const;
668668

669669
absl::StatusOr<Result> ExecuteHelper(
670670
absl::Span<PjRtBuffer* const> argument_handles, int replica,

0 commit comments

Comments
 (0)