Skip to content

Commit 9cde086

Browse files
committed
Tensor Handle allocator working better.
1 parent 889c69b commit 9cde086

File tree

4 files changed

+18
-6
lines changed

4 files changed

+18
-6
lines changed

bridge/include/bridge.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77
extern "C" {
88
#endif
99

10+
typedef float float32_t;
11+
typedef double float64_t;
12+
typedef char bool_t;
13+
1014
typedef struct bridge_tensor_t {
1115
float* data;
1216
int* sizes;
1317
int dim;
18+
bool_t created_by_c;
1419
} bridge_tensor_t;
1520

16-
typedef float float32_t;
17-
typedef double float64_t;
1821

1922
int baz(void);
2023

bridge/lib/bridge.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ void store_tensor(torch::Tensor &input, float32_t* dest) {
3434

3535
bridge_tensor_t torch_to_bridge(torch::Tensor &tensor) {
3636
bridge_tensor_t result;
37+
result.created_by_c = true;
3738
result.dim = tensor.dim();
3839
result.sizes = new int[result.dim];
3940
for (int i = 0; i < result.dim; ++i) {

lib/Bridge.chpl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ module Bridge {
99
var data: c_ptr(real(32));
1010
var sizes: c_ptr(int(32));
1111
var dim: int(32);
12+
var created_by_c: bool;
1213
}
1314

1415
proc tensorHandle(type eltType) type {
@@ -51,8 +52,10 @@ module Bridge {
5152
var result: [dom] real(32);
5253
forall (i,idx) in dom.everyZip() do
5354
result[idx] = package.data[i];
54-
deallocate(package.data);
55-
deallocate(package.sizes);
55+
if package.created_by_c {
56+
deallocate(package.data);
57+
deallocate(package.sizes);
58+
}
5659
return result;
5760
}
5861

@@ -64,8 +67,10 @@ module Bridge {
6467
const dom = existing.domain;
6568
forall (i,idx) in dom.everyZip() do
6669
existing[idx] = package.data[i];
67-
deallocate(package.data);
68-
deallocate(package.sizes);
70+
if package.created_by_c {
71+
deallocate(package.data);
72+
deallocate(package.sizes);
73+
}
6974
}
7075

7176
proc createBridgeTensor(const ref data: [] real(32)): bridge_tensor_t {

test/tiny/layer_test.chpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ writeln(arr);
6666
writeln(arr.toBridgeTensor());
6767

6868
var bt = arr.toBridgeTensor();
69+
writeln(bt);
6970
arr.loadFromBridgeTensor(bt);
7071
writeln(arr);
7172
writeln(arr.shape);
73+
writeln(bt);
74+
7275
writeln(ndarray.fromBridgeTensor(2,bt));

0 commit comments

Comments
 (0)