Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions bae/autograd/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,18 @@ def backward(output_):
index = output_.optrace[id(output_)][1]
arg = output_.optrace[id(output_)][2]

# check if upstream index exists
# If the last operation is indexing, there is no downstream map op to
# populate Jacobian values. In this case, the Jacobian block values are
# identity matrices placed at the indexed columns.
if not hasattr(output_, 'jactrace'):
if output_.ndim == 1:
eye_blocks = torch.ones((output_.shape[0], 1, 1), device=output_.device, dtype=output_.dtype)
else:
block_dim = output_.shape[-1]
eye = torch.eye(block_dim, device=output_.device, dtype=output_.dtype)
eye_blocks = eye.unsqueeze(0).repeat(output_.shape[0], 1, 1)
output_.jactrace = (None, eye_blocks)

if type(output_.jactrace) is tuple:
if output_.jactrace[0] is not None:
upstream_index = output_.jactrace[0]
Expand All @@ -124,7 +135,7 @@ def backward(output_):


def jacobian(output, params):
assert output.optrace[id(output)][0] == 'map', "The last operation in compute graph being indexing transform is not meaningful"
assert output.optrace[id(output)][0] in ('map', 'index'), "Unsupported last operation in compute graph"
backward(output)
res = []
for param in params:
Expand Down
29 changes: 29 additions & 0 deletions tests/autograd/test_graph_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,32 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:

assert torch.equal(J_sparse[0].col_indices(), idx_a[sel])
assert torch.equal(J_sparse[1].col_indices(), idx_b[sel])


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_sparse_jacobian_last_op_indexing_is_identity(device: str):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA not available")

torch.manual_seed(0)
dtype = torch.float64

num_a = 6
n = 8
dim = 4

A0 = torch.randn(num_a, dim, device=device, dtype=dtype, requires_grad=True)
idx_a = torch.randint(0, num_a, (n,), device=device, dtype=torch.int32)

model = nn.Parameter(Track(A0))
out = model[idx_a]

(J_sparse,) = sparse_jacobian(out, [model])
assert J_sparse.layout == torch.sparse_bsr

def f(A: torch.Tensor) -> torch.Tensor:
return A[idx_a]

(JA,) = jacrev(f, argnums=(0,))(A0)
torch.testing.assert_close(J_sparse.to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10)
assert torch.equal(J_sparse.col_indices(), idx_a)