Skip to content

Commit 07c96bd

Browse files
committed
MNIST TEST WORKS!j!je
1 parent 42fa2db commit 07c96bd

File tree

3 files changed

+159
-97
lines changed

3 files changed

+159
-97
lines changed

examples/basic/mnist.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def train_and_evaluate_mnist():
8686
print(f"\nStarting training with batch size 32...")
8787

8888
# Training loop
89-
epochs = 25
89+
epochs = 10
9090
train_losses = []
9191

9292
for epoch in range(epochs):
@@ -107,12 +107,12 @@ def train_and_evaluate_mnist():
107107
# 3. Backward pass
108108
loss.backward()
109109

110-
print("After backward:")
111-
for i, param in enumerate(model.parameters()):
112-
if param.grad is not None:
113-
print(f" Param {i}: HAS gradient, shape {param.grad.shape}")
114-
else:
115-
print(f" Param {i}: NO gradient! (shape {param.data.shape})")
110+
# print("After backward:")
111+
# for i, param in enumerate(model.parameters()):
112+
# if param.grad is not None:
113+
# print(f" Param {i}: HAS gradient, shape {param.grad.shape}")
114+
# else:
115+
# print(f" Param {i}: NO gradient! (shape {param.data.shape})")
116116
# 4. Update parameters
117117
optimizer.step()
118118

@@ -189,7 +189,6 @@ def evaluate_model(model, X_test, y_test):
189189

190190
def main():
191191
"""Main function."""
192-
print("MNIST Classification Example")
193192
print("=" * 40)
194193

195194
print("Training new model...")

fit/core/autograd.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,29 @@ def backward_fn():
161161
if output.grad is not None:
162162
grads = cls.backward(ctx, output.grad)
163163

164+
# Debug print for MatMul
165+
# if cls.__name__ == "MatMul":
166+
# print(f"MatMul backward: output.grad.shape = {output.grad.shape}")
167+
# for i, grad in enumerate(grads):
168+
# if grad is not None:
169+
# print(f" grad[{i}].shape = {grad.shape}")
170+
# print(f" Assigning gradients to {len([inp for inp in tensor_inputs if inp is not None and inp.requires_grad])} tensors")
171+
# for i, inp in enumerate(tensor_inputs):
172+
# if inp is not None and inp.requires_grad:
173+
# print(f" tensor[{i}].shape = {inp.data.shape}")
174+
164175
# Assign gradients to the correct tensors
165-
grad_idx = 0
176+
# We need to match gradients with their corresponding input tensors
166177
for i, inp in enumerate(tensor_inputs):
167-
if inp is not None and inp.requires_grad:
168-
grad = grads[grad_idx] if grad_idx < len(grads) else None
178+
if inp is not None and inp.requires_grad and i < len(grads):
179+
grad = grads[i]
169180
if grad is not None:
181+
# if cls.__name__ == "MatMul":
182+
# print(f" Assigning grad[{i}] (shape {grad.shape}) to tensor[{i}] (shape {inp.data.shape})")
170183
if inp.grad is None:
171184
inp.grad = grad
172185
else:
173186
inp.grad = inp.grad + grad
174-
grad_idx += 1
175187

176188
output._backward = backward_fn
177189

@@ -277,32 +289,27 @@ class MatMul(Function):
277289
def apply(ctx: Dict[str, Any], a: np.ndarray, b: np.ndarray) -> np.ndarray:
278290
ctx["a"] = a
279291
ctx["b"] = b
292+
ctx["a_shape"] = a.shape
293+
ctx["b_shape"] = b.shape
280294
return a @ b
281295

282296
@staticmethod
283297
def backward(
284298
ctx: Dict[str, Any], grad_output: np.ndarray
285299
) -> Tuple[np.ndarray, np.ndarray]:
286300
a, b = ctx["a"], ctx["b"]
301+
a_shape, b_shape = ctx["a_shape"], ctx["b_shape"]
287302

