Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,56 @@ jobs:
- name: Run tests
if: ${{ env.should_run == 'true' }}
run: mix test

macos:
name: macOS ARM (torchx, ${{ matrix.elixir }}, ${{ matrix.otp }})
runs-on: macos-14
needs: detect-changes
strategy:
fail-fast: false
matrix:
include:
- elixir: "1.18.4"
otp: "27.3"
defaults:
run:
working-directory: torchx
env:
MIX_ENV: test
TORCHX_DEFAULT_DEVICE: mps
PYTORCH_MPS_HIGH_WATERMARK_RATIO: "0.0"
steps:
- name: Set conditional variables
working-directory: .
run: |
if [ "${{ needs.detect-changes.outputs.torchx-changed }}" = "true" ]; then
echo "should_run=true" >> $GITHUB_ENV
else
echo "should_run=true" >> $GITHUB_ENV
fi
- uses: actions/checkout@v2
- uses: erlef/setup-beam@v1
with:
otp-version: ${{ matrix.otp }}
elixir-version: ${{ matrix.elixir }}
- name: Install libomp
if: ${{ env.should_run == 'true' }}
run: brew install libomp
- name: Retrieve dependencies cache
if: ${{ steps.mix-cache.outputs.cache-hit != 'true' && env.should_run == 'true' }}
env:
cache-name: cache-mix-deps
uses: actions/cache@v3
id: mix-cache # id to use in retrieve action
with:
path: ${{ github.workspace }}/torchx/deps
key: ${{ runner.os }}-Elixir-v${{ matrix.elixir }}-OTP-${{ matrix.otp }}-${{ hashFiles('torchx/mix.lock') }}-v1
- name: Install dependencies
if: ${{ steps.mix-cache.outputs.cache-hit != 'true' && env.should_run == 'true' }}
run: mix deps.get
- name: Compile and check warnings
if: ${{ env.should_run == 'true' }}
run: mix compile --warnings-as-errors
- name: Run tests
if: ${{ env.should_run == 'true' }}
run: mix test
37 changes: 22 additions & 15 deletions torchx/c_src/torchx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,30 @@ REGISTER_TENSOR_NIF(delete_tensor);
fine::Ok<fine::ResourcePtr<TorchTensor>>
from_blob(ErlNifEnv *env, ErlNifBinary blob, std::vector<int64_t> shape,
fine::Atom type_atom, std::tuple<int64_t, int64_t> device_tuple) {
try {
auto type = string2type(type_atom.to_string());
auto device = tuple_to_device(device_tuple);

// Check if binary is large enough
if (blob.size / dtype_sizes[type_atom.to_string()] < elem_count(shape)) {
throw std::invalid_argument(
"Binary size is too small for the requested shape");
}

auto type = string2type(type_atom.to_string());
auto device = tuple_to_device(device_tuple);

// Check if binary is large enough
if (blob.size / dtype_sizes[type_atom.to_string()] < elem_count(shape)) {
throw std::invalid_argument(
"Binary size is too small for the requested shape");
}

auto tensor = torch::from_blob(blob.data, vec_to_array_ref(shape),
torch::device(torch::kCPU).dtype(type));
auto tensor = torch::from_blob(blob.data, vec_to_array_ref(shape),
torch::device(torch::kCPU).dtype(type));

if (device.type() == torch::kCPU) {
return tensor_ok(tensor.clone());
} else {
return tensor_ok(tensor.to(device));
if (device.type() == torch::kCPU) {
return tensor_ok(tensor.clone());
} else {
return tensor_ok(tensor.to(device));
}
} catch (const c10::Error &e) {
throw std::runtime_error("from_blob failed: " + e.msg());
} catch (const std::exception &e) {
throw std::runtime_error("from_blob failed: " + std::string(e.what()));
} catch (...) {
throw std::runtime_error("from_blob failed with unknown error");
}
}

Expand Down
Loading