Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions timm/layers/create_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,16 @@


def get_act_fn(name: Optional[LayerType] = 'relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""Fetch an activation function by name.

Returns export/script-friendly or memory-efficient activation functions
dynamically based on current config.

Args:
name: Activation function name, callable, or None.

Returns:
Activation function or None.
"""
if not name:
return None
Expand All @@ -108,9 +115,16 @@ def get_act_fn(name: Optional[LayerType] = 'relu'):


def get_act_layer(name: Optional[LayerType] = 'relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""Fetch an activation layer by name.

Returns export/script-friendly or memory-efficient activation layers
dynamically based on current config.

Args:
name: Activation layer name, type, callable, or None.

Returns:
Activation layer class or None.
"""
if name is None:
return None
Expand All @@ -131,6 +145,19 @@ def create_act_layer(
inplace: Optional[bool] = None,
**kwargs
):
"""Create an activation layer instance by name.

Handles inplace argument for activations that support it, gracefully
falling back for those that don't.

Args:
name: Activation layer name or type.
inplace: Enable inplace operation if supported by the activation.
**kwargs: Additional arguments passed to activation layer.

Returns:
Instantiated activation layer or None.
"""
act_layer = get_act_layer(name)
if act_layer is None:
return None
Expand Down