diff --git a/bae/autograd/graph.py b/bae/autograd/graph.py index 860e99d..f8ded3e 100644 --- a/bae/autograd/graph.py +++ b/bae/autograd/graph.py @@ -105,7 +105,18 @@ def backward(output_): index = output_.optrace[id(output_)][1] arg = output_.optrace[id(output_)][2] - # check if upstream index exists + # If the last operation is indexing, there is no downstream map op to + # populate Jacobian values. In this case, the Jacobian block values are + # identity matrices placed at the indexed columns. + if not hasattr(output_, 'jactrace'): + if output_.ndim == 1: + eye_blocks = torch.ones((output_.shape[0], 1, 1), device=output_.device, dtype=output_.dtype) + else: + block_dim = output_.shape[-1] + eye = torch.eye(block_dim, device=output_.device, dtype=output_.dtype) + eye_blocks = eye.unsqueeze(0).repeat(output_.shape[0], 1, 1) + output_.jactrace = (None, eye_blocks) + if type(output_.jactrace) is tuple: if output_.jactrace[0] is not None: upstream_index = output_.jactrace[0] @@ -124,7 +135,7 @@ def backward(output_): def jacobian(output, params): - assert output.optrace[id(output)][0] == 'map', "The last operation in compute graph being indexing transform is not meaningful" + assert output.optrace[id(output)][0] in ('map', 'index'), "Unsupported last operation in compute graph" backward(output) res = [] for param in params: diff --git a/tests/autograd/test_graph_jacobian.py b/tests/autograd/test_graph_jacobian.py index d036939..3b3fc4c 100644 --- a/tests/autograd/test_graph_jacobian.py +++ b/tests/autograd/test_graph_jacobian.py @@ -71,3 +71,32 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: assert torch.equal(J_sparse[0].col_indices(), idx_a[sel]) assert torch.equal(J_sparse[1].col_indices(), idx_b[sel]) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_sparse_jacobian_last_op_indexing_is_identity(device: str): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + torch.manual_seed(0) + dtype = torch.float64 + + num_a = 6 + n = 8 + dim = 4 + + A0 = torch.randn(num_a, dim, device=device, dtype=dtype, requires_grad=True) + idx_a = torch.randint(0, num_a, (n,), device=device, dtype=torch.int32) + + model = nn.Parameter(Track(A0)) + out = model[idx_a] + + (J_sparse,) = sparse_jacobian(out, [model]) + assert J_sparse.layout == torch.sparse_bsr + + def f(A: torch.Tensor) -> torch.Tensor: + return A[idx_a] + + (JA,) = jacrev(f, argnums=(0,))(A0) + torch.testing.assert_close(J_sparse.to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10) + assert torch.equal(J_sparse.col_indices(), idx_a)