From a04a17ade6c86d55a9753d07aaee562d3f0ef121 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:38:29 -0300 Subject: [PATCH 1/4] fix(exla): stack should respect multi-device --- exla/lib/exla/backend.ex | 2 ++ exla/test/exla/backend_test.exs | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 5135e1770f..f804c9e172 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -265,6 +265,7 @@ defmodule EXLA.Backend do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.concatenate(out, copied, axis) Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true @@ -272,6 +273,7 @@ defmodule EXLA.Backend do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.stack(out, copied, axis) Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @impl true diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 22e3c60850..73d3d7ebc7 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -147,6 +147,30 @@ defmodule EXLA.BackendTest do assert %{device_id: 1, client_name: :other_host} = Nx.reshape(a, {1}).data.buffer end + @tag :multi_device + test "stack and concatenate should end up in the same client" do + t_0 = Nx.tensor([1], backend: {EXLA.Backend, client: :other_host, device_id: 0}) + t_1 = Nx.tensor([1], backend: {EXLA.Backend, client: :other_host, device_id: 1}) + + t_stack_0 = Nx.stack([t_0, t_1]) + t_concat_0 = Nx.concatenate([t_0, t_1]) + + assert t_stack_0.data.buffer.client_name == :other_host + assert t_stack_0.data.buffer.device_id == 1 + + assert t_concat_0.data.buffer.client_name == :other_host + assert t_concat_0.data.buffer.device_id == 1 + + t_stack_1 = Nx.stack([t_1, t_0]) + t_concat_1 = Nx.concatenate([t_1, t_0]) + + assert t_stack_1.data.buffer.client_name == :other_host + assert t_stack_1.data.buffer.device_id == 1 + + assert t_concat_1.data.buffer.client_name == :other_host + assert t_concat_1.data.buffer.device_id == 1 + end + test "Kernel.inspect/2" do t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend) From 3ab0c5925beff4f1f3796882836a0a54c15b0fd9 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:48:47 -0300 Subject: [PATCH 2/4] fix: don't allow automatic transfers in test --- exla/config/runtime.exs | 3 ++- exla/lib/exla/backend.ex | 4 ++-- exla/test/exla/backend_test.exs | 16 ++++++++-------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/exla/config/runtime.exs b/exla/config/runtime.exs index 60b8f5102f..80e2d65da1 100644 --- a/exla/config/runtime.exs +++ b/exla/config/runtime.exs @@ -3,7 +3,8 @@ import Config config :exla, :clients, cuda: [platform: :cuda, memory_fraction: 0.8], rocm: [platform: :rocm, memory_fraction: 0.8], - other_host: [platform: :host] + other_host: [platform: :host], + no_automatic_transfers_host: [platform: :host, automatic_transfers: false] config :exla, default_client: String.to_atom(System.get_env("EXLA_TARGET", "host")) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index f804c9e172..e7b226831a 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -264,7 +264,6 @@ defmodule EXLA.Backend do def concatenate(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.concatenate(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @@ -272,7 +271,6 @@ defmodule EXLA.Backend do def stack(out, tensors, axis) do copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) result = Nx.BinaryBackend.stack(out, copied, axis) - Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) Nx.backend_transfer(result, {EXLA.Backend, jit_opts(tensors, [])}) end @@ -443,6 +441,8 @@ defmodule EXLA.Backend do opts[:device_id] || priority_did || backup_did || EXLA.Client.fetch!(client).default_device_id + require IEx + IEx.pry() [client: client, device_id: device_id] end end diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 73d3d7ebc7..6ee603700b 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -149,26 +149,26 @@ defmodule EXLA.BackendTest do @tag :multi_device test "stack and concatenate should end up in the same client" do - t_0 = Nx.tensor([1], backend: {EXLA.Backend, client: :other_host, device_id: 0}) - t_1 = Nx.tensor([1], backend: {EXLA.Backend, client: :other_host, device_id: 1}) + t_0 = Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0}) + t_1 = Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1}) t_stack_0 = Nx.stack([t_0, t_1]) t_concat_0 = Nx.concatenate([t_0, t_1]) - assert t_stack_0.data.buffer.client_name == :other_host + assert t_stack_0.data.buffer.client_name == :no_automatic_transfers_host assert t_stack_0.data.buffer.device_id == 1 - assert t_concat_0.data.buffer.client_name == :other_host + assert t_concat_0.data.buffer.client_name == :no_automatic_transfers_host assert t_concat_0.data.buffer.device_id == 1 t_stack_1 = Nx.stack([t_1, t_0]) t_concat_1 = Nx.concatenate([t_1, t_0]) - assert t_stack_1.data.buffer.client_name == :other_host - assert t_stack_1.data.buffer.device_id == 1 + assert t_stack_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_stack_1.data.buffer.device_id == 0 - assert t_concat_1.data.buffer.client_name == :other_host - assert t_concat_1.data.buffer.device_id == 1 + assert t_concat_1.data.buffer.client_name == :no_automatic_transfers_host + assert t_concat_1.data.buffer.device_id == 0 end test "Kernel.inspect/2" do From 8c9c4be3b3b342e55c693731642b259fa1683e7f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:49:49 -0300 Subject: [PATCH 3/4] chore: remove pry --- exla/lib/exla/backend.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index e7b226831a..50747de4bb 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -441,8 +441,6 @@ defmodule EXLA.Backend do opts[:device_id] || priority_did || backup_did || EXLA.Client.fetch!(client).default_device_id - require IEx - IEx.pry() [client: client, device_id: device_id] end end From 6495d7a6d2d38a8fc7c8e8d7894466195ab32907 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 11 Mar 2025 22:55:46 -0300 Subject: [PATCH 4/4] chore: format --- exla/test/exla/backend_test.exs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 6ee603700b..7f99c8bc6f 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -149,8 +149,11 @@ defmodule EXLA.BackendTest do @tag :multi_device test "stack and concatenate should end up in the same client" do - t_0 = Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0}) - t_1 = Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1}) + t_0 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 0}) + + t_1 = + Nx.tensor([1], backend: {EXLA.Backend, client: :no_automatic_transfers_host, device_id: 1}) t_stack_0 = Nx.stack([t_0, t_1]) t_concat_0 = Nx.concatenate([t_0, t_1])