288-
# Gradient for a: grad_output @ b.T
289-
# Handle different dimensions correctly
290-
if grad_output.ndim == 1 and b.ndim == 2:
291-
# Vector @ matrix case
292-
grad_a = grad_output.reshape(1, -1) @ b.T
293-
if a.ndim == 1:
294-
grad_a = grad_a.reshape(-1)
295-
else:
296-
grad_a = grad_output @ b.T
297-
298-
# Gradient for b: a.T @ grad_output
299-
# Handle different dimensions correctly
300-
if a.ndim == 1 and grad_output.ndim > 0:
301-
# Vector @ matrix case
302-
a_reshaped = a.reshape(-1, 1)
303-
grad_b = a_reshaped @ grad_output.reshape(1, -1)
304-
else:
305-
grad_b = a.T @ grad_output
303+
# For matrix multiplication C = A @ B:
304+
# dA = grad_output @ B.T
305+
# dB = A.T @ grad_output
306+
307+
grad_a = grad_output @ b.T
308+
grad_b = a.T @ grad_output
309+
310+
# Ensure gradients have the correct shapes
311+
assert grad_a.shape == a_shape, f"grad_a shape {grad_a.shape} != a_shape {a_shape}"
312+
assert grad_b.shape == b_shape, f"grad_b shape {grad_b.shape} != b_shape {b_shape}"
306313

307314
return grad_a, grad_b
308315

@@ -345,7 +352,7 @@ def backward(
345352

346353
class Mean(Function):
347354
"""Mean reduction function."""
348-
355+
349356
@staticmethod
350357
def apply(ctx: Dict[str, Any], a: np.ndarray, axis=None, keepdims=False) -> np.ndarray:
351358
ctx["input_shape"] = a.shape
@@ -434,73 +441,68 @@ class Reshape(Function):
434441
@staticmethod
435442
def apply(ctx: Dict[str, Any], a: np.ndarray, shape: Tuple[int, ...]) -> np.ndarray:
436443
ctx["input_shape"] = a.shape
437-
return np.reshape(a, shape)
444+
return a.reshape(shape)
438445

439446
@staticmethod
440447
def backward(
441448
ctx: Dict[str, Any], grad_output: np.ndarray
442449
) -> Tuple[np.ndarray, None]:
443450
input_shape = ctx["input_shape"]
444-
return np.reshape(grad_output, input_shape), None
451+
return grad_output.reshape(input_shape), None
445452

446453

447454
class ReLU(Function):
448-
"""Rectified Linear Unit function."""
455+
"""Rectified Linear Unit activation function."""
449456

450457
@staticmethod
451458
def apply(ctx: Dict[str, Any], a: np.ndarray) -> np.ndarray:
452459
ctx["mask"] = a > 0
453-
return np.maximum(a, 0)
460+
return np.maximum(0, a)
454461

455462
@staticmethod
456463
def backward(ctx: Dict[str, Any], grad_output: np.ndarray) -> Tuple[np.ndarray,]:
457464
mask = ctx["mask"]
458465
return (grad_output * mask,)
459466

460467

461-
class Tanh(Function):
462-
"""Hyperbolic tangent function."""
463-
464-
@staticmethod
465-
def apply(ctx: Dict[str, Any], a: np.ndarray) -> np.ndarray:
466-
result = np.tanh(a)
467-
ctx["result"] = result
468-
return result
469-
470-
@staticmethod
471-
def backward(ctx: Dict[str, Any], grad_output: np.ndarray) -> Tuple[np.ndarray,]:
472-
result = ctx["result"]
473-
# Derivative of tanh is 1 - tanh^2
474-
return (grad_output * (1 - result * result),)
475-
476-
477-
# Register functions for use with the tensor class
478-
function_registry = {
468+
# Function registry for dynamic lookup
469+
_function_registry = {
479470
"add": Add,
480471
"multiply": Multiply,
481472
"matmul": MatMul,
482473
"sum": Sum,
474+
"mean": Mean,
483475
"exp": Exp,
484476
"log": Log,
485477
"reshape": Reshape,
486478
"relu": ReLU,
487-
"mean": Mean,
488-
"tanh": Tanh,
489479
}
490480

491481

492-
# Function to get a registered function
493482
def get_function(name: str) -> Function:
494483
"""
495-
Get a registered autograd function by name.
484+
Get a function by name from the registry.
496485
497486
Args:
498487
name: Name of the function
499488
500489
Returns:
501490
Function class
491+
492+
Raises:
493+
ValueError: If the function is not found
494+
"""
495+
if name not in _function_registry:
496+
raise ValueError(f"Function '{name}' not found in registry")
497+
return _function_registry[name]
498+
499+
500+
def register_function(name: str, function: Function) -> None:
502501
"""
503-
if name not in function_registry:
504-
raise ValueError(f"Function {name} not found in registry")
502+
Register a new function in the registry.
505503
506-
return function_registry[name]
504+
Args:
505+
name: Name of the function
506+
function: Function class to register
507+
"""
508+
_function_registry[name] = function

0 commit comments

Comments
 (0)