@@ -161,3 +161,64 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
161161 assert JB_sparse .crow_indices ()[n_a ].item () == 0
162162 assert JB_sparse .crow_indices ()[- 1 ].item () == n_b
163163 assert torch .equal (JB_sparse .col_indices (), idx_b )
164+
165+
166+ class CatSubResidual (nn .Module ):
167+ def __init__ (self , A : torch .Tensor , B : torch .Tensor ):
168+ super ().__init__ ()
169+ self .A = nn .Parameter (Track (A ))
170+ self .B = nn .Parameter (Track (B ))
171+
172+ def forward (
173+ self ,
174+ obs_a : torch .Tensor ,
175+ obs_b : torch .Tensor ,
176+ idx_a : torch .Tensor ,
177+ idx_b : torch .Tensor ,
178+ ) -> torch .Tensor :
179+ pred = torch .cat ([self .A [idx_a ], self .B [idx_b ]], dim = 0 )
180+ obs = torch .cat ([obs_a , obs_b ], dim = 0 )
181+ return pred - obs
182+
183+
184+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
185+ def test_sparse_jacobian_cat_minus_cat_matches_torch_jacrev (device : str ):
186+ if device == "cuda" and not torch .cuda .is_available ():
187+ pytest .skip ("CUDA not available" )
188+
189+ torch .manual_seed (0 )
190+ dtype = torch .float64
191+
192+ num_a , num_b = 6 , 8
193+ n_a , n_b = 5 , 7
194+ dim = 3
195+
196+ A0 = torch .randn (num_a , dim , device = device , dtype = dtype )
197+ B0 = torch .randn (num_b , dim , device = device , dtype = dtype )
198+ obs_a = torch .randn (n_a , dim , device = device , dtype = dtype )
199+ obs_b = torch .randn (n_b , dim , device = device , dtype = dtype )
200+
201+ idx_a = torch .randint (0 , num_a , (n_a ,), device = device , dtype = torch .int32 )
202+ idx_b = torch .randint (0 , num_b , (n_b ,), device = device , dtype = torch .int32 )
203+
204+ model = CatSubResidual (A0 , B0 )
205+ out = model (obs_a , obs_b , idx_a , idx_b )
206+
207+ JA_sparse , JB_sparse = sparse_jacobian (out , [model .A , model .B ])
208+
209+ def f (A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
210+ pred = torch .cat ([A [idx_a ], B [idx_b ]], dim = 0 )
211+ obs = torch .cat ([obs_a , obs_b ], dim = 0 )
212+ return pred - obs
213+
214+ JA , JB = jacrev (f , argnums = (0 , 1 ))(A0 , B0 )
215+ torch .testing .assert_close (JA_sparse .to_dense (), _flatten_jac (JA ), rtol = 1e-10 , atol = 1e-10 )
216+ torch .testing .assert_close (JB_sparse .to_dense (), _flatten_jac (JB ), rtol = 1e-10 , atol = 1e-10 )
217+
218+ assert JA_sparse .crow_indices ()[n_a ].item () == n_a
219+ assert JA_sparse .crow_indices ()[- 1 ].item () == n_a
220+ assert torch .equal (JA_sparse .col_indices (), idx_a )
221+
222+ assert JB_sparse .crow_indices ()[n_a ].item () == 0
223+ assert JB_sparse .crow_indices ()[- 1 ].item () == n_b
224+ assert torch .equal (JB_sparse .col_indices (), idx_b )
0 commit comments