Skip to content

Commit 8f579f0

Browse files
[ParallelRunStep] Enable mini_batch_size and retry_settings bind to pipeline input. (#33500)
* allow prs run settings binding to literal input * v1 * revert * revert * revert e2e test * reformat --------- Co-authored-by: Xiaole Wen <[email protected]>
1 parent caf2f97 commit 8f579f0

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

Diff for: sdk/ml/azure-ai-ml/azure/ai/ml/_schema/component/retry_settings.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from marshmallow import fields
66

77
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
8+
from azure.ai.ml._schema.core.fields import DataBindingStr, UnionField
89

910

1011
class RetrySettingsSchema(metaclass=PatchedSchemaMeta):
11-
timeout = fields.Int()
12-
max_retries = fields.Int()
12+
timeout = UnionField([fields.Int(), DataBindingStr])
13+
max_retries = UnionField([fields.Int(), DataBindingStr])

Diff for: sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/parallel.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .._job.parallel.retry_settings import RetrySettings
3434
from .._job.pipeline._io import NodeOutput, NodeWithGroupInputMixin
3535
from .._util import convert_ordered_dict_to_dict, get_rest_dict_for_node_attrs, validate_attribute_type
36+
from ..._utils.utils import is_data_binding_expression
3637
from .base_node import BaseNode
3738

3839
module_logger = logging.getLogger(__name__)
@@ -166,7 +167,11 @@ def __init__(
166167

167168
self._task = task
168169

169-
if mini_batch_size is not None and not isinstance(mini_batch_size, int):
170+
if (
171+
mini_batch_size is not None
172+
and not isinstance(mini_batch_size, int)
173+
and not is_data_binding_expression(mini_batch_size)
174+
):
170175
"""Convert str to int.""" # pylint: disable=pointless-string-statement
171176
pattern = re.compile(r"^\d+([kKmMgG][bB])*$")
172177
if not pattern.match(mini_batch_size):

Diff for: sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/parallel/retry_settings.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class RetrySettings(RestTranslatableMixin, DictMixin):
2727
def __init__(
2828
self,
2929
*,
30-
timeout: Optional[int] = None,
31-
max_retries: Optional[int] = None,
30+
timeout: Optional[Union[int, str]] = None,
31+
max_retries: Optional[Union[int, str]] = None,
3232
**kwargs, # pylint: disable=unused-argument
3333
):
3434
self.timeout = timeout

0 commit comments

Comments
 (0)