@@ -161,17 +161,29 @@ def backward_fn():
161161 if output .grad is not None :
162162 grads = cls .backward (ctx , output .grad )
163163
164+ # Debug print for MatMul
165+ # if cls.__name__ == "MatMul":
166+ # print(f"MatMul backward: output.grad.shape = {output.grad.shape}")
167+ # for i, grad in enumerate(grads):
168+ # if grad is not None:
169+ # print(f" grad[{i}].shape = {grad.shape}")
170+ # print(f" Assigning gradients to {len([inp for inp in tensor_inputs if inp is not None and inp.requires_grad])} tensors")
171+ # for i, inp in enumerate(tensor_inputs):
172+ # if inp is not None and inp.requires_grad:
173+ # print(f" tensor[{i}].shape = {inp.data.shape}")
174+
164175 # Assign gradients to the correct tensors
165- grad_idx = 0
176+ # We need to match gradients with their corresponding input tensors
166177 for i , inp in enumerate (tensor_inputs ):
167- if inp is not None and inp .requires_grad :
168- grad = grads [grad_idx ] if grad_idx < len ( grads ) else None
178+ if inp is not None and inp .requires_grad and i < len ( grads ) :
179+ grad = grads [i ]
169180 if grad is not None :
181+ # if cls.__name__ == "MatMul":
182+ # print(f" Assigning grad[{i}] (shape {grad.shape}) to tensor[{i}] (shape {inp.data.shape})")
170183 if inp .grad is None :
171184 inp .grad = grad
172185 else :
173186 inp .grad = inp .grad + grad
174- grad_idx += 1
175187
176188 output ._backward = backward_fn
177189
@@ -277,32 +289,27 @@ class MatMul(Function):
277289 def apply (ctx : Dict [str , Any ], a : np .ndarray , b : np .ndarray ) -> np .ndarray :
278290 ctx ["a" ] = a
279291 ctx ["b" ] = b
292+ ctx ["a_shape" ] = a .shape
293+ ctx ["b_shape" ] = b .shape
280294 return a @ b
281295
282296 @staticmethod
283297 def backward (
284298 ctx : Dict [str , Any ], grad_output : np .ndarray
285299 ) -> Tuple [np .ndarray , np .ndarray ]:
286300 a , b = ctx ["a" ], ctx ["b" ]
301+ a_shape , b_shape = ctx ["a_shape" ], ctx ["b_shape" ]
287302
288- # Gradient for a: grad_output @ b.T
289- # Handle different dimensions correctly
290- if grad_output .ndim == 1 and b .ndim == 2 :
291- # Vector @ matrix case
292- grad_a = grad_output .reshape (1 , - 1 ) @ b .T
293- if a .ndim == 1 :
294- grad_a = grad_a .reshape (- 1 )
295- else :
296- grad_a = grad_output @ b .T
297-
298- # Gradient for b: a.T @ grad_output
299- # Handle different dimensions correctly
300- if a .ndim == 1 and grad_output .ndim > 0 :
301- # Vector @ matrix case
302- a_reshaped = a .reshape (- 1 , 1 )
303- grad_b = a_reshaped @ grad_output .reshape (1 , - 1 )
304- else :
305- grad_b = a .T @ grad_output
303+ # For matrix multiplication C = A @ B:
304+ # dA = grad_output @ B.T
305+ # dB = A.T @ grad_output
306+
307+ grad_a = grad_output @ b .T
308+ grad_b = a .T @ grad_output
309+
310+ # Ensure gradients have the correct shapes
311+ assert grad_a .shape == a_shape , f"grad_a shape { grad_a .shape } != a_shape { a_shape } "
312+ assert grad_b .shape == b_shape , f"grad_b shape { grad_b .shape } != b_shape { b_shape } "
306313
307314 return grad_a , grad_b
308315
@@ -345,7 +352,7 @@ def backward(
345352
346353class Mean (Function ):
347354 """Mean reduction function."""
348-
355+
349356 @staticmethod
350357 def apply (ctx : Dict [str , Any ], a : np .ndarray , axis = None , keepdims = False ) -> np .ndarray :
351358 ctx ["input_shape" ] = a .shape
@@ -434,73 +441,68 @@ class Reshape(Function):
434441 @staticmethod
435442 def apply (ctx : Dict [str , Any ], a : np .ndarray , shape : Tuple [int , ...]) -> np .ndarray :
436443 ctx ["input_shape" ] = a .shape
437- return np .reshape (a , shape )
444+ return a .reshape (shape )
438445
439446 @staticmethod
440447 def backward (
441448 ctx : Dict [str , Any ], grad_output : np .ndarray
442449 ) -> Tuple [np .ndarray , None ]:
443450 input_shape = ctx ["input_shape" ]
444- return np .reshape (grad_output , input_shape ), None
451+ return grad_output .reshape (input_shape ), None
445452
446453
447454class ReLU (Function ):
448- """Rectified Linear Unit function."""
455+ """Rectified Linear Unit activation function."""
449456
450457 @staticmethod
451458 def apply (ctx : Dict [str , Any ], a : np .ndarray ) -> np .ndarray :
452459 ctx ["mask" ] = a > 0
453- return np .maximum (a , 0 )
460+ return np .maximum (0 , a )
454461
455462 @staticmethod
456463 def backward (ctx : Dict [str , Any ], grad_output : np .ndarray ) -> Tuple [np .ndarray ,]:
457464 mask = ctx ["mask" ]
458465 return (grad_output * mask ,)
459466
460467
461- class Tanh (Function ):
462- """Hyperbolic tangent function."""
463-
464- @staticmethod
465- def apply (ctx : Dict [str , Any ], a : np .ndarray ) -> np .ndarray :
466- result = np .tanh (a )
467- ctx ["result" ] = result
468- return result
469-
470- @staticmethod
471- def backward (ctx : Dict [str , Any ], grad_output : np .ndarray ) -> Tuple [np .ndarray ,]:
472- result = ctx ["result" ]
473- # Derivative of tanh is 1 - tanh^2
474- return (grad_output * (1 - result * result ),)
475-
476-
477- # Register functions for use with the tensor class
478- function_registry = {
468+ # Function registry for dynamic lookup
469+ _function_registry = {
479470 "add" : Add ,
480471 "multiply" : Multiply ,
481472 "matmul" : MatMul ,
482473 "sum" : Sum ,
474+ "mean" : Mean ,
483475 "exp" : Exp ,
484476 "log" : Log ,
485477 "reshape" : Reshape ,
486478 "relu" : ReLU ,
487- "mean" : Mean ,
488- "tanh" : Tanh ,
489479}
490480
491481
492- # Function to get a registered function
493482def get_function (name : str ) -> Function :
494483 """
495- Get a registered autograd function by name.
484+ Get a function by name from the registry .
496485
497486 Args:
498487 name: Name of the function
499488
500489 Returns:
501490 Function class
491+
492+ Raises:
493+ ValueError: If the function is not found
494+ """
495+ if name not in _function_registry :
496+ raise ValueError (f"Function '{ name } ' not found in registry" )
497+ return _function_registry [name ]
498+
499+
500+ def register_function (name : str , function : Function ) -> None :
502501 """
503- if name not in function_registry :
504- raise ValueError (f"Function { name } not found in registry" )
502+ Register a new function in the registry.
505503
506- return function_registry [name ]
504+ Args:
505+ name: Name of the function
506+ function: Function class to register
507+ """
508+ _function_registry [name ] = function
0 commit comments