77
88import numpy as np
99from typing import Dict , List , Set , Callable , Optional , Tuple , Any , Union
10- from fit .core .tensor import Tensor
11-
1210
1311class Node :
1412 """
@@ -164,8 +162,8 @@ def forward(cls, *inputs: "Tensor") -> "Tensor":
164162 tensor_inputs = [
165163 inp for inp in inputs if isinstance (inp , Tensor ) and inp .requires_grad
166164 ]
167- output .parents = set (tensor_inputs )
168-
165+ output ._prev = set (tensor_inputs )
166+
169167 # Define backward function
170168 def backward_fn ():
171169 if output .grad is not None :
@@ -177,7 +175,7 @@ def backward_fn():
177175 else :
178176 inp .grad = inp .grad + grad
179177
180- output .backward_fn = backward_fn
178+ output ._backward = backward_fn
181179
182180 return output
183181
@@ -352,21 +350,21 @@ class Mean(Function):
352350 """Mean reduction function."""
353351
354352 @staticmethod
355- def apply (
356- ctx : Dict [str , Any ],
357- a : np .ndarray ,
358- axis : Optional [int ] = None ,
359- keepdims : bool = False ,
360- ) -> np .ndarray :
353+ def apply (ctx : Dict [str , Any ], a : np .ndarray , axis = None , keepdims = False ) -> np .ndarray :
361354 ctx ["input_shape" ] = a .shape
355+
356+ # Convert 0-d array to None
357+ if hasattr (axis , 'ndim' ) and axis .ndim == 0 :
358+ axis = None
359+
362360 ctx ["axis" ] = axis
363361 ctx ["keepdims" ] = keepdims
364-
362+
365363 if axis is None :
366364 ctx ["size" ] = a .size
367365 else :
368366 ctx ["size" ] = a .shape [axis ]
369-
367+
370368 return np .mean (a , axis = axis , keepdims = keepdims )
371369
372370 @staticmethod
0 commit comments