Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Nx.to_pointer/2 and Nx.from_pointer/5 to raise on errors #1582

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND);

// ExlaBuffer Functions

std::variant<fine::Ok<uint64_t, uint64_t>,
fine::Ok<std::string, uint64_t, uint64_t>,
fine::Ok<std::string, uint64_t>, fine::Error<std::string>>
std::variant<std::tuple<fine::Atom, uint64_t, uint64_t>,
std::tuple<fine::Atom, std::string, uint64_t, uint64_t>,
std::tuple<fine::Atom, std::string, uint64_t>>
get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
fine::Term buffer_term, fine::Atom pointer_kind) {
auto buffer = decode_exla_buffer(env, buffer_term);
Expand All @@ -228,7 +228,7 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
uint64_t ptr = unwrap(buffer->GetDevicePointer(client->client()));

if (pointer_kind == "local") {
return fine::Ok(ptr, device_size);
return std::make_tuple(pointer_kind, ptr, device_size);
}

if (pointer_kind == "host_ipc") {
Expand All @@ -237,39 +237,38 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
auto fd = get_ipc_handle(handle_name.c_str(), device_size);

if (fd == -1) {
return fine::Error(std::string("unable to get IPC handle"));
throw std::runtime_error("unable to get IPC handle");
}

auto ipc_ptr = open_ipc_handle(fd, device_size);
if (ipc_ptr == nullptr) {
return fine::Error(std::string("unable to open IPC handle"));
throw std::runtime_error("unable to open IPC handle");
}

memcpy(ipc_ptr, reinterpret_cast<void *>(ptr), device_size);

return fine::Ok(handle_name, static_cast<uint64_t>(fd), device_size);
return std::make_tuple(pointer_kind, handle_name, static_cast<uint64_t>(fd),
device_size);
}

if (pointer_kind == "cuda_ipc") {
auto maybe_handle = get_cuda_ipc_handle(ptr);
if (!maybe_handle) {
return fine::Error(std::string("unable to get cuda IPC handle"));
throw std::runtime_error("unable to get cuda IPC handle");
}

return fine::Ok(maybe_handle.value(), device_size);
return std::make_tuple(pointer_kind, maybe_handle.value(), device_size);
}

throw std::invalid_argument("unexpected pointer type");
}

FINE_NIF(get_buffer_device_pointer, 0);

std::variant<fine::Ok<fine::ResourcePtr<ExlaBuffer>>, fine::Error<std::string>>
create_buffer_from_device_pointer(ErlNifEnv *env,
fine::ResourcePtr<ExlaClient> client,
fine::Atom pointer_kind,
fine::Term pointer_data, xla::Shape shape,
int64_t device_id) {
fine::ResourcePtr<ExlaBuffer> create_buffer_from_device_pointer(
ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
fine::Atom pointer_kind, fine::Term pointer_data, xla::Shape shape,
int64_t device_id) {
void *ptr = nullptr;
std::function<void()> on_delete_callback = []() {};

Expand All @@ -278,7 +277,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
auto maybe_pointer = get_pointer_for_ipc_handle(
cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id);
if (!maybe_pointer) {
return fine::Error<std::string>("unable to get pointer for IPC handle");
throw std::runtime_error("unable to get pointer for IPC handle");
}
ptr = maybe_pointer.value();
} else if (pointer_kind == "host_ipc") {
Expand All @@ -289,7 +288,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
auto device_size = xla::ShapeUtil::ByteSizeOf(shape);
ptr = open_ipc_handle(fd, device_size);
if (ptr == nullptr) {
return fine::Error<std::string>("unable to get pointer for IPC handle");
throw std::runtime_error("unable to get pointer for IPC handle");
}
on_delete_callback = [fd, memname, ptr, device_size]() {
close_ipc_handle(fd, ptr, memname.c_str(), device_size);
Expand All @@ -305,7 +304,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer(
ptr, shape, device, on_delete_callback));
return fine::Ok(fine::make_resource<ExlaBuffer>(std::move(buffer)));
return fine::make_resource<ExlaBuffer>(std::move(buffer));
}

FINE_NIF(create_buffer_from_device_pointer, 0);
Expand Down
4 changes: 1 addition & 3 deletions exla/c_src/exla/exla_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ std::optional<std::string> get_cuda_ipc_handle(std::uintptr_t ptr) {
const size_t size = sizeof(cudaIpcMemHandle_t);

// Copy the memory handle to a buffer
std::string buffer;
buffer.resize(size);
memcpy(&(*(buffer.begin())), &ipc_handle, size);
auto buffer = std::string(reinterpret_cast<const char*>(&ipc_handle), size);

return buffer;
}
Expand Down
3 changes: 1 addition & 2 deletions exla/c_src/exla/ipc.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#include "ipc.h"

#include <cstdio>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

#include <iostream>

// Function to create or open a shared memory object and set its size
int get_ipc_handle(const char* memname, size_t memsize) {
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);
Expand Down
48 changes: 11 additions & 37 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -114,35 +114,15 @@ defmodule EXLA.Backend do
client = EXLA.Client.fetch!(buffer.client_name)

case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
{:ok, ptr, size} when mode == :local and is_integer(ptr) ->
{:local, ptr, size} ->
# Pointer is an integer here
{:ok,
%Nx.Pointer{
kind: :local,
address: ptr,
data_size: size
}}

{:ok, handle_name, fd, size} when mode == :host_ipc ->
{:ok,
%Nx.Pointer{
kind: :ipc,
handle: handle_name,
address: fd,
data_size: size
}}

{:ok, handle, size} when mode === :cuda_ipc ->
{:ok,
%Nx.Pointer{
kind: :ipc,
handle: handle,
address: buffer.device_id,
data_size: size
}}

{:error, reason} ->
{:error, reason}
%Nx.Pointer{kind: :local, address: ptr, data_size: size}

{:host_ipc, handle_name, fd, size} ->
%Nx.Pointer{kind: :ipc, handle: handle_name, address: fd, data_size: size}

{:cuda_ipc, handle, size} ->
%Nx.Pointer{kind: :ipc, handle: handle, address: buffer.device_id, data_size: size}
end
end

Expand Down Expand Up @@ -180,7 +160,7 @@ defmodule EXLA.Backend do
{:cuda_ipc, handle}
end

result =
ref =
EXLA.NIF.create_buffer_from_device_pointer(
client.ref,
mode,
Expand All @@ -189,14 +169,8 @@ defmodule EXLA.Backend do
device_id
)

case result do
{:ok, ref} ->
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec)
{:ok, %{template | data: %EXLA.Backend{buffer: buffer}}}

