Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 107 additions & 57 deletions torchx/c_src/torchx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,28 @@
#include "torchx_nif_util.h"
#include <iostream>
#include <numeric>
#include <stdexcept>

namespace torchx {

// Register TorchTensor as a resource type
FINE_RESOURCE(TorchTensor);

// Helper macro to provide better error messages for PyTorch exceptions
#define TORCH_CATCH_ERROR(expr, operation_name) \
try { \
return expr; \
} catch (const c10::Error &e) { \
throw std::runtime_error(std::string(operation_name) + \
" failed: " + e.what()); \
} catch (const std::exception &e) { \
throw std::runtime_error(std::string(operation_name) + \
" failed: " + e.what()); \
} catch (...) { \
throw std::runtime_error(std::string(operation_name) + \
" failed with unknown error"); \
}

// Macro to register both _cpu and _io variants of a function
// Following EXLA's pattern - create wrapper functions
#define REGISTER_TENSOR_NIF(NAME) \
Expand Down Expand Up @@ -257,8 +273,11 @@ REGISTER_TENSOR_NIF(to_type);
fine::Ok<fine::ResourcePtr<TorchTensor>>
to_device(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
std::tuple<int64_t, int64_t> device_tuple) {
auto device = tuple_to_device(device_tuple);
return tensor_ok(get_tensor(tensor).to(device));
TORCH_CATCH_ERROR(({
auto device = tuple_to_device(device_tuple);
tensor_ok(get_tensor(tensor).to(device));
}),
"Device transfer");
}

REGISTER_TENSOR_NIF(to_device);
Expand Down Expand Up @@ -340,15 +359,18 @@ fine::Ok<fine::ResourcePtr<TorchTensor>>
index_put(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> input,
std::vector<fine::ResourcePtr<TorchTensor>> indices,
fine::ResourcePtr<TorchTensor> values, bool accumulate) {
TORCH_CATCH_ERROR(
[&]() {
c10::List<std::optional<at::Tensor>> torch_indices;
for (const auto &idx : indices) {
torch_indices.push_back(get_tensor(idx));
}

c10::List<std::optional<at::Tensor>> torch_indices;
for (const auto &idx : indices) {
torch_indices.push_back(get_tensor(idx));
}

torch::Tensor result = get_tensor(input).clone();
result.index_put_(torch_indices, get_tensor(values), accumulate);
return tensor_ok(result);
torch::Tensor result = get_tensor(input).clone();
result.index_put_(torch_indices, get_tensor(values), accumulate);
return tensor_ok(result);
}(),
"index_put");
}

REGISTER_TENSOR_NIF(index_put);
Expand Down Expand Up @@ -713,26 +735,30 @@ REGISTER_TENSOR_NIF(matmul);
fine::Ok<fine::ResourcePtr<TorchTensor>>
pad(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor,
fine::ResourcePtr<TorchTensor> constant, std::vector<int64_t> config) {
return tensor_ok(torch::constant_pad_nd(get_tensor(tensor),
vec_to_array_ref(config),
get_tensor(constant).item()));
TORCH_CATCH_ERROR(tensor_ok(torch::constant_pad_nd(
get_tensor(tensor), vec_to_array_ref(config),
get_tensor(constant).item())),
"Pad operation");
}

REGISTER_TENSOR_NIF(pad);

fine::Ok<fine::ResourcePtr<TorchTensor>>
triangular_solve(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> a,
fine::ResourcePtr<TorchTensor> b, bool transpose, bool upper) {
auto ts_a = get_tensor(a);
if (transpose) {
auto num_dims = ts_a.dim();
ts_a = torch::transpose(ts_a, num_dims - 2, num_dims - 1);
upper = !upper;
}

torch::Tensor result =
torch::linalg_solve_triangular(ts_a, get_tensor(b), upper, true, false);
return tensor_ok(result);
TORCH_CATCH_ERROR(({
auto ts_a = get_tensor(a);
if (transpose) {
auto num_dims = ts_a.dim();
ts_a =
torch::transpose(ts_a, num_dims - 2, num_dims - 1);
upper = !upper;
}
torch::Tensor result = torch::linalg_solve_triangular(
ts_a, get_tensor(b), upper, true, false);
tensor_ok(result);
}),
"Triangular solve");
}

REGISTER_TENSOR_NIF(triangular_solve);
Expand Down Expand Up @@ -952,20 +978,28 @@ REGISTER_TENSOR_NIF_ARITY(cholesky, cholesky_2);
fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
qr_1(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
auto result = torch::linalg_qr(get_tensor(t), "reduced");
return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
TORCH_CATCH_ERROR(
({
auto result = torch::linalg_qr(get_tensor(t), "reduced");
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
}),
"QR decomposition");
}

fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
qr_2(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool reduced) {
auto result =
torch::linalg_qr(get_tensor(t), reduced ? "reduced" : "complete");
return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
TORCH_CATCH_ERROR(
({
auto result =
torch::linalg_qr(get_tensor(t), reduced ? "reduced" : "complete");
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
}),
"QR decomposition");
}

REGISTER_TENSOR_NIF_ARITY(qr, qr_1);
Expand All @@ -976,22 +1010,30 @@ fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
fine::ResourcePtr<TorchTensor>>>
svd_1(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
auto result = torch::linalg_svd(get_tensor(t), true);
return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result)),
fine::make_resource<TorchTensor>(std::get<2>(result))));
TORCH_CATCH_ERROR(
({
auto result = torch::linalg_svd(get_tensor(t), true);
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result)),
fine::make_resource<TorchTensor>(std::get<2>(result))));
}),
"SVD decomposition");
}

fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
fine::ResourcePtr<TorchTensor>>>
svd_2(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t, bool full_matrices) {
auto result = torch::linalg_svd(get_tensor(t), full_matrices);
return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result)),
fine::make_resource<TorchTensor>(std::get<2>(result))));
TORCH_CATCH_ERROR(
({
auto result = torch::linalg_svd(get_tensor(t), full_matrices);
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result)),
fine::make_resource<TorchTensor>(std::get<2>(result))));
}),
"SVD decomposition");
}

REGISTER_TENSOR_NIF_ARITY(svd, svd_1);
Expand All @@ -1001,15 +1043,18 @@ fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>,
fine::ResourcePtr<TorchTensor>>>
lu(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> t) {
std::tuple<torch::Tensor, torch::Tensor> lu_result =
torch::linalg_lu_factor(get_tensor(t));
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result));

return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(plu)),
fine::make_resource<TorchTensor>(std::get<1>(plu)),
fine::make_resource<TorchTensor>(std::get<2>(plu))));
TORCH_CATCH_ERROR(
({
std::tuple<torch::Tensor, torch::Tensor> lu_result =
torch::linalg_lu_factor(get_tensor(t));
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> plu =
torch::lu_unpack(std::get<0>(lu_result), std::get<1>(lu_result));
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(plu)),
fine::make_resource<TorchTensor>(std::get<1>(plu)),
fine::make_resource<TorchTensor>(std::get<2>(plu))));
}),
"LU decomposition");
}

REGISTER_TENSOR_NIF(lu);
Expand All @@ -1035,19 +1080,24 @@ REGISTER_TENSOR_NIF(amin);
fine::Ok<
std::tuple<fine::ResourcePtr<TorchTensor>, fine::ResourcePtr<TorchTensor>>>
eigh(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensor) {
auto result = torch::linalg_eigh(get_tensor(tensor));
return fine::Ok(
std::make_tuple(fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
TORCH_CATCH_ERROR(
({
auto result = torch::linalg_eigh(get_tensor(tensor));
fine::Ok(std::make_tuple(
fine::make_resource<TorchTensor>(std::get<0>(result)),
fine::make_resource<TorchTensor>(std::get<1>(result))));
}),
"Eigenvalue decomposition (eigh)");
}

REGISTER_TENSOR_NIF(eigh);

fine::Ok<fine::ResourcePtr<TorchTensor>>
solve(ErlNifEnv *env, fine::ResourcePtr<TorchTensor> tensorA,
fine::ResourcePtr<TorchTensor> tensorB) {
return tensor_ok(
torch::linalg_solve(get_tensor(tensorA), get_tensor(tensorB)));
TORCH_CATCH_ERROR(
tensor_ok(torch::linalg_solve(get_tensor(tensorA), get_tensor(tensorB))),
"Linear solve");
}

REGISTER_TENSOR_NIF(solve);
Expand Down
22 changes: 19 additions & 3 deletions torchx/lib/torchx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,19 @@ defmodule Torchx do
def eye(size, type, device), do: eye(size, size, type, device)
defdevice eye(m, n, type, device)
defdevice from_blob(blob, shape, type, device)
defdevice to_device(tensor, device)

@torch_function {:to_device, 2}
def to_device(tensor, device) do
{[tensor_ref], _current_device} = prepare_tensors!([tensor])
{user_device, index} = normalize_device!(device)
target_device_struct = torch_device!(user_device, index)

case user_device do
:cpu -> Torchx.NIF.to_device_cpu(tensor_ref, target_device_struct)
_ -> Torchx.NIF.to_device_io(tensor_ref, target_device_struct)
end
|> unwrap_tensor!(user_device)
end

## Manipulation

Expand Down Expand Up @@ -466,7 +478,9 @@ defmodule Torchx do
ref

{other_dev, ref} when is_tensor(other_dev, ref) ->
raise ArgumentError, "cannot perform operation across devices #{dev} and #{other_dev}"
# Auto-transfer tensor to target device
{^dev, new_ref} = Torchx.to_device({other_dev, ref}, dev)
new_ref

bad_tensor ->
raise ArgumentError, "expected a Torchx tensor, got: #{inspect(bad_tensor)}"
Expand All @@ -484,7 +498,9 @@ defmodule Torchx do
{ref, dev}

{dev, ref}, other_dev when is_tensor(dev, ref) ->
raise ArgumentError, "cannot perform operation across devices #{dev} and #{other_dev}"
# Auto-transfer tensor to target device
{^other_dev, new_ref} = Torchx.to_device({dev, ref}, other_dev)
{new_ref, other_dev}

[{dev, ref} | _] = tensors, nil when is_tensor(dev, ref) ->
prepare_tensors_list!(tensors, dev)
Expand Down
Loading