Skip to content

Commit fd42506

Browse files
committed
Update nn.c
1 parent a1e0733 commit fd42506

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/nn.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ static Tensor GradFn_relu(Tensor self, int i) {
2222
}
2323

2424
Tensor nn_relu(Tensor self) {
25-
Tensor res = Tensor_new(self.shape, self.node != NULL);
25+
bool requires_grad = !cten_is_eval() && self.node != NULL;
26+
Tensor res = Tensor_new(self.shape, requires_grad);
2627
for(int i = 0; i < self.data->numel; i++) {
2728
res.data->flex[i] = fmaxf(0, self.data->flex[i]);
2829
}
2930

30-
if(self.node != NULL) {
31+
if(requires_grad) {
3132
res.node->grad_fn = GradFn_relu;
3233
res.node->inputs[0] = self;
3334
res.node->n_inputs = 1;
@@ -51,8 +52,8 @@ static Tensor GradFn_softmax(Tensor self, int i) {
5152
}
5253

5354
Tensor nn_softmax(Tensor self) {
54-
Tensor res = Tensor_new(self.shape, self.node != NULL);
55-
55+
bool requires_grad = !cten_is_eval() && self.node != NULL;
56+
Tensor res = Tensor_new(self.shape, requires_grad);
5657
int self_dim = TensorShape_dim(self.shape);
5758
assert(self_dim > 0);
5859
int last_dim_size = self.shape[self_dim - 1];
@@ -79,12 +80,11 @@ Tensor nn_softmax(Tensor self) {
7980
}
8081
}
8182

82-
if(self.node != NULL) {
83+
if(requires_grad) {
8384
res.node->grad_fn = GradFn_softmax;
8485
res.node->inputs[0] = self;
8586
res.node->n_inputs = 1;
8687
}
87-
8888
return res;
8989
}
9090

@@ -100,11 +100,13 @@ Tensor nn_crossentropy(Tensor y_true, Tensor y_pred) {
100100
assert(n_samples == y_pred.shape[0]);
101101
assert(n_classes == y_pred.shape[1]);
102102

103-
Tensor res = Tensor_new((TensorShape){n_samples}, true);
103+
bool requires_grad = !cten_is_eval() && (y_true.node != NULL || y_pred.node != NULL);
104+
Tensor res = Tensor_new((TensorShape){n_samples}, requires_grad);
104105
for(int i = 0; i < n_samples; i++) {
105106
float loss = 0;
106107
for(int j = 0; j < n_classes; j++) {
107-
loss += y_true.data->flex[i * n_classes + j] * logf(y_pred.data->flex[i * n_classes + j]);
108+
loss +=
109+
y_true.data->flex[i * n_classes + j] * logf(y_pred.data->flex[i * n_classes + j]);
108110
}
109111
res.data->flex[i] = -loss;
110112
}

0 commit comments

Comments
 (0)