@@ -33,46 +33,66 @@ void cten_assert_dim(const char* title, int a, int b) {
33
33
bool cten_elemwise_broadcast (Tensor * a , Tensor * b ) {
34
34
int a_dim = TensorShape_dim (a -> shape );
35
35
int b_dim = TensorShape_dim (b -> shape );
36
- if (a_dim != b_dim ) return false;
36
+
37
+ if (a_dim != b_dim ) return false;
38
+
37
39
int a_broadcast = -1 ;
38
- for (int i = 0 ; i < a_dim ; i ++ ) {
39
- if (a -> shape [i ] == b -> shape [i ]) continue ;
40
- if (a -> shape [i ] == 1 ) {
41
- if (a_broadcast == 0 ) return false;
40
+
41
+ for (int i = 0 ; i < a_dim ; i ++ ) {
42
+ if (a -> shape [i ] == b -> shape [i ]) continue ;
43
+ if (a -> shape [i ] == 1 ) {
44
+ if (a_broadcast == 0 ) return false;
42
45
a_broadcast = 1 ;
43
- } else if (b -> shape [i ] == 1 ) {
44
- if (a_broadcast == 1 ) return false;
46
+ } else if (b -> shape [i ] == 1 ) {
47
+ if (a_broadcast == 1 ) return false;
45
48
a_broadcast = 0 ;
46
49
} else {
47
50
return false;
48
51
}
49
52
}
50
- if (a_broadcast != -1 ) {
51
- if (a_broadcast == 0 ) {
53
+
54
+ if (a_broadcast != -1 ) {
55
+ if (a_broadcast == 0 ) {
52
56
Tensor * tmp = a ;
53
57
a = b ;
54
58
b = tmp ;
55
59
a_broadcast = 1 ;
56
60
}
57
- Tensor a_ = Tensor_new (b -> shape , a -> node != NULL );
58
- for (int i = 0 ; i < a_ .shape [0 ]; i ++ ) {
59
- int i_ = a -> shape [0 ] == 1 ? 0 : i ;
60
- for (int j = 0 ; j < a_ .shape [1 ]; j ++ ) {
61
- int j_ = a -> shape [1 ] == 1 ? 0 : j ;
62
- for (int k = 0 ; k < a_ .shape [2 ]; k ++ ) {
63
- int k_ = a -> shape [2 ] == 1 ? 0 : k ;
64
- for (int l = 0 ; l < a_ .shape [3 ]; l ++ ) {
65
- int l_ = a -> shape [3 ] == 1 ? 0 : l ;
66
- // a_[i][j][k][l] = a[i_][j_][k_][l_]
67
- a_ .data -> flex [i * a_ .shape [1 ] * a_ .shape [2 ] * a_ .shape [3 ] +
68
- j * a_ .shape [2 ] * a_ .shape [3 ] + k * a_ .shape [3 ] + l ] =
69
- a -> data -> flex [i_ * a -> shape [1 ] * a -> shape [2 ] * a -> shape [3 ] +
70
- j_ * a -> shape [2 ] * a -> shape [3 ] + k_ * a -> shape [3 ] + l_ ];
61
+
62
+ Tensor a_ = Tensor_zeros (b -> shape , a -> node != NULL );
63
+
64
+ int stride_a_1 = (a_dim > 1 ) ? a -> shape [1 ] : 1 ;
65
+ int stride_a_2 = (a_dim > 2 ) ? a -> shape [2 ] : 1 ;
66
+ int stride_a_3 = (a_dim > 3 ) ? a -> shape [3 ] : 1 ;
67
+
68
+ int stride_a_1_new = (a_dim > 1 ) ? a_ .shape [1 ] : 1 ;
69
+ int stride_a_2_new = (a_dim > 2 ) ? a_ .shape [2 ] : 1 ;
70
+ int stride_a_3_new = (a_dim > 3 ) ? a_ .shape [3 ] : 1 ;
71
+
72
+ for (int i = 0 ; i < a_ .shape [0 ]; i ++ ) {
73
+ int i_ = (a -> shape [0 ] == 1 ) ? 0 : i ;
74
+ for (int j = 0 ; j < ((a_dim > 1 ) ? a_ .shape [1 ] : 1 ); j ++ ) {
75
+ int j_ = (a_dim > 1 && a -> shape [1 ] == 1 ) ? 0 : j ;
76
+ for (int k = 0 ; k < ((a_dim > 2 ) ? a_ .shape [2 ] : 1 ); k ++ ) {
77
+ int k_ = (a_dim > 2 && a -> shape [2 ] == 1 ) ? 0 : k ;
78
+ for (int l = 0 ; l < ((a_dim > 3 ) ? a_ .shape [3 ] : 1 ); l ++ ) {
79
+ int l_ = (a_dim > 3 && a -> shape [3 ] == 1 ) ? 0 : l ;
80
+
81
+ int dst_idx = i * stride_a_1_new * stride_a_2_new * stride_a_3_new +
82
+ j * stride_a_2_new * stride_a_3_new +
83
+ k * stride_a_3_new + l ;
84
+
85
+ int src_idx = i_ * stride_a_1 * stride_a_2 * stride_a_3 +
86
+ j_ * stride_a_2 * stride_a_3 +
87
+ k_ * stride_a_3 + l_ ;
88
+
89
+ a_ .data -> flex [dst_idx ] = a -> data -> flex [src_idx ];
71
90
}
72
91
}
73
92
}
74
93
}
75
94
* a = a_ ;
76
95
}
96
+
77
97
return true;
78
- }
98
+ }
0 commit comments