|
112 | 112 | "model": BasicModel_MultiLayer(), |
113 | 113 | "attribute_args": {"inputs": torch.randn(4, 3), "target": 1}, |
114 | 114 | }, |
115 | | - { |
116 | | - "name": "basic_single_target_cross_tensor_attributions", |
117 | | - "algorithms": [ |
118 | | - FeatureAblation, |
119 | | - FeaturePermutation, |
120 | | - ], |
121 | | - "model": BasicModel_MultiLayer(), |
122 | | - "attribute_args": { |
123 | | - "inputs": torch.randn(4, 3), |
124 | | - "target": 1, |
125 | | - "enable_cross_tensor_attribution": True, |
126 | | - }, |
127 | | - }, |
128 | 115 | { |
129 | 116 | "name": "basic_multi_input", |
130 | 117 | "algorithms": [ |
|
192 | 179 | }, |
193 | 180 | "dp_delta": 0.0005, |
194 | 181 | }, |
195 | | - { |
196 | | - "name": "basic_multi_input_multi_target_cross_tensor_attributions", |
197 | | - "algorithms": [ |
198 | | - FeatureAblation, |
199 | | - FeaturePermutation, |
200 | | - ], |
201 | | - "model": BasicModel_MultiLayer_MultiInput(), |
202 | | - "attribute_args": { |
203 | | - "inputs": (10 * torch.randn(6, 3), 5 * torch.randn(6, 3)), |
204 | | - "additional_forward_args": (2 * torch.randn(6, 3), 5), |
205 | | - "target": [0, 1, 1, 0, 0, 1], |
206 | | - "enable_cross_tensor_attribution": True, |
207 | | - }, |
208 | | - "dp_delta": 0.0005, |
209 | | - }, |
210 | 182 | { |
211 | 183 | "name": "basic_multiple_tuple_target", |
212 | 184 | "algorithms": [ |
|
230 | 202 | "additional_forward_args": (None, True), |
231 | 203 | }, |
232 | 204 | }, |
233 | | - { |
234 | | - "name": "basic_multiple_tuple_target_cross_tensor_attributions", |
235 | | - "algorithms": [ |
236 | | - FeatureAblation, |
237 | | - FeaturePermutation, |
238 | | - ], |
239 | | - "model": BasicModel_MultiLayer(), |
240 | | - "attribute_args": { |
241 | | - "inputs": torch.randn(4, 3), |
242 | | - "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], |
243 | | - "additional_forward_args": (None, True), |
244 | | - "enable_cross_tensor_attribution": True, |
245 | | - }, |
246 | | - }, |
247 | 205 | { |
248 | 206 | "name": "basic_tensor_single_target", |
249 | 207 | "algorithms": [ |
|
285 | 243 | "target": torch.tensor([1, 1, 0, 0]), |
286 | 244 | }, |
287 | 245 | }, |
288 | | - { |
289 | | - "name": "basic_tensor_multi_target_cross_tensor_attributions", |
290 | | - "algorithms": [ |
291 | | - FeatureAblation, |
292 | | - FeaturePermutation, |
293 | | - ], |
294 | | - "model": BasicModel_MultiLayer(), |
295 | | - "attribute_args": { |
296 | | - "inputs": torch.randn(4, 3), |
297 | | - "target": torch.tensor([1, 1, 0, 0]), |
298 | | - "enable_cross_tensor_attribution": True, |
299 | | - }, |
300 | | - }, |
301 | 246 | # Primary Configs with Baselines |
302 | 247 | { |
303 | 248 | "name": "basic_multiple_tuple_target_with_baselines", |
|
317 | 262 | "additional_forward_args": (None, True), |
318 | 263 | }, |
319 | 264 | }, |
320 | | - { |
321 | | - "name": "basic_multiple_tuple_target_with_baselines_cross_tensor_attributions", |
322 | | - "algorithms": [ |
323 | | - FeatureAblation, |
324 | | - ], |
325 | | - "model": BasicModel_MultiLayer(), |
326 | | - "attribute_args": { |
327 | | - "inputs": torch.randn(4, 3), |
328 | | - "baselines": 0.5 * torch.randn(4, 3), |
329 | | - "target": [(1, 0, 0), (0, 1, 1), (1, 1, 1), (0, 0, 0)], |
330 | | - "additional_forward_args": (None, True), |
331 | | - "enable_cross_tensor_attribution": True, |
332 | | - }, |
333 | | - }, |
334 | 265 | { |
335 | 266 | "name": "basic_tensor_single_target_with_baselines", |
336 | 267 | "algorithms": [ |
|
348 | 279 | "target": torch.tensor([0]), |
349 | 280 | }, |
350 | 281 | }, |
351 | | - { |
352 | | - "name": "basic_tensor_single_target_with_baselines_cross_tensor_attributions", |
353 | | - "algorithms": [ |
354 | | - FeatureAblation, |
355 | | - ], |
356 | | - "model": BasicModel_MultiLayer(), |
357 | | - "attribute_args": { |
358 | | - "inputs": torch.randn(4, 3), |
359 | | - "baselines": 0.5 * torch.randn(4, 3), |
360 | | - "target": torch.tensor([0]), |
361 | | - "enable_cross_tensor_attribution": True, |
362 | | - }, |
363 | | - }, |
364 | 282 | # Primary Configs with Internal Batching |
365 | 283 | { |
366 | 284 | "name": "basic_multiple_tuple_target_with_internal_batching", |
|
0 commit comments