Skip to content

Commit 7616863

Browse files
committed
Implement and fix nn GradFns
Fix GradFn_softmax to index correctly and store only the vector jacobian product. Implement GradFn_crossentropy for proper backpropagation.
1 parent fd42506 commit 7616863

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

src/nn.c

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,23 @@ Tensor nn_relu(Tensor self) {
3939
/* nn.softmax */
4040
static Tensor GradFn_softmax(Tensor self, int i) {
4141
Tensor input = self.node->inputs[i];
42-
Tensor res = Tensor_new(input.shape, false);
43-
for(int j = 0; j < input.data->numel; j++) {
44-
float softmax_j = self.data->flex[j];
45-
for(int k = 0; k < input.data->numel; k++) {
46-
float softmax_k = self.data->flex[k];
47-
float delta_jk = (j == k) ? 1.0f : 0.0f;
48-
res.data->flex[j * input.data->numel + k] = softmax_j * (delta_jk - softmax_k);
42+
int num_classes = input.shape[TensorShape_dim(input.shape) - 1];
43+
int batch_size = input.data->numel / num_classes;
44+
45+
Tensor res = Tensor_zeros(input.shape, false);
46+
47+
for (int batch = 0; batch < batch_size; batch++) {
48+
float grad_sum = 0.0f;
49+
50+
for (int k = 0; k < num_classes; k++) {
51+
grad_sum += self.data->flex[batch * num_classes + k] * self.node->grad.data->flex[batch * num_classes + k];
52+
}
53+
54+
for (int j = 0; j < num_classes; j++) {
55+
float softmax_j = self.data->flex[batch * num_classes + j];
56+
float grad_j = self.node->grad.data->flex[batch * num_classes + j];
57+
58+
res.data->flex[batch * num_classes + j] = softmax_j * (grad_j - grad_sum);
4959
}
5060
}
5161
return res;
@@ -88,6 +98,27 @@ Tensor nn_softmax(Tensor self) {
8898
return res;
8999
}
90100

101+
Tensor GradFn_crossentropy(Tensor self, int i) {
102+
Tensor y_true = self.node->inputs[0];
103+
Tensor y_pred = self.node->inputs[1];
104+
105+
Tensor grad = Tensor_zeros(y_pred.shape, false);
106+
107+
if (i == 1) {
108+
// f'(y_true, y_pred) = -y_true / y_pred
109+
int n_samples = y_pred.shape[0];
110+
int n_classes = y_pred.shape[1];
111+
112+
for (int s = 0; s < n_samples; s++) {
113+
for (int c = 0; c < n_classes; c++) {
114+
int idx = s * n_classes + c;
115+
grad.data->flex[idx] = -y_true.data->flex[idx] / y_pred.data->flex[idx];
116+
}
117+
}
118+
}
119+
return grad;
120+
}
121+
91122
/* nn.cross_entropy */
92123
Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
93124
// y_true: [None, n_classes]
@@ -101,7 +132,7 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
101132
assert(n_classes == y_pred.shape[1]);
102133

103134
bool requires_grad = !cten_is_eval() && (y_true.node != NULL || y_pred.node != NULL);
104-
Tensor res = Tensor_new((TensorShape){n_samples}, requires_grad);
135+
Tensor res = Tensor_new((TensorShape){n_samples, 1}, requires_grad);
105136
for(int i = 0; i < n_samples; i++) {
106137
float loss = 0;
107138
for(int j = 0; j < n_classes; j++) {
@@ -110,5 +141,12 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
110141
}
111142
res.data->flex[i] = -loss;
112143
}
144+
145+
if (requires_grad) {
146+
res.node->grad_fn = GradFn_crossentropy;
147+
res.node->inputs[0] = y_true;
148+
res.node->inputs[1] = y_pred;
149+
res.node->n_inputs = 2;
150+
}
113151
return Tensor_mean(res);
114152
}

0 commit comments

Comments
 (0)