diff --git a/exla/config/runtime.exs b/exla/config/runtime.exs index 60b8f5102f..80e2d65da1 100644 --- a/exla/config/runtime.exs +++ b/exla/config/runtime.exs @@ -3,7 +3,8 @@ import Config config :exla, :clients, cuda: [platform: :cuda, memory_fraction: 0.8], rocm: [platform: :rocm, memory_fraction: 0.8], - other_host: [platform: :host] + other_host: [platform: :host], + no_automatic_transfers_host: [platform: :host, automatic_transfers: false] config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host")) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 5135e1770f..50747de4bb 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -264,14 +264,14 @@ defmodule EXLA.Backend do def concatenate(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.concatenate(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true def stack(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.stack(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 22e3c60850..7f99c8bc6f 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -147,6 +147,33 @@ defmodule EXLA.BackendTest do assert %{device_id: 1, client_name: :other_host} = Nx.reshape(a, {1}).data.buffer end + @tag :multi_device + test "stack and concatenate should end up in the same client" do + t_0 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0}) + + t_1 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1}) + + t_stack_0 = Nx.stack([t_0, t_1]) + t_concat_0 = Nx.concatenate([t_0, t_1]) + + assert t_stack_0.data.buffer.client_name == :no_automatic_transfers_host + assert t_stack_0.data.buffer.device_id == 1 + + assert t_concat_0.data.buffer.client_name == :no_automatic_transfers_host + assert t_concat_0.data.buffer.device_id == 1 + + t_stack_1 = Nx.stack([t_1, t_0]) + t_concat_1 = Nx.concatenate([t_1, t_0]) + + assert t_stack_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_stack_1.data.buffer.device_id == 0 + + assert t_concat_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_concat_1.data.buffer.device_id == 0 + end + test "Kernel.inspect/2" do t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)