@@ -1655,14 +1655,17 @@ PjRtStreamExecutorClient::RunAsync(
16551655// converted on success.
16561656// When `options` has non-zero `launch_id`, use `launch_id` instead of `run_id`
16571657// to initialize `run_options`.
1658- absl::Status PjRtStreamExecutorLoadedExecutable::EnqueueExecution (
1658+ absl::StatusOr<PjRtRawLoadedExecutable::RawExecuteResult>
1659+ PjRtStreamExecutorLoadedExecutable::EnqueueExecution (
16591660 absl::Span<PjRtBuffer* const > argument_handles, int replica, int partition,
16601661 const RunId& run_id, const ExecuteOptions& options, PjRtDevice* device,
16611662 absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> flat_arguments,
16621663 absl::Span<const tsl::RCReference<CommonPjRtRawBuffer>> results,
16631664 PjRtDeviceEventSet& events,
16641665 std::shared_ptr<DeviceAssignment> device_assignment,
1665- std::vector<absl::AnyInvocable<void () &&>>& compute_callbacks) const {
1666+ bool fill_future) const {
1667+ const uint64_t start_time_usecs = tsl::Env::Default ()->NowMicros ();
1668+ std::vector<absl::AnyInvocable<void () &&>> compute_callbacks;
16661669 int device_ordinal = tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)
16671670 ->local_device_state ()
16681671 ->local_device_id ()
@@ -1838,7 +1841,57 @@ absl::Status PjRtStreamExecutorLoadedExecutable::EnqueueExecution(
18381841 device_assignment)}]() {});
18391842 }
18401843
1841- return absl::OkStatus (); // std::move(results);
1844+ auto definition_event = [&]() -> tsl::RCReference<PjRtDeviceEvent> {
1845+ LocalDeviceState* device_state = &(client_->device_state (device_ordinal));
1846+ se::Stream* stream = device_state->compute_stream ();
1847+
1848+ auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint (
1849+ device_state->GetNextComputeStreamSyncPoint (),
1850+ client_->async_work_runner ());
1851+ if (!definition_event_or.ok ()) {
1852+ StallStreamOnError (device_state, stream);
1853+ return client_->CreateErrorDeviceEvent (definition_event_or.status ());
1854+ }
1855+ return tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
1856+ std::move (*definition_event_or), " PjRtStreamExecutorLoadedExecutable" ,
1857+ " Execute" );
1858+ }();
1859+ std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
1860+ if (device_state->allocation_model () == LocalDeviceState::kSynchronous ) {
1861+ buffers_to_release.reserve (results.size () + flat_arguments.size ());
1862+ for (auto & node : results) {
1863+ buffers_to_release.push_back (
1864+ tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get ())
1865+ ->device_buffer ());
1866+ }
1867+ for (auto & node : flat_arguments) {
1868+ buffers_to_release.push_back (
1869+ tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get ())
1870+ ->device_buffer ());
1871+ }
1872+ }
1873+ std::optional<Future<>> maybe_future;
1874+ if (fill_future) {
1875+ auto [promise, future] = MakePromise<>();
1876+ maybe_future = std::move (future);
1877+ compute_callbacks.push_back (
1878+ [promise = std::move (promise)]() mutable { promise.Set (); });
1879+ }
1880+ definition_event->AndThen (
1881+ [callbacks{std::move (compute_callbacks)},
1882+ buffers_to_release{std::move (buffers_to_release)}]() mutable {
1883+ for (auto & fn : callbacks) {
1884+ std::move (fn)();
1885+ }
1886+ callbacks.clear ();
1887+ });
1888+ metrics::ReportExecutableEnqueueTime (tsl::Env::Default ()->NowMicros () -
1889+ start_time_usecs);
1890+
1891+ PjRtRawLoadedExecutable::RawExecuteResult execute_results;
1892+ execute_results.future = std::move (maybe_future);
1893+ execute_results.primary_execute_event = std::move (definition_event);
1894+ return execute_results;
18421895}
18431896
18441897static absl::Status GetFirstInputError (
@@ -1866,7 +1919,6 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
18661919 absl::Span<PjRtBuffer* const > argument_handles, int replica, int partition,
18671920 const RunId& run_id, const ExecuteOptions& options, bool fill_future,
18681921 PjRtDevice* device) const {
1869- const uint64_t start_time_usecs = tsl::Env::Default ()->NowMicros ();
18701922 std::shared_ptr<DeviceAssignment> device_assignment;
18711923 if (device == nullptr ) {
18721924 CHECK (device_assignment_ != nullptr );
@@ -1919,7 +1971,6 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
19191971 VLOG (1 ) << " Replica " << replica << " , partition " << partition
19201972 << " mapped to device ordinal for execution: " << device_ordinal;
19211973
1922- std::vector<absl::AnyInvocable<void () &&>> compute_callbacks;
19231974 absl::InlinedVector<CommonPjRtBuffer::ScopedHold, 4 > device_buffers;
19241975 device_buffers.reserve (argument_handles.size ());
19251976 PjRtStreamExecutorDeviceEventSet events (argument_handles.size ());
@@ -1936,81 +1987,28 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper(
19361987 result_shape_, device_buffers,
19371988 executable_->executable ()->module ().input_output_alias_config (),
19381989 device, output_memory_space_kind_ids_));
1939- absl::Status status =
1990+
1991+ absl::StatusOr<PjRtRawLoadedExecutable::RawExecuteResult> status_or_results =
19401992 EnqueueExecution (argument_handles, replica, partition, run_id, options,
19411993 device, input_buffers, result_buffer, events,
1942- std::move (device_assignment), compute_callbacks );
1994+ std::move (device_assignment), fill_future );
19431995
1944- if (!status.ok ()) {
1945- LOG (ERROR) << " Execution of replica " << replica << " failed: " << status;
1946- return status;
1996+ if (!status_or_results.ok ()) {
1997+ LOG (ERROR) << " Execution of replica " << replica
1998+ << " failed: " << status_or_results.status ();
1999+ return status_or_results.status ();
19472000 }
19482001
1949- LocalDeviceState* device_state = &(client_->device_state (device_ordinal));
1950- se::Stream* stream = device_state->compute_stream ();
1951-
1952- auto definition_event_or = device_state->GetEventForComputeStreamSyncPoint (
1953- device_state->GetNextComputeStreamSyncPoint (),
1954- client_->async_work_runner ());
1955- if (!definition_event_or.ok ()) {
1956- StallStreamOnError (device_state, stream);
1957- for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
1958- if (b.type () == CommonPjRtBuffer::ScopedHold::kDonation ) {
1959- // Even though there was an error we need to call ConfirmDonation, which
1960- // renders b invalid, since the computation has been enqueued and b has
1961- // been donated.
1962- b.ConfirmDonation ();
1963- }
1964- }
1965- return definition_event_or.status ();
1966- }
1967- std::vector<tsl::AsyncValueRef<RawSEDeviceMemory>> buffers_to_release;
1968- if (device_state->allocation_model () == LocalDeviceState::kSynchronous ) {
1969- buffers_to_release.reserve (result_buffer.size () + device_buffers.size ());
1970- for (auto & node : result_buffer) {
1971- buffers_to_release.push_back (
1972- tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(node.get ())
1973- ->device_buffer ());
1974- }
1975- for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
1976- if (b.type () == CommonPjRtBuffer::ScopedHold::kUsage ) {
1977- buffers_to_release.push_back (
1978- tensorflow::down_cast<PjRtStreamExecutorRawBuffer*>(
1979- b.buffer ()->raw_buffer ().get ())
1980- ->device_buffer ());
1981- }
1982- }
1983- }
1984- auto definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
1985- *definition_event_or, " PjRtStreamExecutorLoadedExecutable" , " Execute" );
1986- PjRtRawLoadedExecutable::RawExecuteResult results;
1987- std::optional<Future<>>& maybe_future = results.future ;
1988- results.primary_execute_event = definition_event;
1989- if (fill_future) {
1990- auto [promise, future] = MakePromise<>();
1991- maybe_future = std::move (future);
1992- compute_callbacks.push_back (
1993- [promise = std::move (promise)]() mutable { promise.Set (); });
1994- }
1995- definition_event->AndThen (
1996- [callbacks{std::move (compute_callbacks)},
1997- buffers_to_release{std::move (buffers_to_release)}]() mutable {
1998- for (auto & fn : callbacks) {
1999- std::move (fn)();
2000- }
2001- callbacks.clear ();
2002- });
2002+ auto & results = status_or_results.value ();
20032003
20042004 for (CommonPjRtBuffer::ScopedHold& b : device_buffers) {
20052005 if (b.type () == CommonPjRtBuffer::ScopedHold::kUsage ) {
2006- b.ConvertUsageHold (definition_event );
2006+ b.ConvertUsageHold (results. primary_execute_event );
20072007 } else {
20082008 CHECK (b.type () == CommonPjRtBuffer::ScopedHold::kDonation );
20092009 b.ConfirmDonation ();
20102010 }
20112011 }
2012- metrics::ReportExecutableEnqueueTime (tsl::Env::Default ()->NowMicros () -
2013- start_time_usecs);
20142012 return PjRtLoadedExecutable::Result (
20152013 {/* future=*/ std::move (results.future ),
20162014 /* buffers=*/ client ()->CreateOutputs (
0 commit comments