Skip to content

Commit 889c69b

Browse files
committed
Working with tensor handles in test file. Going to try using owned.
1 parent 6a8fffa commit 889c69b

File tree

4 files changed

+140
-2
lines changed

4 files changed

+140
-2
lines changed

Diff for: CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ download_libtorch(
3535
)
3636

3737

38-
add_library(bridge OBJECT ${BRIDGE_DIR}/include/bridge.h ${BRIDGE_DIR}/lib/bridge.cpp)
38+
add_library(bridge STATIC ${BRIDGE_DIR}/include/bridge.h ${BRIDGE_DIR}/lib/bridge.cpp)
3939

4040
target_include_directories(
4141
bridge

Diff for: lib/Bridge.chpl

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
module Bridge {
2+
3+
import Utilities as util;
4+
use Utilities.Standard;
5+
use Allocators;
6+
7+
8+
extern record bridge_tensor_t {
9+
var data: c_ptr(real(32));
10+
var sizes: c_ptr(int(32));
11+
var dim: int(32);
12+
}
13+
14+
proc tensorHandle(type eltType) type {
15+
if eltType == real(32) then
16+
return bridge_tensor_t;
17+
else {
18+
compilerWarning("BridgeTensorHandle: Unsupported type");
19+
return bridge_tensor_t;
20+
}
21+
}
22+
23+
24+
extern proc convolve2d(
25+
in input: bridge_tensor_t,
26+
in kernel: bridge_tensor_t,
27+
in bias: bridge_tensor_t,
28+
in stride: int(32),
29+
in padding: int(32)): bridge_tensor_t;
30+
31+
extern proc unsafe(const ref arr: [] real(32)): c_ptr(real(32));
32+
33+
34+
proc getSizeArray(const ref arr: [] ?eltType): [] int(32) {
35+
var sizes: [0..<arr.rank] int(32);
36+
for i in 0..<arr.rank do
37+
sizes[i] = arr.dim(i).size : int(32);
38+
return sizes;
39+
}
40+
41+
proc bridgeTensorShape(param dim: int, result: bridge_tensor_t): dim*int {
42+
var shape: dim*int;
43+
for i in 0..<dim do
44+
shape[i] = result.sizes[i] : int;
45+
return shape;
46+
}
47+
48+
proc bridgeTensorToArray(param rank: int, package: bridge_tensor_t): [] real(32) {
49+
const shape = bridgeTensorShape(rank, package);
50+
const dom = util.domainFromShape((...shape));
51+
var result: [dom] real(32);
52+
forall (i,idx) in dom.everyZip() do
53+
result[idx] = package.data[i];
54+
deallocate(package.data);
55+
deallocate(package.sizes);
56+
return result;
57+
}
58+
59+
60+
proc bridgeTensorToExistingArray(ref existing: [] real(32), package: bridge_tensor_t) {
61+
const shape = bridgeTensorShape(existing.rank, package);
62+
if existing.shape != shape then
63+
util.err("BridgeTensorToExistingArray: Shape mismatch");
64+
const dom = existing.domain;
65+
forall (i,idx) in dom.everyZip() do
66+
existing[idx] = package.data[i];
67+
deallocate(package.data);
68+
deallocate(package.sizes);
69+
}
70+
71+
proc createBridgeTensor(const ref data: [] real(32)): bridge_tensor_t {
72+
var result: bridge_tensor_t;
73+
result.data = c_ptrToConst(data) : c_ptr(real(32));
74+
result.sizes = allocate(int(32),data.rank);
75+
const sizeArr = getSizeArray(data);
76+
for i in 0..<data.rank do
77+
result.sizes[i] = sizeArr[i];
78+
79+
result.dim = data.rank;
80+
return result;
81+
}
82+
83+
84+
}

Diff for: lib/NDArray.chpl

+18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import Utilities as util;
1414
use Utilities.Standard;
1515
use Utilities.Types;
1616

17+
import Bridge;
18+
1719
type domainType = _domain(?);
1820

1921
/* The most fundamental tensor type.
@@ -1245,6 +1247,22 @@ proc ndarray.degenerateFlatten(): [] eltType {
12451247
return flat;
12461248
}
12471249

1250+
proc ndarray.toBridgeTensor(): Bridge.tensorHandle(eltType) do
1251+
return Bridge.createBridgeTensor(this.data);
1252+
1253+
proc type ndarray.fromBridgeTensor(param rank: int, handle: Bridge.tensorHandle(eltType)): ndarray(rank,real(32)) {
1254+
const arr = Bridge.bridgeTensorToArray(rank,handle);
1255+
return new ndarray(arr);
1256+
}
1257+
1258+
1259+
proc ref ndarray.loadFromBridgeTensor(handle: Bridge.tensorHandle(eltType)): void {
1260+
const shape = Bridge.bridgeTensorShape(rank,handle);
1261+
if shape != this.shape then
1262+
this.reshapeDomain(util.domainFromShape((...shape)));
1263+
Bridge.bridgeTensorToExistingArray(this.data,handle);
1264+
}
1265+
12481266
proc ndarray.shapeArray(): [] int do
12491267
return util.tupleToArray((...this.shape));
12501268

Diff for: test/tiny/layer_test.chpl

+37-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ use Tensor;
22
use Layer;
33
use Network except ReLU, Linear, Flatten;
44

5+
use Bridge;
6+
use Utilities as util;
57

68
var x = Tensor.arange(2,3);
79
writeln(x);
@@ -33,4 +35,38 @@ var net2 = new Sequential(
3335

3436
y = net2(x);
3537
writeln(y);
36-
writeln(net2.signature);
38+
writeln(net2.signature);
39+
40+
41+
42+
var dom = {0..<10, 0..<10};
43+
var a: [dom] real(32);
44+
for (idx,i) in zip(dom,0..<dom.size) do
45+
a[idx] = i:real(32);
46+
47+
48+
49+
var input: [util.domainFromShape(2,64,28,28)] real(32) = 1.0;
50+
var kernel: [util.domainFromShape(128,64,3,3)] real(32) = 2.0;
51+
var bias: [util.domainFromShape(128)] real(32) = 3.0;
52+
var stride: int(32) = 1;
53+
var padding: int(32) = 1;
54+
writeln("Begin.");
55+
var resultBT = convolve2d(createBridgeTensor(input), createBridgeTensor(kernel), createBridgeTensor(bias), stride, padding);
56+
var result = bridgeTensorToArray(4, resultBT);
57+
// writeln("Input: ", input);
58+
// writeln("Kernel: ", kernel);
59+
// writeln("Bias: ", bias);
60+
// writeln("Result: ", result);
61+
writeln("Result: ", result.size);
62+
63+
64+
var arr = ndarray.arange(2,3);
65+
writeln(arr);
66+
writeln(arr.toBridgeTensor());
67+
68+
var bt = arr.toBridgeTensor();
69+
arr.loadFromBridgeTensor(bt);
70+
writeln(arr);
71+
writeln(arr.shape);
72+
writeln(ndarray.fromBridgeTensor(2,bt));

0 commit comments

Comments
 (0)