Skip to content

Commit ecc302a

Browse files
committed
mnist problems.
1 parent 583d14b commit ecc302a

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

examples/basic/mnist.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,19 @@ def train_and_evaluate_mnist():
7979

8080
# Create model: 784 -> 128 -> 64 -> 10
8181
model = Sequential(
82-
Linear(784, 128), ReLU(), Linear(128, 64), ReLU(), Linear(64, 10), Softmax()
82+
Linear(784, 128), ReLU(),
83+
Linear(128, 64), ReLU(),
84+
Linear(64, 10)
8385
)
84-
85-
print("Model architecture:")
86-
print("784 (input) -> 128 -> ReLU -> 64 -> ReLU -> 10 -> Softmax (output)")
87-
86+
8887
# Create dataset and dataloader
8988
train_dataset = Dataset(X_train_subset, y_train_subset)
9089
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
9190

9291
# Create loss function and optimizer
9392
criterion = CrossEntropyLoss()
93+
94+
9495
optimizer = Adam(model.parameters(), lr=0.001)
9596

9697
print(f"\nStarting training with batch size 32...")

fit/core/autograd.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,17 @@ class Mean(Function):
353353
def apply(ctx: Dict[str, Any], a: np.ndarray, axis=None, keepdims=False) -> np.ndarray:
354354
ctx["input_shape"] = a.shape
355355

356-
# Convert 0-d array to None
357-
if hasattr(axis, 'ndim') and axis.ndim == 0:
356+
# Handle problematic axis values
357+
if isinstance(axis, np.ndarray):
358+
if axis.ndim == 0: # 0-d array
359+
axis_val = axis.item() # Extract the scalar value
360+
if np.isnan(axis_val):
361+
axis = None
362+
else:
363+
axis = int(axis_val)
364+
else:
365+
axis = None # Multi-dimensional axis arrays not supported
366+
elif axis is not None and np.isnan(axis):
358367
axis = None
359368

360369
ctx["axis"] = axis

fit/core/tensor.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,26 @@ def mean(self, axis=None, keepdims=False):
225225
Returns:
226226
A new tensor containing the mean value
227227
"""
228-
# Use the Mean function from autograd
229-
mean_fn = get_function("mean")
230-
return mean_fn.forward(self, axis, keepdims)
228+
# Just do it directly with numpy, forget the autograd for now
229+
result_data = np.mean(self.data, axis=axis, keepdims=keepdims)
230+
result = Tensor(result_data, requires_grad=self.requires_grad)
231+
232+
if self.requires_grad:
233+
def _backward():
234+
if result.grad is not None:
235+
# Gradient of mean is 1/n
236+
if axis is None:
237+
grad = np.full_like(self.data, result.grad / self.data.size)
238+
else:
239+
# Handle axis case
240+
grad = np.full_like(self.data, result.grad / self.data.shape[axis])
241+
242+
self.grad = grad if self.grad is None else self.grad + grad
243+
244+
result._backward = _backward
245+
result._prev = {self}
246+
247+
return result
231248

232249
def exp(self):
233250
"""

0 commit comments

Comments
 (0)