From 315b9a0209edc69e1f40aab6a07ee3c4b567d19e Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Tue, 10 Mar 2026 00:23:48 +0200 Subject: [PATCH 1/7] working fix --- metaflow/plugins/argo/argo_workflows.py | 73 +++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a4b9427266d..1dbf4dd38c8 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1190,6 +1190,27 @@ def _is_conditional_join_node(self, node): return node.name in self.conditional_join_nodes def _many_in_funcs_all_conditional(self, node): + # avoid situation where none of the in_funcs are actually conditional, + # because this might be a nested static-split inside a conditional branch, + # which means all of the in_funcs will execute, not just some. + def _same_join(in_func): + split_parents = self.graph[in_func].split_parents + if not split_parents: + return False + return self.graph[split_parents[-1]].matching_join == node.name + + def _split_parent(in_func): + split_parents = self.graph[in_func].split_parents + if not split_parents: + return None + return self.graph[split_parents[-1]] + + if ( + all(_same_join(in_func) for in_func in node.in_funcs) + and len(set(_split_parent(in_func) for in_func in node.in_funcs)) > 1 + ): + return False + cond_in_funcs = [ in_func for in_func in node.in_funcs @@ -1412,6 +1433,58 @@ def _visit( ] required_deps = [] + # join steps in_funcs need special handling, as there can be disjoint sets of always-executing and conditional branches. + # for example + # switch_step -> a, b, shared_join + # a --static-split-> a1, a2, a3 -> shared_join + # b --static-split-> b1,b2,b3 -> shared_join + # + # the shared_join needs to handle dependencies (a1&&a2&&a3) || (b1&&b2&&b3) || switch_step + if self.graph[node.name].type == "join" and any( + self._is_conditional_node(self.graph[fn]) for fn in node.in_funcs + ): + # NOTE: The groupings for the in_funcs are formed by traversing up each funcs + # relative path until we encounter a split-switch. + # when a split is encountered, next we determine if the split is joined **before** we reach the join-node, or if it remains + # conditional. + + node_groups = {} + for fn in node.in_funcs: + if self.graph[fn].split_branches: + # This is the latest split in the DAG. + last_split = self.graph[fn].split_branches[-1] + new_funcs = node_groups.get(last_split, []) + new_funcs.append(fn) + node_groups[last_split] = new_funcs + + conditional_deps = [] + required_deps = [] + for lsplit, in_funcs in node_groups.items(): + if ( + self.graph[lsplit].type == "split-switch" + and len(in_funcs) > 1 + ): + # we have an unresolved conditional split leading to a join. + required_deps.append( + "(%s)" + % "||".join( + [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in in_funcs + ] + ) + ) + else: + required_deps.append( + "(%s)" + % "&&".join( + [ + "%s.Succeeded" % self._sanitize(in_func) + for in_func in in_funcs + ] + ) + ) + both_conditions = required_deps and conditional_deps depends_str = "{required}{_and}{conditional}".format( From ddb004c45b9012a8fd88e2f21f7956dfd3c57826 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 11 Mar 2026 23:31:18 +0200 Subject: [PATCH 2/7] more thorough fix for join step parsing --- metaflow/plugins/argo/argo_workflows.py | 91 ++++++++++++++++++------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 1dbf4dd38c8..2001bd799fb 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1448,42 +1448,85 @@ def _visit( # when a split is encountered, next we determine if the split is joined **before** we reach the join-node, or if it remains # conditional. + # TODO: This needs to cover ALL split types in order to tackle every use case. + # not covered is: static-split -> static-split -> conditional join node. + def _split_switch_ancestor(node, first_ancestor): + acc = [] + for in_fn in self.graph[node].in_funcs: + if self.graph[in_fn].type == "split-switch": + acc.append(in_fn) + if not in_fn == first_ancestor: + acc.extend( + _split_switch_ancestor(in_fn, first_ancestor) + ) + + return acc + node_groups = {} + node_switch_ancestors = {} for fn in node.in_funcs: if self.graph[fn].split_branches: # This is the latest split in the DAG. last_split = self.graph[fn].split_branches[-1] + switch_ancestors = _split_switch_ancestor( + fn, node.split_parents[-1] + ) + if switch_ancestors: + node_switch_ancestors[fn] = switch_ancestors new_funcs = node_groups.get(last_split, []) new_funcs.append(fn) node_groups[last_split] = new_funcs + def build_ancestor_tree(node_groups, switch_ancestors): + result = {} + for parent, children in node_groups.items(): + nodes = [ + n + for g in children + for n in (g if isinstance(g, list) else [g]) + ] + + # Group nodes by their ancestor set + by_anc = defaultdict(list) + for n in nodes: + by_anc[frozenset(switch_ancestors.get(n, []))].append(n) + + # Sort from most specific (most ancestors) to least + groups = sorted( + by_anc.items(), key=lambda x: len(x[0]), reverse=True + ) + + # Greedily build chains: add to a chain if this key is a subset of its first (largest) key + chains = [] + for key, grp in groups: + for chain in chains: + if key <= chain[0][0]: + chain.append((key, grp)) + break + else: + chains.append([(key, grp)]) + + result[parent] = [[g for _, g in chain] for chain in chains] + return result + conditional_deps = [] required_deps = [] - for lsplit, in_funcs in node_groups.items(): - if ( - self.graph[lsplit].type == "split-switch" - and len(in_funcs) > 1 - ): - # we have an unresolved conditional split leading to a join. - required_deps.append( - "(%s)" - % "||".join( - [ - "%s.Succeeded" % self._sanitize(in_func) - for in_func in in_funcs - ] - ) - ) - else: - required_deps.append( - "(%s)" - % "&&".join( - [ - "%s.Succeeded" % self._sanitize(in_func) - for in_func in in_funcs - ] + for parent, chains in build_ancestor_tree( + node_groups, node_switch_ancestors + ).items(): + parts = [] + for chain in chains: + # TODO: fix double-braces, though only cosmetic. + groups = [ + "({})".format( + " || ".join( + "%s.Succeeded" % self._sanitize(g) for g in grp + ) ) - ) + for grp in chain + ] + parts.append("({})".format(" || ".join(groups))) + required_deps.append("&&".join(parts)) both_conditions = required_deps and conditional_deps From 6f35872de837857472afe418ec1d92eb0e833481 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Wed, 11 Mar 2026 23:55:25 +0200 Subject: [PATCH 3/7] cleanup --- metaflow/plugins/argo/argo_workflows.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 2001bd799fb..b3416f4b436 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1443,21 +1443,15 @@ def _visit( if self.graph[node.name].type == "join" and any( self._is_conditional_node(self.graph[fn]) for fn in node.in_funcs ): - # NOTE: The groupings for the in_funcs are formed by traversing up each funcs - # relative path until we encounter a split-switch. - # when a split is encountered, next we determine if the split is joined **before** we reach the join-node, or if it remains - # conditional. - - # TODO: This needs to cover ALL split types in order to tackle every use case. - # not covered is: static-split -> static-split -> conditional join node. - def _split_switch_ancestor(node, first_ancestor): + + def _split_switch_ancestors(node, first_ancestor): acc = [] for in_fn in self.graph[node].in_funcs: if self.graph[in_fn].type == "split-switch": acc.append(in_fn) if not in_fn == first_ancestor: acc.extend( - _split_switch_ancestor(in_fn, first_ancestor) + _split_switch_ancestors(in_fn, first_ancestor) ) return acc @@ -1468,7 +1462,7 @@ def _split_switch_ancestor(node, first_ancestor): if self.graph[fn].split_branches: # This is the latest split in the DAG. last_split = self.graph[fn].split_branches[-1] - switch_ancestors = _split_switch_ancestor( + switch_ancestors = _split_switch_ancestors( fn, node.split_parents[-1] ) if switch_ancestors: @@ -1516,7 +1510,6 @@ def build_ancestor_tree(node_groups, switch_ancestors): ).items(): parts = [] for chain in chains: - # TODO: fix double-braces, though only cosmetic. groups = [ "({})".format( " || ".join( From 3d6f27c73edf735d13bf7c51cd390a6424faf28b Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 12 Mar 2026 00:42:51 +0200 Subject: [PATCH 4/7] cleanup --- metaflow/plugins/argo/argo_workflows.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b3416f4b436..7c765432aef 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1440,13 +1440,13 @@ def _visit( # b --static-split-> b1,b2,b3 -> shared_join # # the shared_join needs to handle dependencies (a1&&a2&&a3) || (b1&&b2&&b3) || switch_step - if self.graph[node.name].type == "join" and any( + if node.type == "join" and any( self._is_conditional_node(self.graph[fn]) for fn in node.in_funcs ): - def _split_switch_ancestors(node, first_ancestor): + def _split_switch_ancestors(step_name, first_ancestor): acc = [] - for in_fn in self.graph[node].in_funcs: + for in_fn in self.graph[step_name].in_funcs: if self.graph[in_fn].type == "split-switch": acc.append(in_fn) if not in_fn == first_ancestor: From 5132ee7a5a0eed5f624c1eb42d9333235cf5e913 Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 12 Mar 2026 00:46:48 +0200 Subject: [PATCH 5/7] add guard against empty node_groups --- metaflow/plugins/argo/argo_workflows.py | 34 +++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7c765432aef..994f955348f 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1503,23 +1503,25 @@ def build_ancestor_tree(node_groups, switch_ancestors): result[parent] = [[g for _, g in chain] for chain in chains] return result - conditional_deps = [] - required_deps = [] - for parent, chains in build_ancestor_tree( - node_groups, node_switch_ancestors - ).items(): - parts = [] - for chain in chains: - groups = [ - "({})".format( - " || ".join( - "%s.Succeeded" % self._sanitize(g) for g in grp + if node_groups: + conditional_deps = [] + required_deps = [] + for parent, chains in build_ancestor_tree( + node_groups, node_switch_ancestors + ).items(): + parts = [] + for chain in chains: + groups = [ + "({})".format( + " || ".join( + "%s.Succeeded" % self._sanitize(g) + for g in grp + ) ) - ) - for grp in chain - ] - parts.append("({})".format(" || ".join(groups))) - required_deps.append("&&".join(parts)) + for grp in chain + ] + parts.append("({})".format(" || ".join(groups))) + required_deps.append("&&".join(parts)) both_conditions = required_deps and conditional_deps From ba824096f328c079ebaa3fe72bee77fda24a643a Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 12 Mar 2026 01:32:21 +0200 Subject: [PATCH 6/7] remove unnecessary guard --- metaflow/plugins/argo/argo_workflows.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 994f955348f..4243e889c81 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1190,27 +1190,6 @@ def _is_conditional_join_node(self, node): return node.name in self.conditional_join_nodes def _many_in_funcs_all_conditional(self, node): - # avoid situation where none of the in_funcs are actually conditional, - # because this might be a nested static-split inside a conditional branch, - # which means all of the in_funcs will execute, not just some. - def _same_join(in_func): - split_parents = self.graph[in_func].split_parents - if not split_parents: - return False - return self.graph[split_parents[-1]].matching_join == node.name - - def _split_parent(in_func): - split_parents = self.graph[in_func].split_parents - if not split_parents: - return None - return self.graph[split_parents[-1]] - - if ( - all(_same_join(in_func) for in_func in node.in_funcs) - and len(set(_split_parent(in_func) for in_func in node.in_funcs)) > 1 - ): - return False - cond_in_funcs = [ in_func for in_func in node.in_funcs From 8939b41830070cb7aeac98a097c4589f72f11f5c Mon Sep 17 00:00:00 2001 From: Sakari Ikonen Date: Thu, 12 Mar 2026 01:50:58 +0200 Subject: [PATCH 7/7] cleanup --- metaflow/plugins/argo/argo_workflows.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 4243e889c81..b7b25c8c69d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1413,12 +1413,6 @@ def _visit( required_deps = [] # join steps in_funcs need special handling, as there can be disjoint sets of always-executing and conditional branches. - # for example - # switch_step -> a, b, shared_join - # a --static-split-> a1, a2, a3 -> shared_join - # b --static-split-> b1,b2,b3 -> shared_join - # - # the shared_join needs to handle dependencies (a1&&a2&&a3) || (b1&&b2&&b3) || switch_step if node.type == "join" and any( self._is_conditional_node(self.graph[fn]) for fn in node.in_funcs ):