Skip to content

Commit f5f7313

Browse files
authored
[feat] relax backward function to handle indexing as last operation (#10)
1 parent 69c805a commit f5f7313

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

bae/autograd/graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,18 @@ def backward(output_):
105105
index = output_.optrace[id(output_)][1]
106106
arg = output_.optrace[id(output_)][2]
107107

108-
# check if upstream index exists
108+
# If the last operation is indexing, there is no downstream map op to
109+
# populate Jacobian values. In this case, the Jacobian block values are
110+
# identity matrices placed at the indexed columns.
111+
if not hasattr(output_, 'jactrace'):
112+
if output_.ndim == 1:
113+
eye_blocks = torch.ones((output_.shape[0], 1, 1), device=output_.device, dtype=output_.dtype)
114+
else:
115+
block_dim = output_.shape[-1]
116+
eye = torch.eye(block_dim, device=output_.device, dtype=output_.dtype)
117+
eye_blocks = eye.unsqueeze(0).repeat(output_.shape[0], 1, 1)
118+
output_.jactrace = (None, eye_blocks)
119+
109120
if type(output_.jactrace) is tuple:
110121
if output_.jactrace[0] is not None:
111122
upstream_index = output_.jactrace[0]
@@ -124,7 +135,7 @@ def backward(output_):
124135

125136

126137
def jacobian(output, params):
127-
assert output.optrace[id(output)][0] == 'map', "The last operation in compute graph being indexing transform is not meaningful"
138+
assert output.optrace[id(output)][0] in ('map', 'index'), "Unsupported last operation in compute graph"
128139
backward(output)
129140
res = []
130141
for param in params:

tests/autograd/test_graph_jacobian.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,32 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
7171

7272
assert torch.equal(J_sparse[0].col_indices(), idx_a[sel])
7373
assert torch.equal(J_sparse[1].col_indices(), idx_b[sel])
74+
75+
76+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
77+
def test_sparse_jacobian_last_op_indexing_is_identity(device: str):
78+
if device == "cuda" and not torch.cuda.is_available():
79+
pytest.skip("CUDA not available")
80+
81+
torch.manual_seed(0)
82+
dtype = torch.float64
83+
84+
num_a = 6
85+
n = 8
86+
dim = 4
87+
88+
A0 = torch.randn(num_a, dim, device=device, dtype=dtype, requires_grad=True)
89+
idx_a = torch.randint(0, num_a, (n,), device=device, dtype=torch.int32)
90+
91+
model = nn.Parameter(Track(A0))
92+
out = model[idx_a]
93+
94+
(J_sparse,) = sparse_jacobian(out, [model])
95+
assert J_sparse.layout == torch.sparse_bsr
96+
97+
def f(A: torch.Tensor) -> torch.Tensor:
98+
return A[idx_a]
99+
100+
(JA,) = jacrev(f, argnums=(0,))(A0)
101+
torch.testing.assert_close(J_sparse.to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10)
102+
assert torch.equal(J_sparse.col_indices(), idx_a)

0 commit comments

Comments
 (0)