11import sys
22import warnings
3- from typing import TYPE_CHECKING , List , Optional , Union
3+ from typing import TYPE_CHECKING , Iterator , List , Optional , Tuple , Type , Union
44
55if sys .version_info >= (3 , 11 ):
66 from typing import Self
77else :
88 from typing_extensions import Self
99
10+ from pydantic .fields import FieldInfo
11+
1012from hera .shared ._pydantic import _PYDANTIC_VERSION , get_field_annotations , get_fields
11- from hera .shared ._type_util import get_workflow_annotation , is_annotated
13+ from hera .shared ._type_util import get_workflow_annotation
1214from hera .shared .serialization import MISSING , serialize
1315from hera .workflows ._context import _context
1416from hera .workflows .artifact import Artifact
3941 BaseModel = object # type: ignore
4042
4143
44+ def _construct_io_from_fields (cls : Type [BaseModel ]) -> Iterator [Tuple [str , FieldInfo , Union [Parameter , Artifact ]]]:
45+ """Constructs a Parameter or Artifact object for all Pydantic fields based on their annotations.
46+
47+ If a field has a Parameter or Artifact annotation, a copy will be returned, with name added if missing.
48+ Otherwise, a Parameter object will be constructed.
49+ """
50+ annotations = get_field_annotations (cls )
51+ for field , field_info in get_fields (cls ).items ():
52+ if annotation := get_workflow_annotation (annotations [field ]):
53+ # Copy so as to not modify the fields themselves
54+ annotation_copy = annotation .copy ()
55+ annotation_copy .name = annotation .name or field
56+ yield field , field_info , annotation_copy
57+ else :
58+ yield field , field_info , Parameter (name = field )
59+
60+
4261class InputMixin (BaseModel ):
4362 def __new__ (cls , ** kwargs ):
4463 if _context .declaring :
@@ -62,14 +81,9 @@ def __init__(self, /, **kwargs):
6281 @classmethod
6382 def _get_parameters (cls , object_override : Optional [Self ] = None ) -> List [Parameter ]:
6483 parameters = []
65- annotations = get_field_annotations (cls )
6684
67- for field , field_info in get_fields (cls ).items ():
68- if (param := get_workflow_annotation (annotations [field ])) and isinstance (param , Parameter ):
69- # Copy so as to not modify the Input fields themselves
70- param = param .copy ()
71- if param .name is None :
72- param .name = field
85+ for field , field_info , param in _construct_io_from_fields (cls ):
86+ if isinstance (param , Parameter ):
7387 if param .default is not None :
7488 warnings .warn (
7589 "Using the default field for Parameters in Annotations is deprecated since v5.16"
@@ -81,29 +95,15 @@ def _get_parameters(cls, object_override: Optional[Self] = None) -> List[Paramet
8195 # Serialize the value (usually done in Parameter's validator)
8296 param .default = serialize (field_info .default ) # type: ignore
8397 parameters .append (param )
84- elif not is_annotated (annotations [field ]):
85- # Create a Parameter from basic type annotations
86- default = getattr (object_override , field ) if object_override else field_info .default
87-
88- # For users on Pydantic 2 but using V1 BaseModel, we still need to check if `default` is None
89- if default is None or default == PydanticUndefined :
90- default = MISSING
91-
92- parameters .append (Parameter (name = field , default = default ))
9398
9499 return parameters
95100
96101 @classmethod
97102 def _get_artifacts (cls ) -> List [Artifact ]:
98103 artifacts = []
99- annotations = get_field_annotations (cls )
100104
101- for field in get_fields (cls ):
102- if (artifact := get_workflow_annotation (annotations [field ])) and isinstance (artifact , Artifact ):
103- # Copy so as to not modify the Input fields themselves
104- artifact = artifact .copy ()
105- if artifact .name is None :
106- artifact .name = field
105+ for _ , _ , artifact in _construct_io_from_fields (cls ):
106+ if isinstance (artifact , Artifact ):
107107 if artifact .path is None :
108108 artifact .path = artifact ._get_default_inputs_path ()
109109 artifacts .append (artifact )
@@ -117,42 +117,33 @@ def _get_inputs(cls) -> List[Union[Artifact, Parameter]]:
117117 def _get_as_templated_arguments (cls ) -> Self :
118118 """Returns the Input with templated values to propagate through a DAG/Steps function."""
119119 object_dict = {}
120- cls_fields = get_fields (cls )
121- annotations = get_field_annotations (cls )
122120
123- for field in cls_fields :
124- if param_or_artifact := get_workflow_annotation (annotations [field ]):
125- if isinstance (param_or_artifact , Parameter ):
126- object_dict [field ] = "{{inputs.parameters." + f"{ param_or_artifact .name } " + "}}"
127- else :
128- object_dict [field ] = "{{inputs.artifacts." + f"{ param_or_artifact .name } " + "}}"
129- elif not is_annotated (annotations [field ]):
130- object_dict [field ] = "{{inputs.parameters." + f"{ field } " + "}}"
121+ for field , _ , annotation in _construct_io_from_fields (cls ):
122+ input_type = "parameters" if isinstance (annotation , Parameter ) else "artifacts"
123+ object_dict [field ] = "{{" + f"inputs.{ input_type } .{ annotation .name } " + "}}"
131124
132125 return cls .construct (None , ** object_dict )
133126
134127 def _get_as_arguments (self ) -> ModelArguments :
135128 params = []
136129 artifacts = []
137- annotations = get_field_annotations (type (self ))
138130
139131 if isinstance (self , V1BaseModel ):
140132 self_dict = self .dict ()
141133 elif _PYDANTIC_VERSION == 2 and isinstance (self , V2BaseModel ):
142134 self_dict = self .model_dump ()
143135
144- for field in get_fields (type (self )):
136+ for field , _ , annotation in _construct_io_from_fields (type (self )):
145137 # The value may be a static value (of any time) if it has a default value, so we need to serialize it
146138 # If it is a templated string, it will be unaffected as `"{{mystr}}" == serialize("{{mystr}}")``
147139 templated_value = serialize (self_dict [field ])
140+ name = annotation .name
141+ assert name is not None # guaranteed by _get_workflow_annotations
148142
149- if (param_or_artifact := get_workflow_annotation (annotations [field ])) and param_or_artifact .name :
150- if isinstance (param_or_artifact , Parameter ):
151- params .append (ModelParameter (name = param_or_artifact .name , value = templated_value ))
152- else :
153- artifacts .append (ModelArtifact (name = param_or_artifact .name , from_ = templated_value ))
154- elif not is_annotated (annotations [field ]):
155- params .append (ModelParameter (name = field , value = templated_value ))
143+ if isinstance (annotation , Parameter ):
144+ params .append (ModelParameter (name = name , value = templated_value ))
145+ else :
146+ artifacts .append (ModelArtifact (name = name , from_ = templated_value ))
156147
157148 return ModelArguments (parameters = params or None , artifacts = artifacts or None )
158149
@@ -178,37 +169,22 @@ def __init__(self, /, **kwargs):
178169 @classmethod
179170 def _get_outputs (cls , add_missing_path : bool = False ) -> List [Union [Artifact , Parameter ]]:
180171 outputs : List [Union [Artifact , Parameter ]] = []
181- annotations = get_field_annotations (cls )
182-
183- model_fields = get_fields (cls )
184172
185- for field in model_fields :
173+ for field , field_info , annotation in _construct_io_from_fields ( cls ) :
186174 if field in {"exit_code" , "result" }:
187175 continue
188- if param_or_artifact := get_workflow_annotation (annotations [field ]):
189- if isinstance (param_or_artifact , Parameter ):
190- if add_missing_path and (
191- param_or_artifact .value_from is None or param_or_artifact .value_from .path is None
192- ):
193- param_or_artifact .value_from = ValueFrom (
194- path = f"/tmp/hera-outputs/parameters/{ param_or_artifact .name } "
195- )
196- outputs .append (param_or_artifact )
197- else :
198- if add_missing_path and param_or_artifact .path is None :
199- param_or_artifact .path = f"/tmp/hera-outputs/artifacts/{ param_or_artifact .name } "
200- outputs .append (param_or_artifact )
201- elif not is_annotated (annotations [field ]):
202- # Create a Parameter from basic type annotations
203- default = model_fields [field ].default
204- if default is None or default == PydanticUndefined :
205- default = MISSING
206-
207- value_from = None
208- if add_missing_path :
209- value_from = ValueFrom (path = f"/tmp/hera-outputs/parameters/{ field } " )
210-
211- outputs .append (Parameter (name = field , default = default , value_from = value_from ))
176+ if isinstance (annotation , Parameter ):
177+ if annotation .default is None :
178+ default = field_info .default
179+ if default is not None and default != PydanticUndefined :
180+ annotation .default = serialize (default )
181+
182+ if add_missing_path and (annotation .value_from is None or annotation .value_from .path is None ):
183+ annotation .value_from = ValueFrom (path = f"/tmp/hera-outputs/parameters/{ annotation .name } " )
184+ else :
185+ if add_missing_path and annotation .path is None :
186+ annotation .path = f"/tmp/hera-outputs/artifacts/{ annotation .name } "
187+ outputs .append (annotation )
212188 return outputs
213189
214190 @classmethod
@@ -230,27 +206,21 @@ def _get_as_invocator_output(self) -> List[Union[Artifact, Parameter]]:
230206 This lets dags and steps hoist task/step outputs into its own outputs.
231207 """
232208 outputs : List [Union [Artifact , Parameter ]] = []
233- annotations = get_field_annotations (type (self ))
234209
235210 if isinstance (self , V1BaseModel ):
236211 self_dict = self .dict ()
237212 elif _PYDANTIC_VERSION == 2 and isinstance (self , V2BaseModel ):
238213 self_dict = self .model_dump ()
239214
240- for field in get_fields (type (self )):
215+ for field , _ , annotation in _construct_io_from_fields (type (self )):
241216 if field in {"exit_code" , "result" }:
242217 continue
243218
244219 templated_value = self_dict [field ] # a string such as `"{{tasks.task_a.outputs.parameter.my_param}}"`
245220
246- if (param_or_artifact := get_workflow_annotation (annotations [field ])) and param_or_artifact .name :
247- if isinstance (param_or_artifact , Parameter ):
248- outputs .append (
249- Parameter (name = param_or_artifact .name , value_from = ValueFrom (parameter = templated_value ))
250- )
251- else :
252- outputs .append (Artifact (name = param_or_artifact .name , from_ = templated_value ))
253- elif not is_annotated (annotations [field ]):
254- outputs .append (Parameter (name = field , value_from = ValueFrom (parameter = templated_value )))
221+ if isinstance (annotation , Parameter ):
222+ outputs .append (Parameter (name = annotation .name , value_from = ValueFrom (parameter = templated_value )))
223+ else :
224+ outputs .append (Artifact (name = annotation .name , from_ = templated_value ))
255225
256226 return outputs
0 commit comments