Skip to content

Commit 3a90551

Browse files
authored
Merge branch 'main' into im-dev
2 parents e7acb44 + 5263a9f commit 3a90551

4 files changed

Lines changed: 36 additions & 28 deletions

File tree

CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ if(APPLE)
2929
set(CMAKE_C_COMPILER "/usr/bin/clang")
3030
set(CMAKE_CXX_COMPILER "/usr/bin/clang++")
3131
endif()
32-
set(CMAKE_CXX_STANDARD 23)
32+
set(CMAKE_CXX_STANDARD 20)
3333

3434

3535
include(LibTorchDL)
@@ -137,7 +137,7 @@ set(BRIDGE_OBJECT_FILES $<TARGET_OBJECTS:bridge>)
137137

138138

139139
file(GLOB LIBTORCH_ALL_LIB_FILES
140-
"${LIBTORCH_DIR}/lib/*.a"
140+
# "${LIBTORCH_DIR}/lib/*.a"
141141
"${LIBTORCH_DIR}/lib/*.dylib"
142142
"${LIBTORCH_DIR}/lib/*.so")
143143

@@ -152,7 +152,6 @@ set(REQUIRED_LIBS
152152
"libtorch"
153153
"libtorch_cpu"
154154
"libc10"
155-
"libtorch_global_deps"
156155
)
157156

158157
set(DISALLOWED_LIBS

bridge/include/bridge.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
extern "C" {
88
#endif
99

10+
#define proto_bridge_simple(Name) \
11+
bridge_tensor_t Name(bridge_tensor_t input)
12+
1013
typedef float float32_t;
1114
typedef double float64_t;
1215
typedef char bool_t;
@@ -77,6 +80,28 @@ bridge_tensor_t max_pool2d(
7780
int dilation
7881
);
7982

83+
proto_bridge_simple(relu);
84+
85+
proto_bridge_simple(relu6);
86+
87+
proto_bridge_simple(gelu);
88+
89+
proto_bridge_simple(logsigmoid);
90+
91+
proto_bridge_simple(mish);
92+
93+
proto_bridge_simple(selu);
94+
95+
proto_bridge_simple(silu);
96+
97+
proto_bridge_simple(softmax);
98+
99+
proto_bridge_simple(softmin);
100+
101+
proto_bridge_simple(softsign);
102+
103+
proto_bridge_simple(tanhshrink);
104+
80105

81106
// bridge_tensor_t conv2d(
82107
// bridge_tensor_t input,
@@ -93,4 +118,4 @@ bridge_tensor_t max_pool2d(
93118
#endif
94119

95120
#endif // BRIDGE_H
96-
//hello
121+
//hello

bridge/lib/bridge.cpp

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
#include <vector>
1515
#include <cstdint>
1616

17-
17+
#define def_bridge_simple(Name) \
18+
extern "C" bridge_tensor_t Name(bridge_tensor_t input) { \
19+
auto t_input = bridge_to_torch(input); \
20+
auto t_output = torch::Name(t_input); \
21+
return torch_to_bridge(t_output); \
22+
}
1823

1924

2025

@@ -56,18 +61,6 @@ torch::Tensor bridge_to_torch(bridge_tensor_t &bt) {
5661
return torch::from_blob(bt.data, shape, torch::kFloat);
5762
}
5863

59-
60-
61-
62-
63-
64-
65-
66-
67-
68-
69-
70-
7164
extern "C" float32_t* unsafe(const float32_t* arr) {
7265
return const_cast<float32_t*>(arr);
7366
}
@@ -133,15 +126,6 @@ extern "C" bridge_tensor_t load_run_model(const uint8_t* model_path, bridge_tens
133126
return torch_to_bridge(output);
134127
}
135128

136-
137-
138-
139-
140-
141-
142-
143-
144-
145129
extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
146130
auto t = bridge_to_torch(arr);
147131
// Increment the tensor
@@ -386,4 +370,4 @@ extern "C" float sumArray(float* arr, int* sizes, int dim) {
386370

387371
// auto t = torch::from_blob(arr, shape, torch::kFloat);
388372
// return t.sum().item<float>();
389-
}
373+
}

cmake/LibTorchDL.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ endfunction()
111111

112112
# dlcache("${TORCH_DISTRIBUTION}" OUT url)
113113

114-
# message(STATUS "file loc >>>>> ${url}")
114+
# message(STATUS "file loc >>>>> ${url}")

0 commit comments

Comments
 (0)