Skip to content

Commit 35e5500

Browse files
committed
add the index cat forward pass
1 parent 69c805a commit 35e5500

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

bae/autograd/function.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
3636
result.optrace = {}
3737
index_edge = ("index", args[1], args[0])
3838
result.optrace[id(result)] = index_edge
39+
elif func in (torch.cat, torch.concat):
40+
if kwargs is None:
41+
kwargs = {}
42+
dim = kwargs.get("dim", args[1] if len(args) > 1 else 0)
43+
if dim != 0:
44+
raise NotImplementedError("Only torch.cat(..., dim=0) is supported as an indexing transform")
45+
46+
tensors = args[0]
47+
merged_optrace = {}
48+
for tensor in tensors:
49+
if isinstance(tensor, torch.Tensor) and hasattr(tensor, 'optrace'):
50+
merged_optrace.update(tensor.optrace)
51+
52+
merged_optrace[id(result)] = ("index_cat", dim, tuple(tensors))
53+
result.optrace = merged_optrace
3954
elif func in WHITELISTED_MAPS:
4055
merged_optrace = {}
4156
for arg in args:

tests/autograd/test_graph_jacobian.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ def forward(
2626
return (a + b) - obs
2727

2828

29+
class ToyResidualCat(nn.Module):
30+
def __init__(self, A: torch.Tensor, B: torch.Tensor):
31+
super().__init__()
32+
self.A = nn.Parameter(Track(A))
33+
self.B = nn.Parameter(Track(B))
34+
35+
def forward(
36+
self,
37+
obs1: torch.Tensor,
38+
obs2: torch.Tensor,
39+
idx_a: torch.Tensor,
40+
idx_b: torch.Tensor,
41+
sel1: torch.Tensor,
42+
sel2: torch.Tensor,
43+
) -> torch.Tensor:
44+
a1 = self.A[idx_a][sel1]
45+
b1 = self.B[idx_b][sel1]
46+
r1 = (a1 + b1) - obs1[sel1]
47+
48+
a2 = self.A[idx_a][sel2]
49+
b2 = self.B[idx_b][sel2]
50+
r2 = (a2 + b2) - obs2[sel2]
51+
return torch.cat([r1, r2], dim=0)
52+
53+
2954
def _flatten_jac(J: torch.Tensor) -> torch.Tensor:
3055
n, outdim, num, indim = J.shape
3156
return J.reshape(n * outdim, num * indim)
@@ -71,3 +96,51 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
7196

7297
assert torch.equal(J_sparse[0].col_indices(), idx_a[sel])
7398
assert torch.equal(J_sparse[1].col_indices(), idx_b[sel])
99+
100+
101+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
102+
def test_sparse_jacobian_supports_cat_dim0(device: str):
103+
if device == "cuda" and not torch.cuda.is_available():
104+
pytest.skip("CUDA not available")
105+
106+
torch.manual_seed(0)
107+
dtype = torch.float64
108+
109+
num_a, num_b = 5, 6
110+
n = 9
111+
dim = 3
112+
113+
A0 = torch.randn(num_a, dim, device=device, dtype=dtype, requires_grad=True)
114+
B0 = torch.randn(num_b, dim, device=device, dtype=dtype, requires_grad=True)
115+
obs1 = torch.randn(n, dim, device=device, dtype=dtype)
116+
obs2 = torch.randn(n, dim, device=device, dtype=dtype)
117+
118+
idx_a = torch.randint(0, num_a, (n,), device=device, dtype=torch.int32)
119+
idx_b = torch.randint(0, num_b, (n,), device=device, dtype=torch.int32)
120+
sel1 = torch.tensor([0, 2, 5, 6], device=device, dtype=torch.int32)
121+
sel2 = torch.tensor([1, 3, 4, 8], device=device, dtype=torch.int32)
122+
123+
model = ToyResidualCat(A0, B0)
124+
out = model(obs1, obs2, idx_a, idx_b, sel1, sel2)
125+
126+
J_sparse = sparse_jacobian(out, [model.A, model.B])
127+
assert len(J_sparse) == 2
128+
assert all(j.layout == torch.sparse_bsr for j in J_sparse)
129+
130+
def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
131+
a1 = A[idx_a][sel1]
132+
b1 = B[idx_b][sel1]
133+
r1 = (a1 + b1) - obs1[sel1]
134+
135+
a2 = A[idx_a][sel2]
136+
b2 = B[idx_b][sel2]
137+
r2 = (a2 + b2) - obs2[sel2]
138+
return torch.cat([r1, r2], dim=0)
139+
140+
JA, JB = jacrev(f, argnums=(0, 1))(A0, B0)
141+
142+
torch.testing.assert_close(J_sparse[0].to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10)
143+
torch.testing.assert_close(J_sparse[1].to_dense(), _flatten_jac(JB), rtol=1e-10, atol=1e-10)
144+
145+
assert torch.equal(J_sparse[0].col_indices(), torch.cat([idx_a[sel1], idx_a[sel2]], dim=0))
146+
assert torch.equal(J_sparse[1].col_indices(), torch.cat([idx_b[sel1], idx_b[sel2]], dim=0))

0 commit comments

Comments
 (0)