@@ -114,9 +114,11 @@ def forward(
114114 obs_b : torch .Tensor ,
115115 idx_a : torch .Tensor ,
116116 idx_b : torch .Tensor ,
117+ mul_a : torch .Tensor ,
118+ mul_b : torch .Tensor ,
117119 ) -> torch .Tensor :
118- ra = self .A [idx_a ] - obs_a
119- rb = self .B [idx_b ] - obs_b
120+ ra = ( self .A [idx_a ] - obs_a ) * mul_a
121+ rb = ( self .B [idx_b ] - obs_b ) * mul_b
120122 return torch .cat ([ra , rb ], dim = 0 )
121123
122124
@@ -140,14 +142,17 @@ def test_sparse_jacobian_cat_dim0_matches_torch_jacrev(device: str):
140142 idx_a = torch .randint (0 , num_a , (n_a ,), device = device , dtype = torch .int32 )
141143 idx_b = torch .randint (0 , num_b , (n_b ,), device = device , dtype = torch .int32 )
142144
145+ mul_a = torch .rand (n_a , dim , device = device , dtype = dtype ) + 0.5
146+ mul_b = torch .rand (n_b , dim , device = device , dtype = dtype ) + 0.5
147+
143148 model = CatResidual (A0 , B0 )
144- out = model (obs_a , obs_b , idx_a , idx_b )
149+ out = model (obs_a , obs_b , idx_a , idx_b , mul_a , mul_b )
145150
146151 JA_sparse , JB_sparse = sparse_jacobian (out , [model .A , model .B ])
147152
148153 def f (A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
149- ra = A [idx_a ] - obs_a
150- rb = B [idx_b ] - obs_b
154+ ra = ( A [idx_a ] - obs_a ) * mul_a
155+ rb = ( B [idx_b ] - obs_b ) * mul_b
151156 return torch .cat ([ra , rb ], dim = 0 )
152157
153158 JA , JB = jacrev (f , argnums = (0 , 1 ))(A0 , B0 )
@@ -175,8 +180,10 @@ def forward(
175180 obs_b : torch .Tensor ,
176181 idx_a : torch .Tensor ,
177182 idx_b : torch .Tensor ,
183+ mul_a : torch .Tensor ,
184+ mul_b : torch .Tensor ,
178185 ) -> torch .Tensor :
179- pred = torch .cat ([self .A [idx_a ], self .B [idx_b ]], dim = 0 )
186+ pred = torch .cat ([self .A [idx_a ] * mul_a , self .B [idx_b ] * mul_b ], dim = 0 )
180187 obs = torch .cat ([obs_a , obs_b ], dim = 0 )
181188 return pred - obs
182189
@@ -201,13 +208,16 @@ def test_sparse_jacobian_cat_minus_cat_matches_torch_jacrev(device: str):
201208 idx_a = torch .randint (0 , num_a , (n_a ,), device = device , dtype = torch .int32 )
202209 idx_b = torch .randint (0 , num_b , (n_b ,), device = device , dtype = torch .int32 )
203210
211+ mul_a = torch .rand (n_a , dim , device = device , dtype = dtype ) + 0.5
212+ mul_b = torch .rand (n_b , dim , device = device , dtype = dtype ) + 0.5
213+
204214 model = CatSubResidual (A0 , B0 )
205- out = model (obs_a , obs_b , idx_a , idx_b )
215+ out = model (obs_a , obs_b , idx_a , idx_b , mul_a , mul_b )
206216
207217 JA_sparse , JB_sparse = sparse_jacobian (out , [model .A , model .B ])
208218
209219 def f (A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
210- pred = torch .cat ([A [idx_a ], B [idx_b ]], dim = 0 )
220+ pred = torch .cat ([A [idx_a ] * mul_a , B [idx_b ] * mul_b ], dim = 0 )
211221 obs = torch .cat ([obs_a , obs_b ], dim = 0 )
212222 return pred - obs
213223
0 commit comments