Skip to content

Commit 62f0a59

Browse files
committed
Communication with libtorch now only uses tensor_result_t.
1 parent a7d5ad3 commit 62f0a59

5 files changed

Lines changed: 66 additions & 11 deletions

File tree

bridge/Bridge

67.2 KB
Binary file not shown.

bridge/bridge.o

728 Bytes
Binary file not shown.

bridge/include/bridge.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ float sumArray(float* arr, int* sizes, int dim);
2323

2424
void increment(float* arr, int* sizes, int dim, float* output);
2525
tensor_result_t increment2(float* arr, int* sizes, int dim);
26+
tensor_result_t increment3(tensor_result_t arr);
2627

2728
// void convolve(
2829
// float* input,
@@ -33,8 +34,8 @@ tensor_result_t increment2(float* arr, int* sizes, int dim);
3334
// int kernel_dim,
3435
// float* output,
3536
// int* output_sizes,
36-
// int output_dim,
37-
// )
37+
// int output_dim
38+
// );
3839

3940
#ifdef __cplusplus
4041
}

bridge/lib/Bridge.chpl

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ extern record tensor_result_t {
1414
}
1515

1616
extern proc increment2(arr: [] real(32), sizes: [] int(32), dim: int(32)): tensor_result_t;
17+
extern proc increment3(in arr: tensor_result_t): tensor_result_t;
1718

1819

1920
// baz();
@@ -86,18 +87,58 @@ proc domainFromShape(shape: int ...?rank): domain(rank,int) {
8687
return {(...ranges)};
8788
}
8889

89-
var c = increment2(a,shape.sizes,shape.rank);
9090

91-
var cShape = getResultTensorShape(shape.rank, c);
91+
proc tensorResultToArray(param rank: int, package: tensor_result_t): [] real(32) {
92+
var shape = getResultTensorShape(rank, package);
93+
var dom = domainFromShape((...shape));
94+
var result: [dom] real(32);
95+
forall i in 0..<dom.size {
96+
var idx = dom.orderToIndex(i);
97+
result[idx] = package.data[i];
98+
}
99+
return result;
100+
}
101+
// var c = increment2(a,shape.sizes,shape.rank);
92102

93-
var cDom = domainFromShape((...cShape));
103+
// var cShape = getResultTensorShape(shape.rank, c);
94104

95-
var C: [cDom] real(32);
96-
forall i in 0..<cDom.size {
97-
var idx = cDom.orderToIndex(i);
98-
C[idx] = c.data[i];
99-
}
105+
// var cDom = domainFromShape((...cShape));
106+
107+
// var C: [cDom] real(32);
108+
// forall i in 0..<cDom.size {
109+
// var idx = cDom.orderToIndex(i);
110+
// C[idx] = c.data[i];
111+
// }
112+
113+
var c = tensorResultToArray(shape.rank, increment2(a,shape.sizes,shape.rank));
114+
115+
116+
writeln("C: ", c);
117+
118+
use Allocators;
100119

101120

102-
writeln("C: ", C);
103121

122+
proc createTensorResult(ref data: [] real(32)): tensor_result_t {
123+
// var alloc = new mallocWrapper();
124+
// alloc.allocate(1024);
125+
126+
127+
var result: tensor_result_t;
128+
result.data = c_ptrTo(data);
129+
result.sizes = allocate(int(32),data.rank);
130+
const sizeArr = getSizeArray(data);
131+
for i in 0..<data.rank do
132+
result.sizes[i] = sizeArr[i];
133+
134+
// result.sizes = c_ptrTo(getSizeArray(data));
135+
result.dim = data.rank;
136+
return result;
137+
138+
// var res = newWithAllocator(alloc, tensor_result_t, result);
139+
// return res;
140+
}
141+
142+
writeln(createTensorResult(a));
143+
144+
writeln(tensorResultToArray(2,increment3(createTensorResult(a))));

bridge/lib/bridge.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,19 @@ extern "C" tensor_result_t increment2(float* arr, int* sizes, int dim) {
105105
return tensor_result_convert(incremented_tensor);
106106
}
107107

108+
extern "C" tensor_result_t increment3(tensor_result_t arr) {
109+
// Convert sizes to std::vector<int64_t>
110+
std::vector<int64_t> sizes_vec(arr.sizes, arr.sizes + arr.dim);
111+
auto shape = at::IntArrayRef(sizes_vec);
112+
auto t = torch::from_blob(arr.data, shape, torch::kFloat);
113+
114+
// Increment the tensor
115+
auto incremented_tensor = t + 1;
116+
117+
return tensor_result_convert(incremented_tensor);
118+
}
119+
120+
108121
extern "C" float sumArray(float* arr, int* sizes, int dim) {
109122
// Convert sizes to std::vector<int64_t>
110123

0 commit comments

Comments
 (0)