diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 2108de5c6e..ed3ce31a03 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -217,9 +217,9 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND); // ExlaBuffer Functions -std::variant, - fine::Ok, - fine::Ok, fine::Error> +std::variant, + std::tuple, + std::tuple> get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, fine::Term buffer_term, fine::Atom pointer_kind) { auto buffer = decode_exla_buffer(env, buffer_term); @@ -228,7 +228,7 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, uint64_t ptr = unwrap(buffer->GetDevicePointer(client->client())); if (pointer_kind == "local") { - return fine::Ok(ptr, device_size); + return std::make_tuple(pointer_kind, ptr, device_size); } if (pointer_kind == "host_ipc") { @@ -237,26 +237,27 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, auto fd = get_ipc_handle(handle_name.c_str(), device_size); if (fd == -1) { - return fine::Error(std::string("unable to get IPC handle")); + throw std::runtime_error("unable to get IPC handle"); } auto ipc_ptr = open_ipc_handle(fd, device_size); if (ipc_ptr == nullptr) { - return fine::Error(std::string("unable to open IPC handle")); + throw std::runtime_error("unable to open IPC handle"); } memcpy(ipc_ptr, reinterpret_cast(ptr), device_size); - return fine::Ok(handle_name, static_cast(fd), device_size); + return std::make_tuple(pointer_kind, handle_name, static_cast(fd), + device_size); } if (pointer_kind == "cuda_ipc") { auto maybe_handle = get_cuda_ipc_handle(ptr); if (!maybe_handle) { - return fine::Error(std::string("unable to get cuda IPC handle")); + throw std::runtime_error("unable to get cuda IPC handle"); } - return fine::Ok(maybe_handle.value(), device_size); + return std::make_tuple(pointer_kind, maybe_handle.value(), device_size); } throw std::invalid_argument("unexpected pointer type"); @@ -264,12 +265,10 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr client, FINE_NIF(get_buffer_device_pointer, 0); -std::variant>, fine::Error> -create_buffer_from_device_pointer(ErlNifEnv *env, - fine::ResourcePtr client, - fine::Atom pointer_kind, - fine::Term pointer_data, xla::Shape shape, - int64_t device_id) { +fine::ResourcePtr create_buffer_from_device_pointer( + ErlNifEnv *env, fine::ResourcePtr client, + fine::Atom pointer_kind, fine::Term pointer_data, xla::Shape shape, + int64_t device_id) { void *ptr = nullptr; std::function on_delete_callback = []() {}; @@ -278,7 +277,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, auto maybe_pointer = get_pointer_for_ipc_handle( cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id); if (!maybe_pointer) { - return fine::Error("unable to get pointer for IPC handle"); + throw std::runtime_error("unable to get pointer for IPC handle"); } ptr = maybe_pointer.value(); } else if (pointer_kind == "host_ipc") { @@ -289,7 +288,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, auto device_size = xla::ShapeUtil::ByteSizeOf(shape); ptr = open_ipc_handle(fd, device_size); if (ptr == nullptr) { - return fine::Error("unable to get pointer for IPC handle"); + throw std::runtime_error("unable to get pointer for IPC handle"); } on_delete_callback = [fd, memname, ptr, device_size]() { close_ipc_handle(fd, ptr, memname.c_str(), device_size); @@ -305,7 +304,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env, client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer( ptr, shape, device, on_delete_callback)); - return fine::Ok(fine::make_resource(std::move(buffer))); + return fine::make_resource(std::move(buffer)); } FINE_NIF(create_buffer_from_device_pointer, 0); diff --git a/exla/c_src/exla/exla_cuda.cc b/exla/c_src/exla/exla_cuda.cc index 395fce3f9b..6f5bbe7b97 100644 --- a/exla/c_src/exla/exla_cuda.cc +++ b/exla/c_src/exla/exla_cuda.cc @@ -20,9 +20,7 @@ std::optional get_cuda_ipc_handle(std::uintptr_t ptr) { const size_t size = sizeof(cudaIpcMemHandle_t); // Copy the memory handle to a buffer - std::string buffer; - buffer.resize(size); - memcpy(&(*(buffer.begin())), &ipc_handle, size); + auto buffer = std::string(reinterpret_cast(&ipc_handle), size); return buffer; } diff --git a/exla/c_src/exla/ipc.cc b/exla/c_src/exla/ipc.cc index b1616c046b..afcbb09859 100644 --- a/exla/c_src/exla/ipc.cc +++ b/exla/c_src/exla/ipc.cc @@ -1,12 +1,11 @@ #include "ipc.h" +#include #include #include #include #include -#include - // Function to create or open a shared memory object and set its size int get_ipc_handle(const char* memname, size_t memsize) { int fd = shm_open(memname, O_CREAT | O_RDWR, 0666); diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index a94511cb19..2ac5c62646 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -114,35 +114,15 @@ defmodule EXLA.Backend do client = EXLA.Client.fetch!(buffer.client_name) case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do - {:ok, ptr, size} when mode == :local and is_integer(ptr) -> + {:local, ptr, size} -> # Pointer is an integer here - {:ok, - %Nx.Pointer{ - kind: :local, - address: ptr, - data_size: size - }} - - {:ok, handle_name, fd, size} when mode == :host_ipc -> - {:ok, - %Nx.Pointer{ - kind: :ipc, - handle: handle_name, - address: fd, - data_size: size - }} - - {:ok, handle, size} when mode === :cuda_ipc -> - {:ok, - %Nx.Pointer{ - kind: :ipc, - handle: handle, - address: buffer.device_id, - data_size: size - }} - - {:error, reason} -> - {:error, reason} + %Nx.Pointer{kind: :local, address: ptr, data_size: size} + + {:host_ipc, handle_name, fd, size} -> + %Nx.Pointer{kind: :ipc, handle: handle_name, address: fd, data_size: size} + + {:cuda_ipc, handle, size} -> + %Nx.Pointer{kind: :ipc, handle: handle, address: buffer.device_id, data_size: size} end end @@ -180,7 +160,7 @@ defmodule EXLA.Backend do {:cuda_ipc, handle} end - result = + ref = EXLA.NIF.create_buffer_from_device_pointer( client.ref, mode, @@ -189,14 +169,8 @@ defmodule EXLA.Backend do device_id ) - case result do - {:ok, ref} -> - buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec) - {:ok, %{template | data: %EXLA.Backend{buffer: buffer}}} - - error -> - error - end + buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec) + %{template | data: %EXLA.Backend{buffer: buffer}} end @impl true diff --git a/exla/test/exla/device_memory_sharing_test.exs b/exla/test/exla/device_memory_sharing_test.exs index 09e54a42eb..7ef3b165ef 100644 --- a/exla/test/exla/device_memory_sharing_test.exs +++ b/exla/test/exla/device_memory_sharing_test.exs @@ -11,9 +11,9 @@ defmodule EXLA.DeviceMemorySharingTest do assert inspect(t1) =~ "1, 2, 3" - assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local) + assert pointer = Nx.to_pointer(t1, mode: :local) - assert {:ok, t2} = + assert t2 = Nx.from_pointer( {EXLA.Backend, client: unquote(client_name)}, pointer, diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 8a1cd582cc..170652e4d7 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -16771,21 +16771,21 @@ defmodule Nx do pointer = %Nx.Pointer{kind: :local, address: 1234} Nx.from_pointer(MyBackend, pointer, {:s, 32}, {1, 3}) - #Nx.Tensor< - s32[1][3] - [ - [10, 20, 30] - ] - > + #=> #Nx.Tensor< + #=> s32[1][3] + #=> [ + #=> [10, 20, 30] + #=> ] + #=> > pointer = %Nx.Pointer{kind: :ipc, handle: "some-ipc-handle"} Nx.from_pointer({MyBackend, some: :opt}, pointer, {:s, 32}, {1, 3}, names: [nil, :col]) - #Nx.Tensor< - s32[1][col: 3] - [ - [10, 20, 30] - ] - > + #=> #Nx.Tensor< + #=> s32[1][col: 3] + #=> [ + #=> [10, 20, 30] + #=> ] + #=> > """ @doc type: :creation def from_pointer(backend, pointer, type, shape, opts \\ []) @@ -16823,11 +16823,11 @@ defmodule Nx do t = Nx.u8([10, 20, 30]) Nx.to_pointer(t, mode: :local) - #=> {:ok, %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}} + #=> %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil} t = Nx.s32([1, 2, 3]) Nx.to_pointer(t, mode: :ipc) - #=> {:ok, %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}} + #=> %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"} """ @doc type: :creation def to_pointer(tensor, opts \\ []) do diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 1193a0b8c4..638a1dfdde 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -57,8 +57,8 @@ defmodule Nx.Backend do shape :: tuple(), backend_opts :: keyword(), opts :: keyword() - ) :: {:ok, tensor} | {:error, term()} - @callback to_pointer(tensor, opts :: keyword) :: {:ok, term()} | {:error, term()} + ) :: tensor | no_return() + @callback to_pointer(tensor, opts :: keyword) :: term() | no_return() @callback as_type(out :: tensor, tensor) :: tensor @callback bitcast(out :: tensor, tensor) :: tensor