error ->
error
end
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec)
%{template | data: %EXLA.Backend{buffer: buffer}}
end

@impl true
Expand Down
4 changes: 2 additions & 2 deletions exla/test/exla/device_memory_sharing_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ defmodule EXLA.DeviceMemorySharingTest do

assert inspect(t1) =~ "1, 2, 3"

assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local)
assert pointer = Nx.to_pointer(t1, mode: :local)

assert {:ok, t2} =
assert t2 =
Nx.from_pointer(
{EXLA.Backend, client: unquote(client_name)},
pointer,
Expand Down
28 changes: 14 additions & 14 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16771,21 +16771,21 @@ defmodule Nx do

pointer = %Nx.Pointer{kind: :local, address: 1234}
Nx.from_pointer(MyBackend, pointer, {:s, 32}, {1, 3})
#Nx.Tensor<
s32[1][3]
[
[10, 20, 30]
]
>
#=> #Nx.Tensor<
#=> s32[1][3]
#=> [
#=> [10, 20, 30]
#=> ]
#=> >

pointer = %Nx.Pointer{kind: :ipc, handle: "some-ipc-handle"}
Nx.from_pointer({MyBackend, some: :opt}, pointer, {:s, 32}, {1, 3}, names: [nil, :col])
#Nx.Tensor<
s32[1][col: 3]
[
[10, 20, 30]
]
>
#=> #Nx.Tensor<
#=> s32[1][col: 3]
#=> [
#=> [10, 20, 30]
#=> ]
#=> >
"""
@doc type: :creation
def from_pointer(backend, pointer, type, shape, opts \\ [])
Expand Down Expand Up @@ -16823,11 +16823,11 @@ defmodule Nx do

t = Nx.u8([10, 20, 30])
Nx.to_pointer(t, mode: :local)
#=> {:ok, %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}}
#=> %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}

t = Nx.s32([1, 2, 3])
Nx.to_pointer(t, mode: :ipc)
#=> {:ok, %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}}
#=> %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}
"""
@doc type: :creation
def to_pointer(tensor, opts \\ []) do
Expand Down
4 changes: 2 additions & 2 deletions nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ defmodule Nx.Backend do
shape :: tuple(),
backend_opts :: keyword(),
opts :: keyword()
) :: {:ok, tensor} | {:error, term()}
@callback to_pointer(tensor, opts :: keyword) :: {:ok, term()} | {:error, term()}
) :: tensor | no_return()
@callback to_pointer(tensor, opts :: keyword) :: term() | no_return()

@callback as_type(out :: tensor, tensor) :: tensor
@callback bitcast(out :: tensor, tensor) :: tensor
Expand Down