Skip to content

Commit f68e3c0

Browse files
committed
Improve chapel/pytorch interop with unsafe pointers.
1 parent 5594567 commit f68e3c0

7 files changed

Lines changed: 58 additions & 15 deletions

File tree

bridge/.DS_Store

0 Bytes
Binary file not shown.

bridge/Bridge

110 KB
Binary file not shown.

bridge/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
all: bridge.o Bridge
23

34
bridge.o: lib/bridge.cpp include/bridge.h
45
g++ -c lib/bridge.cpp -I include -o bridge.o -I /Users/iainmoncrief/Documents/Github/ChAI/bridge/libtorch/include/torch/csrc/api/include -I /Users/iainmoncrief/Documents/Github/ChAI/bridge/libtorch/include --std=c++17

bridge/bridge.o

8.49 KB
Binary file not shown.

bridge/include/bridge.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ typedef struct bridge_tensor_t {
1313
int dim;
1414
} bridge_tensor_t;
1515

16+
typedef float float32_t;
17+
typedef double float64_t;
18+
1619
int baz(void);
1720

1821
void wrHello(void);
@@ -25,17 +28,16 @@ void increment(float* arr, int* sizes, int dim, float* output);
2528
bridge_tensor_t increment2(float* arr, int* sizes, int dim);
2629
bridge_tensor_t increment3(bridge_tensor_t arr);
2730

28-
// void convolve(
29-
// float* input,
30-
// int* input_sizes,
31-
// int input_dim,
32-
// float* kernel,
33-
// int* kernel_sizes,
34-
// int kernel_dim,
35-
// float* output,
36-
// int* output_sizes,
37-
// int output_dim
38-
// );
31+
bridge_tensor_t convolve2d(
32+
bridge_tensor_t input,
33+
bridge_tensor_t kernel,
34+
bridge_tensor_t bias,
35+
int stride,
36+
int padding
37+
);
38+
39+
float* unsafe(const float* arr);
40+
3941

4042
#ifdef __cplusplus
4143
}

bridge/lib/Bridge.chpl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ extern record bridge_tensor_t {
2323
extern proc increment2(arr: [] real(32), sizes: [] int(32), dim: int(32)): bridge_tensor_t;
2424
extern proc increment3(in arr: bridge_tensor_t): bridge_tensor_t;
2525

26+
extern proc convolve2d(
27+
in input: bridge_tensor_t,
28+
in kernel: bridge_tensor_t,
29+
in bias: bridge_tensor_t,
30+
in stride: int(32),
31+
in padding: int(32)): bridge_tensor_t;
32+
33+
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
2634

2735
// baz();
2836

@@ -127,9 +135,9 @@ proc bridgeTensorToArray(param rank: int, package: bridge_tensor_t): [] real(32)
127135
}
128136

129137

130-
proc createBridgeTensor(ref data: [] real(32)): bridge_tensor_t {
138+
proc createBridgeTensor(const ref data: [] real(32)): bridge_tensor_t {
131139
var result: bridge_tensor_t;
132-
result.data = c_ptrTo(data);
140+
result.data = c_ptrToConst(data) : c_ptr(real(32));
133141
result.sizes = allocate(int(32),data.rank);
134142
const sizeArr = getSizeArray(data);
135143
for i in 0..<data.rank do
@@ -169,4 +177,19 @@ writeln(a);
169177
writeln("----------");
170178
writeln(chplIncrement(a));
171179
writeln("----------");
172-
writeln(a);
180+
writeln(a);
181+
182+
183+
var input: [domainFromShape(2,64,28,28)] real(32) = 1.0;
184+
var kernel: [domainFromShape(128,64,3,3)] real(32) = 2.0;
185+
var bias: [domainFromShape(128)] real(32) = 3.0;
186+
var stride: int(32) = 1;
187+
var padding: int(32) = 1;
188+
writeln("Begin.");
189+
var resultBT = convolve2d(createBridgeTensor(input), createBridgeTensor(kernel), createBridgeTensor(bias), stride, padding);
190+
var result = bridgeTensorToArray(4, resultBT);
191+
// writeln("Input: ", input);
192+
// writeln("Kernel: ", kernel);
193+
// writeln("Bias: ", bias);
194+
// writeln("Result: ", result);
195+
writeln("Result: ", result.size);

bridge/lib/bridge.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88

99
#include <cstdint>
1010

11-
using float32_t = float;
1211

1312

13+
extern "C" float32_t* unsafe(const float32_t* arr) {
14+
return const_cast<float32_t*>(arr);
15+
}
16+
1417
int bridge_tensor_elements(bridge_tensor_t &bt) {
1518
int size = 1;
1619
for (int i = 0; i < bt.dim; ++i) {
@@ -56,6 +59,20 @@ extern "C" bridge_tensor_t increment3(bridge_tensor_t arr) {
5659
return tensor_result_convert(incremented_tensor);
5760
}
5861

62+
extern "C" bridge_tensor_t convolve2d(
63+
bridge_tensor_t input,
64+
bridge_tensor_t kernel,
65+
bridge_tensor_t bias,
66+
int stride,
67+
int padding
68+
) {
69+
auto t_input = bridge_to_torch(input);
70+
auto t_kernel = bridge_to_torch(kernel);
71+
auto t_bias = bridge_to_torch(bias);
72+
auto output = torch::conv2d(t_input, t_kernel, t_bias, stride, padding);
73+
return tensor_result_convert(output);
74+
}
75+
5976

6077
extern "C" int baz(void) {
6178
printf("Hello from baz!\n");

0 commit comments

Comments
 (0)