Skip to content

Commit 5263a9f

Browse files
committed
Add macro for automating the C++-side of the simpler functions to implement.
1 parent 0b1c8b8 commit 5263a9f

4 files changed

Lines changed: 66 additions & 34 deletions

File tree

CMakeLists.txt

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ if(APPLE)
2929
set(CMAKE_C_COMPILER "/usr/bin/clang")
3030
set(CMAKE_CXX_COMPILER "/usr/bin/clang++")
3131
endif()
32-
set(CMAKE_CXX_STANDARD 23)
32+
set(CMAKE_CXX_STANDARD 20)
3333

3434

3535
include(LibTorchDL)
@@ -38,7 +38,6 @@ download_libtorch(
3838
DESTINATION ${LIBTORCH_DIR}
3939
)
4040

41-
4241
include(FetchContent)
4342
include(ExternalProject)
4443

@@ -170,7 +169,7 @@ set(BRIDGE_OBJECT_FILES $<TARGET_OBJECTS:bridge>)
170169

171170

172171
file(GLOB LIBTORCH_ALL_LIB_FILES
173-
"${LIBTORCH_DIR}/lib/*.a"
172+
# "${LIBTORCH_DIR}/lib/*.a"
174173
"${LIBTORCH_DIR}/lib/*.dylib"
175174
"${LIBTORCH_DIR}/lib/*.so")
176175

@@ -185,7 +184,6 @@ set(REQUIRED_LIBS
185184
"libtorch"
186185
"libtorch_cpu"
187186
"libc10"
188-
"libtorch_global_deps"
189187
)
190188

191189
set(DISALLOWED_LIBS
@@ -206,10 +204,13 @@ foreach(lib_name IN LISTS LIBTORCH_ALL_LIBS)
206204
list(APPEND LIBTORCH_LIBS_LINKER_ARGS "-l${lib_name_short}")
207205
endforeach()
208206

209-
# cmake_print_variables(LIBTORCH_LIBS_LINKER_ARGS)
210-
# cmake_print_variables(${BRIDGE_OBJECT_FILES})
211-
# cmake_print_variables(BRIDGE_OBJECT_FILES)
212-
207+
if(LINUX)
208+
set(LIBTORCH_LINKER_ARGS
209+
"-ltorch"
210+
"-ltorch_cpu"
211+
"-lc10"
212+
)
213+
endif()
213214

214215
add_executable(TorchBridge ${BRIDGE_DIR}/lib/Bridge.chpl)
215216
add_dependencies(TorchBridge bridge)

bridge/include/bridge.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
extern "C" {
88
#endif
99

10+
#define proto_bridge_simple(Name) \
11+
bridge_tensor_t Name(bridge_tensor_t input)
12+
1013
typedef float float32_t;
1114
typedef double float64_t;
1215
typedef char bool_t;
@@ -73,6 +76,28 @@ bridge_tensor_t max_pool2d(
7376
int dilation
7477
);
7578

79+
proto_bridge_simple(relu);
80+
81+
proto_bridge_simple(relu6);
82+
83+
proto_bridge_simple(gelu);
84+
85+
proto_bridge_simple(logsigmoid);
86+
87+
proto_bridge_simple(mish);
88+
89+
proto_bridge_simple(selu);
90+
91+
proto_bridge_simple(silu);
92+
93+
proto_bridge_simple(softmax);
94+
95+
proto_bridge_simple(softmin);
96+
97+
proto_bridge_simple(softsign);
98+
99+
proto_bridge_simple(tanhshrink);
100+
76101

77102
// bridge_tensor_t conv2d(
78103
// bridge_tensor_t input,
@@ -89,4 +114,4 @@ bridge_tensor_t max_pool2d(
89114
#endif
90115

91116
#endif // BRIDGE_H
92-
//hello
117+
//hello

bridge/lib/bridge.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
#include <vector>
1515
#include <cstdint>
1616

17-
17+
#define def_bridge_simple(Name) \
18+
extern "C" bridge_tensor_t Name(bridge_tensor_t input) { \
19+
auto t_input = bridge_to_torch(input); \
20+
auto t_output = torch::Name(t_input); \
21+
return torch_to_bridge(t_output); \
22+
}
1823

1924

2025
int bridge_tensor_elements(bridge_tensor_t &bt) {
@@ -54,18 +59,6 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
5459
return torch::from_blob(bt.data, shape, torch::kFloat);
5560
}
5661

57-
58-
59-
60-
61-
62-
63-
64-
65-
66-
67-
68-
6962
extern "C" float32_t* unsafe(const float32_t* arr) {
7063
return const_cast<float32_t*>(arr);
7164
}
@@ -131,15 +124,6 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
131124
return torch_to_bridge(output);
132125
}
133126

134-
135-
136-
137-
138-
139-
140-
141-
142-
143127
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
144128
auto t = bridge_to_torch(arr);
145129
// Increment the tensor
@@ -206,6 +190,28 @@ extern "C" bridge_tensor_t max_pool2d(
206190
return torch_to_bridge(output);
207191
}
208192

193+
def_bridge_simple(relu);
194+
195+
def_bridge_simple(relu6);
196+
197+
def_bridge_simple(gelu);
198+
199+
def_bridge_simple(logsigmoid);
200+
201+
def_bridge_simple(mish);
202+
203+
def_bridge_simple(selu);
204+
205+
def_bridge_simple(silu);
206+
207+
def_bridge_simple(softmax);
208+
209+
def_bridge_simple(softmin);
210+
211+
def_bridge_simple(softsign);
212+
213+
def_bridge_simple(tanhshrink);
214+
209215

210216
// extern "C"
211217

@@ -306,4 +312,4 @@ extern "C" float sumArray(float* arr, int* sizes, int dim) {
306312

307313
// auto t = torch::from_blob(arr, shape, torch::kFloat);
308314
// return t.sum().item<float>();
309-
}
315+
}

cmake/LibTorchDL.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ endif()
2323
if(APPLE)
2424
set(TORCH_DISTRIBUTION "${TORCH_URL_PREFIX}/libtorch-macos-arm64-latest.zip")
2525
elseif(LINUX)
26-
set(TORCH_DISTRIBUTION "$${TORCH_URL_PREFIX}/libtorch-shared-with-deps-latest.zip")
26+
set(TORCH_DISTRIBUTION "${TORCH_URL_PREFIX}/libtorch-cxx11-abi-shared-with-deps-latest.zip")
2727
endif()
2828

2929
function(download_libtorch)
@@ -112,4 +112,4 @@ endfunction()
112112

113113
# dlcache("${TORCH_DISTRIBUTION}" OUT url)
114114

115-
# message(STATUS "file loc >>>>> ${url}")
115+
# message(STATUS "file loc >>>>> ${url}")

0 commit comments

Comments
 (0)