diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index a4b9427266d..b7b25c8c69d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -1412,6 +1412,90 @@ def _visit( ] required_deps = [] + # join steps in_funcs need special handling, as there can be disjoint sets of always-executing and conditional branches. + if node.type == "join" and any( + self._is_conditional_node(self.graph[fn]) for fn in node.in_funcs + ): + + def _split_switch_ancestors(step_name, first_ancestor): + acc = [] + for in_fn in self.graph[step_name].in_funcs: + if self.graph[in_fn].type == "split-switch": + acc.append(in_fn) + if not in_fn == first_ancestor: + acc.extend( + _split_switch_ancestors(in_fn, first_ancestor) + ) + + return acc + + node_groups = {} + node_switch_ancestors = {} + for fn in node.in_funcs: + if self.graph[fn].split_branches: + # This is the latest split in the DAG. + last_split = self.graph[fn].split_branches[-1] + switch_ancestors = _split_switch_ancestors( + fn, node.split_parents[-1] + ) + if switch_ancestors: + node_switch_ancestors[fn] = switch_ancestors + new_funcs = node_groups.get(last_split, []) + new_funcs.append(fn) + node_groups[last_split] = new_funcs + + def build_ancestor_tree(node_groups, switch_ancestors): + result = {} + for parent, children in node_groups.items(): + nodes = [ + n + for g in children + for n in (g if isinstance(g, list) else [g]) + ] + + # Group nodes by their ancestor set + by_anc = defaultdict(list) + for n in nodes: + by_anc[frozenset(switch_ancestors.get(n, []))].append(n) + + # Sort from most specific (most ancestors) to least + groups = sorted( + by_anc.items(), key=lambda x: len(x[0]), reverse=True + ) + + # Greedily build chains: add to a chain if this key is a subset of its first (largest) key + chains = [] + for key, grp in groups: + for chain in chains: + if key <= chain[0][0]: + chain.append((key, grp)) + break + else: + chains.append([(key, grp)]) + + result[parent] = [[g for _, g in chain] for chain in chains] + return result + + if node_groups: + conditional_deps = [] + required_deps = [] + for parent, chains in build_ancestor_tree( + node_groups, node_switch_ancestors + ).items(): + parts = [] + for chain in chains: + groups = [ + "({})".format( + " || ".join( + "%s.Succeeded" % self._sanitize(g) + for g in grp + ) + ) + for grp in chain + ] + parts.append("({})".format(" || ".join(groups))) + required_deps.append("&&".join(parts)) + both_conditions = required_deps and conditional_deps depends_str = "{required}{_and}{conditional}".format(