Skip to content

Commit 29282fa

Browse files
CopyFromBuffer Device2Device BufferInstance transfer workaround (#2277)
### Ticket Revealed by #2258. ### Problem description After #1657, output BufferInstances no longer have a host runtime tensor. If we try to do a (PJRT) device to device transfer (even though that is not really meaningful in the current way we model MeshDevice/PJRT Device Instance), the copyToBuffer path assumes an existing host runtime tensor on the source buffer instance. ### What's changed In copyToBuffer, if a device runtime tensor exists on the copy source buffer instance, transfer it to host as the source of truth. ### Checklist - [x] New/Existing tests provide coverage for changes
1 parent ab2d8a8 commit 29282fa

File tree

4 files changed

+96
-9
lines changed

4 files changed

+96
-9
lines changed

pjrt_implementation/inc/api/buffer_instance.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class BufferInstance {
181181
std::optional<uint32_t> device_id = std::nullopt);
182182

183183
// Copies the tensor inside the src_buffer to the tensor of this buffer.
184+
// Currently only used for device to device transfer in copy construction
185+
// of new buffer instance.
184186
void copyFromBuffer(const BufferInstance *src_buffer);
185187

186188
// Calculates required tensor shape.

pjrt_implementation/src/api/buffer_instance.cc

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,21 +249,49 @@ void BufferInstance::copyFromHost(
249249
}
250250

251251
void BufferInstance::copyFromBuffer(const BufferInstance *src_buffer) {
252+
DLOG_F(LOG_DEBUG, "BufferInstance::copyFromBuffer");
252253
::tt::target::DataType runtime_data_type =
253254
tt::pjrt::data_type_utils::convertPJRTToRuntimeDataType(
254255
src_buffer->m_data_type);
256+
255257
std::uint32_t element_size =
256258
tt::runtime::utils::dataTypeElementSize(runtime_data_type);
257259
std::vector<std::uint32_t> shape = calculateShape(
258260
src_buffer->getDimensionsRaw(), src_buffer->getNumberOfDimensions());
259261
std::vector<std::uint32_t> strides = calculateStrides(
260262
src_buffer->getNumberOfDimensions(), nullptr, 0, element_size);
261263

264+
// This function is expected to be used for device-to-device buffer
265+
// initialization of a new buffer instance, so destination buffer must not
266+
// have data yet, or it will be overwritten.
267+
assert((!m_host_runtime_tensor.has_value() &&
268+
!m_prepared_runtime_tensor.has_value()) &&
269+
"Destination buffer already has data");
270+
271+
tt::runtime::Tensor source_host_runtime_tensor;
272+
273+
if (src_buffer->m_prepared_runtime_tensor != std::nullopt) {
274+
DLOG_F(WARNING,
275+
"BufferInstance::copyFromBuffer: Device-Device transfer is "
276+
"inefficient due to PJRT device modeling limitations. This will "
277+
"actually copy src to host, and fill dst host tensor, because at "
278+
"this callsite we do not know what dst device is.");
279+
std::vector<tt::runtime::Tensor> host_runtime_tensors = tt::runtime::toHost(
280+
src_buffer->m_prepared_runtime_tensor.value(), /*untilize=*/true);
281+
282+
assert(host_runtime_tensors.size() == 1 &&
283+
"Expected single host tensor when copying from device buffer");
284+
285+
source_host_runtime_tensor = host_runtime_tensors[0];
286+
} else if (src_buffer->m_host_runtime_tensor != std::nullopt) {
287+
source_host_runtime_tensor = *src_buffer->m_host_runtime_tensor;
288+
} else {
289+
assert(false && "Source buffer has no data to copy from");
290+
}
291+
262292
m_host_runtime_tensor = tt::runtime::createOwnedHostTensor(
263293
/* data= */ nullptr, shape, strides, element_size, runtime_data_type);
264-
265-
tt::runtime::memcpy(*m_host_runtime_tensor,
266-
*src_buffer->m_host_runtime_tensor);
294+
tt::runtime::memcpy(*m_host_runtime_tensor, source_host_runtime_tensor);
267295
tt::runtime::setTensorRetain(*m_host_runtime_tensor, /*retain=*/true);
268296

269297
markAsDataReady();
@@ -347,7 +375,8 @@ tt_pjrt_status BufferInstance::copyToHost(void *host_buffer,
347375
[](std::unique_lock<std::mutex> copy_lock, void *host_buffer,
348376
tt::runtime::Tensor runtime_tensor, EventInstance *event,
349377
PJRT_Buffer_Type data_type, size_t host_buffer_size,
350-
std::optional<uint32_t> device_id, bool already_on_host) {
378+
std::optional<uint32_t> device_id, bool already_on_host,
379+
uint64_t buffer_uid) {
351380
// Acquire lock to serialize all copy-to-host operations across all
352381
// BufferInstances since any metal dispatch in this async thread will
353382
// cause ND segfaults as metal is not thread safe.
@@ -368,9 +397,9 @@ tt_pjrt_status BufferInstance::copyToHost(void *host_buffer,
368397
}
369398
DLOG_F(LOG_DEBUG,
370399
"Returning tensor to host with host_runtime_tensors ct = %ld "
371-
"from device %d",
400+
"from device %d with buffer UID %zu",
372401
host_runtime_tensors.size(),
373-
device_id.has_value() ? device_id.value() : 0);
402+
device_id.has_value() ? device_id.value() : 0, buffer_uid);
374403

375404
// If device_id is not set, we are returning a replicated input
376405
// buffer instance to host (eg. cache position for update). This means
@@ -411,7 +440,7 @@ tt_pjrt_status BufferInstance::copyToHost(void *host_buffer,
411440
},
412441
std::move(copy_lock), host_buffer, runtime_tensor_to_retrieve,
413442
event.get(), m_data_type, host_buffer_size, m_device_id,
414-
is_tensor_on_host);
443+
is_tensor_on_host, m_uid);
415444

416445
// responsible for calling `PJRT_Event_Destroy` on the event.
417446
*out_copy_done_event = event.release();

pjrt_implementation/src/api/flatbuffer_loaded_executable_instance.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ void FlatbufferLoadedExecutableInstance::fillPJRTOutputLists(
199199
m_addressable_devices[device_index]->getDefaultMemory(),
200200
expected_output_data_types[output_index], device_index);
201201
DLOG_F(LOG_DEBUG,
202-
"Filled output at output_index %zu device_index %d with shape %s",
203-
output_index, device_index, output_buffer->toShapeStr().c_str());
202+
"Filled output at output_index %zu device_index %d with shape %s "
203+
"and UID %zu",
204+
output_index, device_index, output_buffer->toShapeStr().c_str(),
205+
output_buffer->getUID());
204206

