Skip to content

Commit 3e58f22

Browse files
Fix multiple bugs in IO mixins (#1193)
**Pull Request Checklist** - [X] Fixes #1190, fixes #1165 - [X] Tests added - [ ] Documentation/examples added - [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR title **Description of PR** Currently, the methods that consume annotations in InputMixin and OutputMixin have multiple bugs: - ignoring fields with a workflow annotation with no name - ignoring fields with a Pydantic annotation and no workflow annotation - `OutputMixin._get_outputs` additionally was: * mutating the original annotations without copying them * ignoring the model default if a workflow annotation was present but had no default This PR pulls out a common function to iterate through fields and yield annotations, which: - copies the annotations so they cannot be accidentally changed - ignores unrecognized annotations (this is a requirement of the Annotated spec, and fixes the issue with Pydantic annotations) - defaults the annotation name to the field name if unset This ensures all functions treat an explicit Parameter and a missing workflow annotation the same way, solving the second issue with `OutputMixin._get_outputs`. --------- Signed-off-by: Alice Purcell <alicederyn@gmail.com> Signed-off-by: Elliot Gunton <elliotgunton@gmail.com> Co-authored-by: Elliot Gunton <elliotgunton@gmail.com>
1 parent af5b19f commit 3e58f22

2 files changed

Lines changed: 618 additions & 93 deletions

File tree

Lines changed: 53 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import sys
22
import warnings
3-
from typing import TYPE_CHECKING, List, Optional, Union
3+
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Type, Union
44

55
if sys.version_info >= (3, 11):
66
from typing import Self
77
else:
88
from typing_extensions import Self
99

10+
from pydantic.fields import FieldInfo
11+
1012
from hera.shared._pydantic import _PYDANTIC_VERSION, get_field_annotations, get_fields
11-
from hera.shared._type_util import get_workflow_annotation, is_annotated
13+
from hera.shared._type_util import get_workflow_annotation
1214
from hera.shared.serialization import MISSING, serialize
1315
from hera.workflows._context import _context
1416
from hera.workflows.artifact import Artifact
@@ -39,6 +41,23 @@
3941
BaseModel = object # type: ignore
4042

4143

44+
def _construct_io_from_fields(cls: Type[BaseModel]) -> Iterator[Tuple[str, FieldInfo, Union[Parameter, Artifact]]]:
45+
"""Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations.
46+
47+
If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing.
48+
Otherwise, a Parameter object will be constructed.
49+
"""
50+
annotations = get_field_annotations(cls)
51+
for field, field_info in get_fields(cls).items():
52+
if annotation := get_workflow_annotation(annotations[field]):
53+
# Copy so as to not modify the fields themselves
54+
annotation_copy = annotation.copy()
55+
annotation_copy.name = annotation.name or field
56+
yield field, field_info, annotation_copy
57+
else:
58+
yield field, field_info, Parameter(name=field)
59+
60+
4261
class InputMixin(BaseModel):
4362
def __new__(cls, **kwargs):
4463
if _context.declaring:
@@ -62,14 +81,9 @@ def __init__(self, /, **kwargs):
6281
@classmethod
6382
def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Parameter]:
6483
parameters = []
65-
annotations = get_field_annotations(cls)
6684

