Skip to content

Commit cf35078

Browse files
authored
fix(exla): respect device id when automatic transfers are disabled (#1592)
1 parent 9526198 commit cf35078

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

exla/config/runtime.exs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ import Config
33
config :exla, :clients,
44
cuda: [platform: :cuda, memory_fraction: 0.8],
55
rocm: [platform: :rocm, memory_fraction: 0.8],
6-
other_host: [platform: :host]
6+
other_host: [platform: :host],
7+
no_automatic_transfers_host: [platform: :host, automatic_transfers: false]
78

89
config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host"))
910

exla/lib/exla/backend.ex

+2-2
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,14 @@ defmodule EXLA.Backend do
264264
def concatenate(out, tensors, axis) do
265265
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
266266
result = Nx.BinaryBackend.concatenate(out, copied, axis)
267-
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
267+
Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])})
268268
end
269269

270270
@impl true
271271
def stack(out, tensors, axis) do
272272
copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend))
273273
result = Nx.BinaryBackend.stack(out, copied, axis)
274-
Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)})
274+
Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])})
275275
end
276276

277277
@impl true

exla/test/exla/backend_test.exs

+27
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,33 @@ defmodule EXLA.BackendTest do
147147
assert %{device_id: 1, client_name: :other_host} = Nx.reshape(a, {1}).data.buffer
148148
end
149149

150+
@tag :multi_device
151+
test "stack and concatenate should end up in the same client" do
152+
t_0 =
153+
Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0})
154+
155+
t_1 =
156+
Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1})
157+
158+
t_stack_0 = Nx.stack([t_0, t_1])
159+
t_concat_0 = Nx.concatenate([t_0, t_1])
160+
161+
assert t_stack_0.data.buffer.client_name == :no_automatic_transfers_host
162+
assert t_stack_0.data.buffer.device_id == 1
163+
164+
assert t_concat_0.data.buffer.client_name == :no_automatic_transfers_host
165+
assert t_concat_0.data.buffer.device_id == 1
166+
167+
t_stack_1 = Nx.stack([t_1, t_0])
168+
t_concat_1 = Nx.concatenate([t_1, t_0])
169+
170+
assert t_stack_1.data.buffer.client_name == :no_automatic_transfers_host
171+
assert t_stack_1.data.buffer.device_id == 0
172+
173+
assert t_concat_1.data.buffer.client_name == :no_automatic_transfers_host
174+
assert t_concat_1.data.buffer.device_id == 0
175+
end
176+
150177
test "Kernel.inspect/2" do
151178
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
152179

0 commit comments

Comments
 (0)