1+ module Bridge {
2+
3+ import Utilities as util;
4+ use Utilities.Standard;
5+ use Allocators;
6+
7+
8+ extern record bridge_tensor_t {
9+ var data: c_ptr(real (32 ));
10+ var sizes: c_ptr(int (32 ));
11+ var dim: int (32 );
12+ }
13+
14+ proc tensorHandle (type eltType) type {
15+ if eltType == real (32 ) then
16+ return bridge_tensor_t;
17+ else {
18+ compilerWarning(" BridgeTensorHandle: Unsupported type" );
19+ return bridge_tensor_t;
20+ }
21+ }
22+
23+
24+ extern proc convolve2d(
25+ in input: bridge_tensor_t,
26+ in kernel: bridge_tensor_t,
27+ in bias: bridge_tensor_t,
28+ in stride: int (32 ),
29+ in padding: int (32 )): bridge_tensor_t;
30+
31+ extern proc unsafe(const ref arr: [] real (32 )): c_ptr(real (32 ));
32+
33+
34+ proc getSizeArray (const ref arr: [] ?eltType): [] int (32 ) {
35+ var sizes: [ 0 ..< arr.rank] int (32 );
36+ for i in 0 ..< arr.rank do
37+ sizes[ i] = arr.dim(i).size : int (32 );
38+ return sizes;
39+ }
40+
41+ proc bridgeTensorShape (param dim: int , result: bridge_tensor_t): dim* int {
42+ var shape: dim* int ;
43+ for i in 0 ..< dim do
44+ shape[ i] = result.sizes[ i] : int ;
45+ return shape;
46+ }
47+
48+ proc bridgeTensorToArray (param rank: int , package: bridge_tensor_t): [] real (32 ) {
49+ const shape = bridgeTensorShape(rank, package);
50+ const dom = util.domainFromShape((...shape));
51+ var result: [ dom] real (32 );
52+ forall (i,idx) in dom.everyZip() do
53+ result[ idx] = package.data[ i] ;
54+ deallocate(package.data);
55+ deallocate(package.sizes);
56+ return result;
57+ }
58+
59+
60+ proc bridgeTensorToExistingArray (ref existing: [] real (32 ), package: bridge_tensor_t) {
61+ const shape = bridgeTensorShape(existing.rank, package);
62+ if existing.shape != shape then
63+ util.err(" BridgeTensorToExistingArray: Shape mismatch" );
64+ const dom = existing.domain ;
65+ forall (i,idx) in dom.everyZip() do
66+ existing[ idx] = package.data[ i] ;
67+ deallocate(package.data);
68+ deallocate(package.sizes);
69+ }
70+
71+ proc createBridgeTensor (const ref data: [] real (32 )): bridge_tensor_t {
72+ var result: bridge_tensor_t;
73+ result.data = c_ptrToConst(data) : c_ptr(real (32 ));
74+ result.sizes = allocate(int (32 ),data.rank);
75+ const sizeArr = getSizeArray(data);
76+ for i in 0 ..< data.rank do
77+ result.sizes[ i] = sizeArr[ i] ;
78+
79+ result.dim = data.rank;
80+ return result;
81+ }
82+
83+
84+ }
0 commit comments