Skip to content

Commit c3c285d

Browse files
Marcio Portofacebook-github-bot
Marcio Porto
authored andcommitted
Move curent_mask to perturbed tensor device
Summary: Currently `FeaturePermutation` and `FeatureAblation` both throw a device mismatch issue in https://fburl.com/code/9mfuidf4 because the `current_mask` is always created on CPU and never moved to the same device as `expanded_input` when CUDA is available. Reviewed By: cyrjano Differential Revision: D54969675
1 parent e9f43bd commit c3c285d

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

captum/attr/_core/feature_ablation.py

+1
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def _construct_ablated_input(
559559
current_mask = torch.stack(
560560
[input_mask == j for j in range(start_feature, end_feature)], dim=0
561561
).long()
562+
current_mask = current_mask.to(expanded_input.device)
562563
ablated_tensor = (
563564
expanded_input * (1 - current_mask).to(expanded_input.dtype)
564565
) + (baseline * current_mask.to(expanded_input.dtype))

captum/attr/_core/feature_permutation.py

+1
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def _construct_ablated_input(
301301
current_mask = torch.stack(
302302
[input_mask == j for j in range(start_feature, end_feature)], dim=0
303303
).bool()
304+
current_mask = current_mask.to(expanded_input.device)
304305

305306
output = torch.stack(
306307
[

0 commit comments

Comments
 (0)