Skip to content

Commit 83174b6

Browse files
committed
[tests] perturb jac values so that they are not always 1
1 parent 41d3e2e commit 83174b6

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

tests/autograd/test_graph_jacobian.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,11 @@ def forward(
114114
obs_b: torch.Tensor,
115115
idx_a: torch.Tensor,
116116
idx_b: torch.Tensor,
117+
mul_a: torch.Tensor,
118+
mul_b: torch.Tensor,
117119
) -> torch.Tensor:
118-
ra = self.A[idx_a] - obs_a
119-
rb = self.B[idx_b] - obs_b
120+
ra = (self.A[idx_a] - obs_a) * mul_a
121+
rb = (self.B[idx_b] - obs_b) * mul_b
120122
return torch.cat([ra, rb], dim=0)
121123

122124

@@ -140,14 +142,17 @@ def test_sparse_jacobian_cat_dim0_matches_torch_jacrev(device: str):
140142
idx_a = torch.randint(0, num_a, (n_a,), device=device, dtype=torch.int32)
141143
idx_b = torch.randint(0, num_b, (n_b,), device=device, dtype=torch.int32)
142144

145+
mul_a = torch.rand(n_a, dim, device=device, dtype=dtype) + 0.5
146+
mul_b = torch.rand(n_b, dim, device=device, dtype=dtype) + 0.5
147+
143148
model = CatResidual(A0, B0)
144-
out = model(obs_a, obs_b, idx_a, idx_b)
149+
out = model(obs_a, obs_b, idx_a, idx_b, mul_a, mul_b)
145150

146151
JA_sparse, JB_sparse = sparse_jacobian(out, [model.A, model.B])
147152

148153
def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
149-
ra = A[idx_a] - obs_a
150-
rb = B[idx_b] - obs_b
154+
ra = (A[idx_a] - obs_a) * mul_a
155+
rb = (B[idx_b] - obs_b) * mul_b
151156
return torch.cat([ra, rb], dim=0)
152157

153158
JA, JB = jacrev(f, argnums=(0, 1))(A0, B0)
@@ -175,8 +180,10 @@ def forward(
175180
obs_b: torch.Tensor,
176181
idx_a: torch.Tensor,
177182
idx_b: torch.Tensor,
183+
mul_a: torch.Tensor,
184+
mul_b: torch.Tensor,
178185
) -> torch.Tensor:
179-
pred = torch.cat([self.A[idx_a], self.B[idx_b]], dim=0)
186+
pred = torch.cat([self.A[idx_a] * mul_a, self.B[idx_b] * mul_b], dim=0)
180187
obs = torch.cat([obs_a, obs_b], dim=0)
181188
return pred - obs
182189

@@ -201,13 +208,16 @@ def test_sparse_jacobian_cat_minus_cat_matches_torch_jacrev(device: str):
201208
idx_a = torch.randint(0, num_a, (n_a,), device=device, dtype=torch.int32)
202209
idx_b = torch.randint(0, num_b, (n_b,), device=device, dtype=torch.int32)
203210

211+
mul_a = torch.rand(n_a, dim, device=device, dtype=dtype) + 0.5
212+
mul_b = torch.rand(n_b, dim, device=device, dtype=dtype) + 0.5
213+
204214
model = CatSubResidual(A0, B0)
205-
out = model(obs_a, obs_b, idx_a, idx_b)
215+
out = model(obs_a, obs_b, idx_a, idx_b, mul_a, mul_b)
206216

207217
JA_sparse, JB_sparse = sparse_jacobian(out, [model.A, model.B])
208218

209219
def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
210-
pred = torch.cat([A[idx_a], B[idx_b]], dim=0)
220+
pred = torch.cat([A[idx_a] * mul_a, B[idx_b] * mul_b], dim=0)
211221
obs = torch.cat([obs_a, obs_b], dim=0)
212222
return pred - obs
213223

0 commit comments

Comments
 (0)