205207
output_buffer->markAsDataReady();
206208

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import numpy as np
8+
import pytest
9+
from infra.connectors.device_connector import DeviceType
10+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
11+
12+
13+
@pytest.mark.nightly
14+
@pytest.mark.push
15+
def test_sharded_copyFromBuffer():
16+
"""
17+
Test basic tensor sharding with device_put - no operations.
18+
19+
This requires a revert of the of the jax_platforms test config set in autouse
20+
initialize_device_connectors conftest fixture, which is monkeypatched around the test.
21+
This results in the sharding happening on-device and induces a copyFromBuffer call by the framework.
22+
23+
This is not the expected usage pattern for tt-xla users, but is instead a backup check that the
24+
copyFromBuffer path works correctly, as there is no legitimate usecase for it right now.
25+
Users will encounter this path if they don't set jax platforms config to CPU **first** as is done in the conftest fixture.
26+
27+
Expected log when running locally:
28+
> [...] buffer_instance.cc:295 WARN| BufferInstance::copyFromBuffer: Device-Device transfer
29+
is inefficient due to PJRT device modeling limitations. This will actually copy src to host,
30+
and fill dst host tensor, because at this callsite we do not know what dst device is.
31+
"""
32+
original_platforms = jax.config.jax_platforms
33+
34+
try:
35+
jax.config.update(
36+
"jax_platforms",
37+
",".join([device.value for device in [DeviceType.TT, DeviceType.CPU]]),
38+
)
39+
40+
devices = jax.devices("tt")
41+
mesh = Mesh(np.array(devices), axis_names=("data",))
42+
43+
# Create tensor on CPU
44+
with jax.default_device(jax.devices("cpu")[0]):
45+
a = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
46+
47+
# Shard tensor across data dimension
48+
a_tt = jax.device_put(a, NamedSharding(mesh, PartitionSpec("data")))
49+
50+
# Verify sharding
51+
assert a_tt.sharding is not None
52+
finally:
53+
# Restore original config
54+
jax.config.update("jax_platforms", original_platforms)

0 commit comments

Comments
 (0)