Skip to content

Commit 41d3e2e

Browse files
committed
add the cat - cat test
1 parent 514adb2 commit 41d3e2e

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/autograd/test_graph_jacobian.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,64 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
161161
assert JB_sparse.crow_indices()[n_a].item() == 0
162162
assert JB_sparse.crow_indices()[-1].item() == n_b
163163
assert torch.equal(JB_sparse.col_indices(), idx_b)
164+
165+
166+
class CatSubResidual(nn.Module):
167+
def __init__(self, A: torch.Tensor, B: torch.Tensor):
168+
super().__init__()
169+
self.A = nn.Parameter(Track(A))
170+
self.B = nn.Parameter(Track(B))
171+
172+
def forward(
173+
self,
174+
obs_a: torch.Tensor,
175+
obs_b: torch.Tensor,
176+
idx_a: torch.Tensor,
177+
idx_b: torch.Tensor,
178+
) -> torch.Tensor:
179+
pred = torch.cat([self.A[idx_a], self.B[idx_b]], dim=0)
180+
obs = torch.cat([obs_a, obs_b], dim=0)
181+
return pred - obs
182+
183+
184+
@pytest.mark.parametrize("device", ["cpu", "cuda"])
185+
def test_sparse_jacobian_cat_minus_cat_matches_torch_jacrev(device: str):
186+
if device == "cuda" and not torch.cuda.is_available():
187+
pytest.skip("CUDA not available")
188+
189+
torch.manual_seed(0)
190+
dtype = torch.float64
191+
192+
num_a, num_b = 6, 8
193+
n_a, n_b = 5, 7
194+
dim = 3
195+
196+
A0 = torch.randn(num_a, dim, device=device, dtype=dtype)
197+
B0 = torch.randn(num_b, dim, device=device, dtype=dtype)
198+
obs_a = torch.randn(n_a, dim, device=device, dtype=dtype)
199+
obs_b = torch.randn(n_b, dim, device=device, dtype=dtype)
200+
201+
idx_a = torch.randint(0, num_a, (n_a,), device=device, dtype=torch.int32)
202+
idx_b = torch.randint(0, num_b, (n_b,), device=device, dtype=torch.int32)
203+
204+
model = CatSubResidual(A0, B0)
205+
out = model(obs_a, obs_b, idx_a, idx_b)
206+
207+
JA_sparse, JB_sparse = sparse_jacobian(out, [model.A, model.B])
208+
209+
def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
210+
pred = torch.cat([A[idx_a], B[idx_b]], dim=0)
211+
obs = torch.cat([obs_a, obs_b], dim=0)
212+
return pred - obs
213+
214+
JA, JB = jacrev(f, argnums=(0, 1))(A0, B0)
215+
torch.testing.assert_close(JA_sparse.to_dense(), _flatten_jac(JA), rtol=1e-10, atol=1e-10)
216+
torch.testing.assert_close(JB_sparse.to_dense(), _flatten_jac(JB), rtol=1e-10, atol=1e-10)
217+
218+
assert JA_sparse.crow_indices()[n_a].item() == n_a
219+
assert JA_sparse.crow_indices()[-1].item() == n_a
220+
assert torch.equal(JA_sparse.col_indices(), idx_a)
221+
222+
assert JB_sparse.crow_indices()[n_a].item() == 0
223+
assert JB_sparse.crow_indices()[-1].item() == n_b
224+
assert torch.equal(JB_sparse.col_indices(), idx_b)

0 commit comments

Comments
 (0)