Skip to content

Commit 583d14b

Browse files
committed
Fixing autograd and tensor modules (no T attribute, autograd chain)
1 parent dbf969a commit 583d14b

File tree

3 files changed

+19
-94
lines changed

3 files changed

+19
-94
lines changed

fit/core/autograd.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import numpy as np
99
from typing import Dict, List, Set, Callable, Optional, Tuple, Any, Union
10-
from fit.core.tensor import Tensor
11-
1210

1311
class Node:
1412
"""
@@ -164,8 +162,8 @@ def forward(cls, *inputs: "Tensor") -> "Tensor":
164162
tensor_inputs = [
165163
inp for inp in inputs if isinstance(inp, Tensor) and inp.requires_grad
166164
]
167-
output.parents = set(tensor_inputs)
168-
165+
output._prev = set(tensor_inputs)
166+
169167
# Define backward function
170168
def backward_fn():
171169
if output.grad is not None:
@@ -177,7 +175,7 @@ def backward_fn():
177175
else:
178176
inp.grad = inp.grad + grad
179177

180-
output.backward_fn = backward_fn
178+
output._backward = backward_fn
181179

182180
return output
183181

@@ -352,21 +350,21 @@ class Mean(Function):
352350
"""Mean reduction function."""
353351

354352
@staticmethod
355-
def apply(
356-
ctx: Dict[str, Any],
357-
a: np.ndarray,
358-
axis: Optional[int] = None,
359-
keepdims: bool = False,
360-
) -> np.ndarray:
353+
def apply(ctx: Dict[str, Any], a: np.ndarray, axis=None, keepdims=False) -> np.ndarray:
361354
ctx["input_shape"] = a.shape
355+
356+
# Convert 0-d array to None
357+
if hasattr(axis, 'ndim') and axis.ndim == 0:
358+
axis = None
359+
362360
ctx["axis"] = axis
363361
ctx["keepdims"] = keepdims
364-
362+
365363
if axis is None:
366364
ctx["size"] = a.size
367365
else:
368366
ctx["size"] = a.shape[axis]
369-
367+
370368
return np.mean(a, axis=axis, keepdims=keepdims)
371369

372370
@staticmethod

fit/core/tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,11 @@ def build_topo(v):
311311
# Call backward functions in reverse topological order
312312
for node in reversed(topo):
313313
node._backward()
314+
315+
@property
316+
def T(self):
317+
"""Transpose property for 2D tensors."""
318+
if len(self.data.shape) == 2:
319+
return Tensor(self.data.T, requires_grad=self.requires_grad)
320+
else:
321+
raise ValueError("T property only works for 2D tensors")

fit/run_all_tests.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)