@@ -26,6 +26,31 @@ def forward(
2626 return (a + b ) - obs
2727
2828
29+ class ToyResidualCat (nn .Module ):
30+ def __init__ (self , A : torch .Tensor , B : torch .Tensor ):
31+ super ().__init__ ()
32+ self .A = nn .Parameter (Track (A ))
33+ self .B = nn .Parameter (Track (B ))
34+
35+ def forward (
36+ self ,
37+ obs1 : torch .Tensor ,
38+ obs2 : torch .Tensor ,
39+ idx_a : torch .Tensor ,
40+ idx_b : torch .Tensor ,
41+ sel1 : torch .Tensor ,
42+ sel2 : torch .Tensor ,
43+ ) -> torch .Tensor :
44+ a1 = self .A [idx_a ][sel1 ]
45+ b1 = self .B [idx_b ][sel1 ]
46+ r1 = (a1 + b1 ) - obs1 [sel1 ]
47+
48+ a2 = self .A [idx_a ][sel2 ]
49+ b2 = self .B [idx_b ][sel2 ]
50+ r2 = (a2 + b2 ) - obs2 [sel2 ]
51+ return torch .cat ([r1 , r2 ], dim = 0 )
52+
53+
2954def _flatten_jac (J : torch .Tensor ) -> torch .Tensor :
3055 n , outdim , num , indim = J .shape
3156 return J .reshape (n * outdim , num * indim )
@@ -71,3 +96,51 @@ def f(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
7196
7297 assert torch .equal (J_sparse [0 ].col_indices (), idx_a [sel ])
7398 assert torch .equal (J_sparse [1 ].col_indices (), idx_b [sel ])
99+
100+
101+ @pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
102+ def test_sparse_jacobian_supports_cat_dim0 (device : str ):
103+ if device == "cuda" and not torch .cuda .is_available ():
104+ pytest .skip ("CUDA not available" )
105+
106+ torch .manual_seed (0 )
107+ dtype = torch .float64
108+
109+ num_a , num_b = 5 , 6
110+ n = 9
111+ dim = 3
112+
113+ A0 = torch .randn (num_a , dim , device = device , dtype = dtype , requires_grad = True )
114+ B0 = torch .randn (num_b , dim , device = device , dtype = dtype , requires_grad = True )
115+ obs1 = torch .randn (n , dim , device = device , dtype = dtype )
116+ obs2 = torch .randn (n , dim , device = device , dtype = dtype )
117+
118+ idx_a = torch .randint (0 , num_a , (n ,), device = device , dtype = torch .int32 )
119+ idx_b = torch .randint (0 , num_b , (n ,), device = device , dtype = torch .int32 )
120+ sel1 = torch .tensor ([0 , 2 , 5 , 6 ], device = device , dtype = torch .int32 )
121+ sel2 = torch .tensor ([1 , 3 , 4 , 8 ], device = device , dtype = torch .int32 )
122+
123+ model = ToyResidualCat (A0 , B0 )
124+ out = model (obs1 , obs2 , idx_a , idx_b , sel1 , sel2 )
125+
126+ J_sparse = sparse_jacobian (out , [model .A , model .B ])
127+ assert len (J_sparse ) == 2
128+ assert all (j .layout == torch .sparse_bsr for j in J_sparse )
129+
130+ def f (A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
131+ a1 = A [idx_a ][sel1 ]
132+ b1 = B [idx_b ][sel1 ]
133+ r1 = (a1 + b1 ) - obs1 [sel1 ]
134+
135+ a2 = A [idx_a ][sel2 ]
136+ b2 = B [idx_b ][sel2 ]
137+ r2 = (a2 + b2 ) - obs2 [sel2 ]
138+ return torch .cat ([r1 , r2 ], dim = 0 )
139+
140+ JA , JB = jacrev (f , argnums = (0 , 1 ))(A0 , B0 )
141+
142+ torch .testing .assert_close (J_sparse [0 ].to_dense (), _flatten_jac (JA ), rtol = 1e-10 , atol = 1e-10 )
143+ torch .testing .assert_close (J_sparse [1 ].to_dense (), _flatten_jac (JB ), rtol = 1e-10 , atol = 1e-10 )
144+
145+ assert torch .equal (J_sparse [0 ].col_indices (), torch .cat ([idx_a [sel1 ], idx_a [sel2 ]], dim = 0 ))
146+ assert torch .equal (J_sparse [1 ].col_indices (), torch .cat ([idx_b [sel1 ], idx_b [sel2 ]], dim = 0 ))
0 commit comments