Skip to content

Commit 94d453a

Browse files
committed
Fix cten_elemwise_broadcast function
1 parent 7616863 commit 94d453a

File tree

1 file changed

+44
-24
lines changed

1 file changed

+44
-24
lines changed

src/utils.c

+44-24
Original file line numberDiff line numberDiff line change
@@ -33,46 +33,66 @@ void cten_assert_dim(const char* title, int a, int b) {
3333
bool cten_elemwise_broadcast(Tensor* a, Tensor* b) {
3434
int a_dim = TensorShape_dim(a->shape);
3535
int b_dim = TensorShape_dim(b->shape);
36-
if(a_dim != b_dim) return false;
36+
37+
if (a_dim != b_dim) return false;
38+
3739
int a_broadcast = -1;
38-
for(int i = 0; i < a_dim; i++) {
39-
if(a->shape[i] == b->shape[i]) continue;
40-
if(a->shape[i] == 1) {
41-
if(a_broadcast == 0) return false;
40+
41+
for (int i = 0; i < a_dim; i++) {
42+
if (a->shape[i] == b->shape[i]) continue;
43+
if (a->shape[i] == 1) {
44+
if (a_broadcast == 0) return false;
4245
a_broadcast = 1;
43-
} else if(b->shape[i] == 1) {
44-
if(a_broadcast == 1) return false;
46+
} else if (b->shape[i] == 1) {
47+
if (a_broadcast == 1) return false;
4548
a_broadcast = 0;
4649
} else {
4750
return false;
4851
}
4952
}
50-
if(a_broadcast != -1) {
51-
if(a_broadcast == 0) {
53+
54+
if (a_broadcast != -1) {
55+
if (a_broadcast == 0) {
5256
Tensor* tmp = a;
5357
a = b;
5458
b = tmp;
5559
a_broadcast = 1;
5660
}
57-
Tensor a_ = Tensor_new(b->shape, a->node != NULL);
58-
for(int i = 0; i < a_.shape[0]; i++) {
59-
int i_ = a->shape[0] == 1 ? 0 : i;
60-
for(int j = 0; j < a_.shape[1]; j++) {
61-
int j_ = a->shape[1] == 1 ? 0 : j;
62-
for(int k = 0; k < a_.shape[2]; k++) {
63-
int k_ = a->shape[2] == 1 ? 0 : k;
64-
for(int l = 0; l < a_.shape[3]; l++) {
65-
int l_ = a->shape[3] == 1 ? 0 : l;
66-
// a_[i][j][k][l] = a[i_][j_][k_][l_]
67-
a_.data->flex[i * a_.shape[1] * a_.shape[2] * a_.shape[3] +
68-
j * a_.shape[2] * a_.shape[3] + k * a_.shape[3] + l] =
69-
a->data->flex[i_ * a->shape[1] * a->shape[2] * a->shape[3] +
70-
j_ * a->shape[2] * a->shape[3] + k_ * a->shape[3] + l_];
61+
62+
Tensor a_ = Tensor_zeros(b->shape, a->node != NULL);
63+
64+
int stride_a_1 = (a_dim > 1) ? a->shape[1] : 1;
65+
int stride_a_2 = (a_dim > 2) ? a->shape[2] : 1;
66+
int stride_a_3 = (a_dim > 3) ? a->shape[3] : 1;
67+
68+
int stride_a_1_new = (a_dim > 1) ? a_.shape[1] : 1;
69+
int stride_a_2_new = (a_dim > 2) ? a_.shape[2] : 1;
70+
int stride_a_3_new = (a_dim > 3) ? a_.shape[3] : 1;
71+
72+
for (int i = 0; i < a_.shape[0]; i++) {
73+
int i_ = (a->shape[0] == 1) ? 0 : i;
74+
for (int j = 0; j < ((a_dim > 1) ? a_.shape[1] : 1); j++) {
75+
int j_ = (a_dim > 1 && a->shape[1] == 1) ? 0 : j;
76+
for (int k = 0; k < ((a_dim > 2) ? a_.shape[2] : 1); k++) {
77+
int k_ = (a_dim > 2 && a->shape[2] == 1) ? 0 : k;
78+
for (int l = 0; l < ((a_dim > 3) ? a_.shape[3] : 1); l++) {
79+
int l_ = (a_dim > 3 && a->shape[3] == 1) ? 0 : l;
80+
81+
int dst_idx = i * stride_a_1_new * stride_a_2_new * stride_a_3_new +
82+
j * stride_a_2_new * stride_a_3_new +
83+
k * stride_a_3_new + l;
84+
85+
int src_idx = i_ * stride_a_1 * stride_a_2 * stride_a_3 +
86+
j_ * stride_a_2 * stride_a_3 +
87+
k_ * stride_a_3 + l_;
88+
89+
a_.data->flex[dst_idx] = a->data->flex[src_idx];
7190
}
7291
}
7392
}
7493
}
7594
*a = a_;
7695
}
96+
7797
return true;
78-
}
98+
}

0 commit comments

Comments
 (0)