Skip to content

Commit

Permalink
Fix TransformationRobustness doc formatting & add missing RedirectedR…
Browse files Browse the repository at this point in the history
…eLU forward docs
  • Loading branch information
ProGamerGov authored Jul 6, 2022
1 parent 07c9e60 commit 953780e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
22 changes: 13 additions & 9 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._center_crop(x)


# Define TransformationRobustness defaults externally for easier Sphinx docs formatting
_TR_TRANSLATE: List[int] = [4] * 10
_TR_SCALE: List[float] = [0.995**n for n in range(-5, 80)] + [
0.998**n for n in 2 * list(range(20, 40))
]
_TR_DEGREES: List[int] = (
list(range(-20, 20)) + list(range(-10, 10)) + list(range(-5, 5)) + 5 * [0]
)


class TransformationRobustness(nn.Module):
"""
This transform combines the standard transforms (:class:`.RandomSpatialJitter`,
Expand All @@ -1269,15 +1279,9 @@ class TransformationRobustness(nn.Module):
def __init__(
self,
padding_transform: Optional[nn.Module] = nn.ConstantPad2d(2, value=0.5),
translate: Optional[Union[int, List[int]]] = [4] * 10,
scale: Optional[NumSeqOrTensorOrProbDistType] = [
0.995**n for n in range(-5, 80)
]
+ [0.998**n for n in 2 * list(range(20, 40))],
degrees: Optional[NumSeqOrTensorOrProbDistType] = list(range(-20, 20))
+ list(range(-10, 10))
+ list(range(-5, 5))
+ 5 * [0],
translate: Optional[Union[int, List[int]]] = _TR_TRANSLATE,
scale: Optional[NumSeqOrTensorOrProbDistType] = _TR_SCALE,
degrees: Optional[NumSeqOrTensorOrProbDistType] = _TR_DEGREES,
final_translate: Optional[int] = 2,
crop_or_pad_output: bool = False,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion captum/optim/_utils/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def extract_expanded_weights(
Args:
model (nn.Module): The reference to PyTorch model instance.
target1 (nn.module): The starting target layer. Must be below the layer
target1 (nn.Module): The starting target layer. Must be below the layer
specified for ``target2``.
target2 (nn.Module): The end target layer. Must be above the layer
specified for ``target1``.
Expand Down
8 changes: 8 additions & 0 deletions captum/optim/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ class RedirectedReluLayer(nn.Module):

@torch.jit.ignore
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): A tensor to pass through RedirectedReLU.
Returns:
x (torch.Tensor): The output of RedirectedReLU.
"""
return RedirectedReLU.apply(input)


Expand Down

0 comments on commit 953780e

Please sign in to comment.