67-
for field, field_info in get_fields(cls).items():
68-
if (param := get_workflow_annotation(annotations[field])) and isinstance(param, Parameter):
69-
# Copy so as to not modify the Input fields themselves
70-
param = param.copy()
71-
if param.name is None:
72-
param.name = field
85+
for field, field_info, param in _construct_io_from_fields(cls):
86+
if isinstance(param, Parameter):
7387
if param.default is not None:
7488
warnings.warn(
7589
"Using the default field for Parameters in Annotations is deprecated since v5.16"
@@ -81,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
8195
# Serialize the value (usually done in Parameter's validator)
8296
param.default = serialize(field_info.default) # type: ignore
8397
parameters.append(param)
84-
elif not is_annotated(annotations[field]):
85-
# Create a Parameter from basic type annotations
86-
default = getattr(object_override, field) if object_override else field_info.default
87-
88-
# For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None
89-
if default is None or default == PydanticUndefined:
90-
default = MISSING
91-
92-
parameters.append(Parameter(name=field, default=default))
9398

9499
return parameters
95100

96101
@classmethod
97102
def _get_artifacts(cls) -> List[Artifact]:
98103
artifacts = []
99-
annotations = get_field_annotations(cls)
100104

101-
for field in get_fields(cls):
102-
if (artifact := get_workflow_annotation(annotations[field])) and isinstance(artifact, Artifact):
103-
# Copy so as to not modify the Input fields themselves
104-
artifact = artifact.copy()
105-
if artifact.name is None:
106-
artifact.name = field
105+
for _, _, artifact in _construct_io_from_fields(cls):
106+
if isinstance(artifact, Artifact):
107107
if artifact.path is None:
108108
artifact.path = artifact._get_default_inputs_path()
109109
artifacts.append(artifact)
@@ -117,42 +117,33 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
117117
def _get_as_templated_arguments(cls) -> Self:
118118
"""Returns the Input with templated values to propagate through a DAG/Steps function."""
119119
object_dict = {}
120-
cls_fields = get_fields(cls)
121-
annotations = get_field_annotations(cls)
122120

123-
for field in cls_fields:
124-
if param_or_artifact := get_workflow_annotation(annotations[field]):
125-
if isinstance(param_or_artifact, Parameter):
126-
object_dict[field] = "{{inputs.parameters." + f"{param_or_artifact.name}" + "}}"
127-
else:
128-
object_dict[field] = "{{inputs.artifacts." + f"{param_or_artifact.name}" + "}}"
129-
elif not is_annotated(annotations[field]):
130-
object_dict[field] = "{{inputs.parameters." + f"{field}" + "}}"
121+
for field, _, annotation in _construct_io_from_fields(cls):
122+
input_type = "parameters" if isinstance(annotation, Parameter) else "artifacts"
123+
object_dict[field] = "{{" + f"inputs.{input_type}.{annotation.name}" + "}}"
131124

132125
return cls.construct(None, **object_dict)
133126

134127
def _get_as_arguments(self) -> ModelArguments:
135128
params = []
136129
artifacts = []
137-
annotations = get_field_annotations(type(self))
138130

139131
if isinstance(self, V1BaseModel):
140132
self_dict = self.dict()
141133
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
142134
self_dict = self.model_dump()
143135

144-
for field in get_fields(type(self)):
136+
for field, _, annotation in _construct_io_from_fields(type(self)):
145137
# The value may be a static value (of any time) if it has a default value, so we need to serialize it
146138
# If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
147139
templated_value = serialize(self_dict[field])
140+
name = annotation.name
141+
assert name is not None # guaranteed by _get_workflow_annotations
148142

149-
if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
150-
if isinstance(param_or_artifact, Parameter):
151-
params.append(ModelParameter(name=param_or_artifact.name, value=templated_value))
152-
else:
153-
artifacts.append(ModelArtifact(name=param_or_artifact.name, from_=templated_value))
154-
elif not is_annotated(annotations[field]):
155-
params.append(ModelParameter(name=field, value=templated_value))
143+
if isinstance(annotation, Parameter):
144+
params.append(ModelParameter(name=name, value=templated_value))
145+
else:
146+
artifacts.append(ModelArtifact(name=name, from_=templated_value))
156147

157148
return ModelArguments(parameters=params or None, artifacts=artifacts or None)
158149

@@ -178,37 +169,22 @@ def __init__(self, /, **kwargs):
178169
@classmethod
179170
def _get_outputs(cls, add_missing_path: bool = False) -> List[Union[Artifact, Parameter]]:
180171
outputs: List[Union[Artifact, Parameter]] = []
181-
annotations = get_field_annotations(cls)
182-
183-
model_fields = get_fields(cls)
184172

185-
for field in model_fields:
173+
for field, field_info, annotation in _construct_io_from_fields(cls):
186174
if field in {"exit_code", "result"}:
187175
continue
188-
if param_or_artifact := get_workflow_annotation(annotations[field]):
189-
if isinstance(param_or_artifact, Parameter):
190-
if add_missing_path and (
191-
param_or_artifact.value_from is None or param_or_artifact.value_from.path is None
192-
):
193-
param_or_artifact.value_from = ValueFrom(
194-
path=f"/tmp/hera-outputs/parameters/{param_or_artifact.name}"
195-
)
196-
outputs.append(param_or_artifact)
197-
else:
198-
if add_missing_path and param_or_artifact.path is None:
199-
param_or_artifact.path = f"/tmp/hera-outputs/artifacts/{param_or_artifact.name}"
200-
outputs.append(param_or_artifact)
201-
elif not is_annotated(annotations[field]):
202-
# Create a Parameter from basic type annotations
203-
default = model_fields[field].default
204-
if default is None or default == PydanticUndefined:
205-
default = MISSING
206-
207-
value_from = None
208-
if add_missing_path:
209-
value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{field}")
210-
211-
outputs.append(Parameter(name=field, default=default, value_from=value_from))
176+
if isinstance(annotation, Parameter):
177+
if annotation.default is None:
178+
default = field_info.default
179+
if default is not None and default != PydanticUndefined:
180+
annotation.default = serialize(default)
181+
182+
if add_missing_path and (annotation.value_from is None or annotation.value_from.path is None):
183+
annotation.value_from = ValueFrom(path=f"/tmp/hera-outputs/parameters/{annotation.name}")
184+
else:
185+
if add_missing_path and annotation.path is None:
186+
annotation.path = f"/tmp/hera-outputs/artifacts/{annotation.name}"
187+
outputs.append(annotation)
212188
return outputs
213189

214190
@classmethod
@@ -230,27 +206,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]:
230206
This lets dags and steps hoist task/step outputs into its own outputs.
231207
"""
232208
outputs: List[Union[Artifact, Parameter]] = []
233-
annotations = get_field_annotations(type(self))
234209

235210
if isinstance(self, V1BaseModel):
236211
self_dict = self.dict()
237212
elif _PYDANTIC_VERSION == 2 and isinstance(self, V2BaseModel):
238213
self_dict = self.model_dump()
239214

240-
for field in get_fields(type(self)):
215+
for field, _, annotation in _construct_io_from_fields(type(self)):
241216
if field in {"exit_code", "result"}:
242217
continue
243218

244219
templated_value = self_dict[field] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"`
245220

246-
if (param_or_artifact := get_workflow_annotation(annotations[field])) and param_or_artifact.name:
247-
if isinstance(param_or_artifact, Parameter):
248-
outputs.append(
249-
Parameter(name=param_or_artifact.name, value_from=ValueFrom(parameter=templated_value))
250-
)
251-
else:
252-
outputs.append(Artifact(name=param_or_artifact.name, from_=templated_value))
253-
elif not is_annotated(annotations[field]):
254-
outputs.append(Parameter(name=field, value_from=ValueFrom(parameter=templated_value)))
221+
if isinstance(annotation, Parameter):
222+
outputs.append(Parameter(name=annotation.name, value_from=ValueFrom(parameter=templated_value)))
223+
else:
224+
outputs.append(Artifact(name=annotation.name, from_=templated_value))
255225

256226
return outputs

0 commit comments

Comments
 (0)