Skip to content

Commit 1a237e1

Browse files
Change Nx.to_pointer/2 and Nx.from_pointer/5 to raise on errors (#1582)
Co-authored-by: Paulo Valente <[email protected]>
1 parent 5d30718 commit 1a237e1

File tree

7 files changed

+48
-78
lines changed

7 files changed

+48
-78
lines changed

exla/c_src/exla/exla.cc

+17-18
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ FINE_NIF(mlir_compile, ERL_NIF_DIRTY_JOB_CPU_BOUND);
217217

218218
// ExlaBuffer Functions
219219

220-
std::variant<fine::Ok<uint64_t, uint64_t>,
221-
fine::Ok<std::string, uint64_t, uint64_t>,
222-
fine::Ok<std::string, uint64_t>, fine::Error<std::string>>
220+
std::variant<std::tuple<fine::Atom, uint64_t, uint64_t>,
221+
std::tuple<fine::Atom, std::string, uint64_t, uint64_t>,
222+
std::tuple<fine::Atom, std::string, uint64_t>>
223223
get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
224224
fine::Term buffer_term, fine::Atom pointer_kind) {
225225
auto buffer = decode_exla_buffer(env, buffer_term);
@@ -228,7 +228,7 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
228228
uint64_t ptr = unwrap(buffer->GetDevicePointer(client->client()));
229229

230230
if (pointer_kind == "local") {
231-
return fine::Ok(ptr, device_size);
231+
return std::make_tuple(pointer_kind, ptr, device_size);
232232
}
233233

234234
if (pointer_kind == "host_ipc") {
@@ -237,39 +237,38 @@ get_buffer_device_pointer(ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
237237
auto fd = get_ipc_handle(handle_name.c_str(), device_size);
238238

239239
if (fd == -1) {
240-
return fine::Error(std::string("unable to get IPC handle"));
240+
throw std::runtime_error("unable to get IPC handle");
241241
}
242242

243243
auto ipc_ptr = open_ipc_handle(fd, device_size);
244244
if (ipc_ptr == nullptr) {
245-
return fine::Error(std::string("unable to open IPC handle"));
245+
throw std::runtime_error("unable to open IPC handle");
246246
}
247247

248248
memcpy(ipc_ptr, reinterpret_cast<void *>(ptr), device_size);
249249

250-
return fine::Ok(handle_name, static_cast<uint64_t>(fd), device_size);
250+
return std::make_tuple(pointer_kind, handle_name, static_cast<uint64_t>(fd),
251+
device_size);
251252
}
252253

253254
if (pointer_kind == "cuda_ipc") {
254255
auto maybe_handle = get_cuda_ipc_handle(ptr);
255256
if (!maybe_handle) {
256-
return fine::Error(std::string("unable to get cuda IPC handle"));
257+
throw std::runtime_error("unable to get cuda IPC handle");
257258
}
258259

259-
return fine::Ok(maybe_handle.value(), device_size);
260+
return std::make_tuple(pointer_kind, maybe_handle.value(), device_size);
260261
}
261262

262263
throw std::invalid_argument("unexpected pointer type");
263264
}
264265

265266
FINE_NIF(get_buffer_device_pointer, 0);
266267

267-
std::variant<fine::Ok<fine::ResourcePtr<ExlaBuffer>>, fine::Error<std::string>>
268-
create_buffer_from_device_pointer(ErlNifEnv *env,
269-
fine::ResourcePtr<ExlaClient> client,
270-
fine::Atom pointer_kind,
271-
fine::Term pointer_data, xla::Shape shape,
272-
int64_t device_id) {
268+
fine::ResourcePtr<ExlaBuffer> create_buffer_from_device_pointer(
269+
ErlNifEnv *env, fine::ResourcePtr<ExlaClient> client,
270+
fine::Atom pointer_kind, fine::Term pointer_data, xla::Shape shape,
271+
int64_t device_id) {
273272
void *ptr = nullptr;
274273
std::function<void()> on_delete_callback = []() {};
275274

@@ -278,7 +277,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
278277
auto maybe_pointer = get_pointer_for_ipc_handle(
279278
cuda_ipc_handle_bin.data, cuda_ipc_handle_bin.size, device_id);
280279
if (!maybe_pointer) {
281-
return fine::Error<std::string>("unable to get pointer for IPC handle");
280+
throw std::runtime_error("unable to get pointer for IPC handle");
282281
}
283282
ptr = maybe_pointer.value();
284283
} else if (pointer_kind == "host_ipc") {
@@ -289,7 +288,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
289288
auto device_size = xla::ShapeUtil::ByteSizeOf(shape);
290289
ptr = open_ipc_handle(fd, device_size);
291290
if (ptr == nullptr) {
292-
return fine::Error<std::string>("unable to get pointer for IPC handle");
291+
throw std::runtime_error("unable to get pointer for IPC handle");
293292
}
294293
on_delete_callback = [fd, memname, ptr, device_size]() {
295294
close_ipc_handle(fd, ptr, memname.c_str(), device_size);
@@ -305,7 +304,7 @@ create_buffer_from_device_pointer(ErlNifEnv *env,
305304
client->client()->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
306305
auto buffer = unwrap(client->client()->CreateViewOfDeviceBuffer(
307306
ptr, shape, device, on_delete_callback));
308-
return fine::Ok(fine::make_resource<ExlaBuffer>(std::move(buffer)));
307+
return fine::make_resource<ExlaBuffer>(std::move(buffer));
309308
}
310309

311310
FINE_NIF(create_buffer_from_device_pointer, 0);

exla/c_src/exla/exla_cuda.cc

+1-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ std::optional<std::string> get_cuda_ipc_handle(std::uintptr_t ptr) {
2020
const size_t size = sizeof(cudaIpcMemHandle_t);
2121

2222
// Copy the memory handle to a buffer
23-
std::string buffer;
24-
buffer.resize(size);
25-
memcpy(&(*(buffer.begin())), &ipc_handle, size);
23+
auto buffer = std::string(reinterpret_cast<const char*>(&ipc_handle), size);
2624

2725
return buffer;
2826
}

exla/c_src/exla/ipc.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include "ipc.h"
22

3+
#include <cstdio>
34
#include <fcntl.h>
45
#include <sys/mman.h>
56
#include <sys/stat.h>
67
#include <unistd.h>
78

8-
#include <iostream>
9-
109
// Function to create or open a shared memory object and set its size
1110
int get_ipc_handle(const char* memname, size_t memsize) {
1211
int fd = shm_open(memname, O_CREAT | O_RDWR, 0666);

exla/lib/exla/backend.ex

+11-37
Original file line numberDiff line numberDiff line change
@@ -114,35 +114,15 @@ defmodule EXLA.Backend do
114114
client = EXLA.Client.fetch!(buffer.client_name)
115115

116116
case EXLA.NIF.get_buffer_device_pointer(client.ref, buffer.ref, mode) do
117-
{:ok, ptr, size} when mode == :local and is_integer(ptr) ->
117+
{:local, ptr, size} ->
118118
# Pointer is an integer here
119-
{:ok,
120-
%Nx.Pointer{
121-
kind: :local,
122-
address: ptr,
123-
data_size: size
124-
}}
125-
126-
{:ok, handle_name, fd, size} when mode == :host_ipc ->
127-
{:ok,
128-
%Nx.Pointer{
129-
kind: :ipc,
130-
handle: handle_name,
131-
address: fd,
132-
data_size: size
133-
}}
134-
135-
{:ok, handle, size} when mode === :cuda_ipc ->
136-
{:ok,
137-
%Nx.Pointer{
138-
kind: :ipc,
139-
handle: handle,
140-
address: buffer.device_id,
141-
data_size: size
142-
}}
143-
144-
{:error, reason} ->
145-
{:error, reason}
119+
%Nx.Pointer{kind: :local, address: ptr, data_size: size}
120+
121+
{:host_ipc, handle_name, fd, size} ->
122+
%Nx.Pointer{kind: :ipc, handle: handle_name, address: fd, data_size: size}
123+
124+
{:cuda_ipc, handle, size} ->
125+
%Nx.Pointer{kind: :ipc, handle: handle, address: buffer.device_id, data_size: size}
146126
end
147127
end
148128

@@ -180,7 +160,7 @@ defmodule EXLA.Backend do
180160
{:cuda_ipc, handle}
181161
end
182162

183-
result =
163+
ref =
184164
EXLA.NIF.create_buffer_from_device_pointer(
185165
client.ref,
186166
mode,
@@ -189,14 +169,8 @@ defmodule EXLA.Backend do
189169
device_id
190170
)
191171

192-
case result do
193-
{:ok, ref} ->
194-
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec)
195-
{:ok, %{template | data: %EXLA.Backend{buffer: buffer}}}
196-
197-
error ->
198-
error
199-
end
172+
buffer = EXLA.DeviceBuffer.from_ref(ref, client, device_id, typespec)
173+
%{template | data: %EXLA.Backend{buffer: buffer}}
200174
end
201175

202176
@impl true

exla/test/exla/device_memory_sharing_test.exs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ defmodule EXLA.DeviceMemorySharingTest do
1111

1212
assert inspect(t1) =~ "1, 2, 3"
1313

14-
assert {:ok, pointer} = Nx.to_pointer(t1, mode: :local)
14+
assert pointer = Nx.to_pointer(t1, mode: :local)
1515

16-
assert {:ok, t2} =
16+
assert t2 =
1717
Nx.from_pointer(
1818
{EXLA.Backend, client: unquote(client_name)},
1919
pointer,

nx/lib/nx.ex

+14-14
Original file line numberDiff line numberDiff line change
@@ -16771,21 +16771,21 @@ defmodule Nx do
1677116771
1677216772
pointer = %Nx.Pointer{kind: :local, address: 1234}
1677316773
Nx.from_pointer(MyBackend, pointer, {:s, 32}, {1, 3})
16774-
#Nx.Tensor<
16775-
s32[1][3]
16776-
[
16777-
[10, 20, 30]
16778-
]
16779-
>
16774+
#=> #Nx.Tensor<
16775+
#=> s32[1][3]
16776+
#=> [
16777+
#=> [10, 20, 30]
16778+
#=> ]
16779+
#=> >
1678016780
1678116781
pointer = %Nx.Pointer{kind: :ipc, handle: "some-ipc-handle"}
1678216782
Nx.from_pointer({MyBackend, some: :opt}, pointer, {:s, 32}, {1, 3}, names: [nil, :col])
16783-
#Nx.Tensor<
16784-
s32[1][col: 3]
16785-
[
16786-
[10, 20, 30]
16787-
]
16788-
>
16783+
#=> #Nx.Tensor<
16784+
#=> s32[1][col: 3]
16785+
#=> [
16786+
#=> [10, 20, 30]
16787+
#=> ]
16788+
#=> >
1678916789
"""
1679016790
@doc type: :creation
1679116791
def from_pointer(backend, pointer, type, shape, opts \\ [])
@@ -16823,11 +16823,11 @@ defmodule Nx do
1682316823
1682416824
t = Nx.u8([10, 20, 30])
1682516825
Nx.to_pointer(t, mode: :local)
16826-
#=> {:ok, %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}}
16826+
#=> %Nx.Pointer{kind: :local, address: 1234, data_size: 3, handle: nil}
1682716827
1682816828
t = Nx.s32([1, 2, 3])
1682916829
Nx.to_pointer(t, mode: :ipc)
16830-
#=> {:ok, %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}}
16830+
#=> %Nx.Pointer{kind: :ipc, address: nil, data_size: 32, handle: "some-ipc-handle"}
1683116831
"""
1683216832
@doc type: :creation
1683316833
def to_pointer(tensor, opts \\ []) do

nx/lib/nx/backend.ex

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ defmodule Nx.Backend do
5757
shape :: tuple(),
5858
backend_opts :: keyword(),
5959
opts :: keyword()
60-
) :: {:ok, tensor} | {:error, term()}
61-
@callback to_pointer(tensor, opts :: keyword) :: {:ok, term()} | {:error, term()}
60+
) :: tensor | no_return()
61+
@callback to_pointer(tensor, opts :: keyword) :: term() | no_return()
6262

6363
@callback as_type(out :: tensor, tensor) :: tensor
6464
@callback bitcast(out :: tensor, tensor) :: tensor

0 commit comments

Comments
 (0)