Skip to content

Commit 29f93f0

Browse files
authored
Adjust Node.exportable() to catch up to Merlin Core (#209)
* Adjust `Node.exportable()` to catch up to Merlin Core As we prepare to export and/or run Merlin graphs in multiple contexts, the signature of `Node.exportable()` has changed a little. Further related changes are coming in #204, but these small changes make the tests pass with the latest `main` version of Core. * Adjust `tox.ini` to run the tests against latest version of Merlin Core
1 parent a6016fb commit 29f93f0

File tree

4 files changed

+57
-8
lines changed

4 files changed

+57
-8
lines changed

merlin/systems/dag/ensemble.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ def export(self, export_path, version=1):
6666
Write out an ensemble model configuration directory. The exported
6767
ensemble is designed for use with Triton Inference Server.
6868
"""
69+
backend = "ensemble"
70+
6971
# Create ensemble config
7072
ensemble_config = model_config.ModelConfig(
7173
name=self.name,
72-
platform="ensemble",
74+
platform=backend,
7375
# max_batch_size=configs[0].max_batch_size
7476
)
7577

@@ -95,14 +97,14 @@ def export(self, export_path, version=1):
9597
node_idx = 0
9698
node_id_lookup = {}
9799
for node in postorder_nodes:
98-
if node.exportable:
100+
if node.exportable(backend):
99101
node_id_lookup[node] = node_idx
100102
node_idx += 1
101103

102104
node_configs = []
103105
# Export node configs and add ensemble steps
104106
for node in postorder_nodes:
105-
if node.exportable:
107+
if node.exportable(backend):
106108
node_id = node_id_lookup.get(node, None)
107109
node_name = f"{node_id}_{node.export_name}"
108110

@@ -120,7 +122,9 @@ def export(self, export_path, version=1):
120122
)
121123

122124
for input_col_name, input_col_schema in node.input_schema.column_schemas.items():
123-
source = _find_column_source(node.parents_with_dependencies, input_col_name)
125+
source = _find_column_source(
126+
node.parents_with_dependencies, input_col_name, backend
127+
)
124128
source_id = node_id_lookup.get(source, None)
125129
in_suffix = f"_{source_id}" if source_id is not None else ""
126130
if input_col_schema.is_list and input_col_schema.is_ragged:
@@ -163,14 +167,14 @@ def export(self, export_path, version=1):
163167
return (ensemble_config, node_configs)
164168

165169

166-
def _find_column_source(upstream_nodes, column_name):
170+
def _find_column_source(upstream_nodes, column_name, backend):
167171
source_node = None
168172
for upstream_node in upstream_nodes:
169173
if column_name in upstream_node.output_columns.names:
170174
source_node = upstream_node
171175
break
172176

173-
if source_node and not source_node.exportable:
174-
return _find_column_source(source_node.parents_with_dependencies, column_name)
177+
if source_node and not source_node.exportable(backend):
178+
return _find_column_source(source_node.parents_with_dependencies, column_name, backend)
175179
else:
176180
return source_node

merlin/systems/dag/node.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,25 @@
2323
class InferenceNode(Node):
2424
"""Specialized node class used in Triton Ensemble DAGs"""
2525

26+
def exportable(self, backend: str = None):
27+
"""
28+
Determine whether the current node's operator is exportable for a given back-end
29+
30+
Parameters
31+
----------
32+
backend : str, optional
33+
The Merlin Systems (not Triton) back-end to use,
34+
either "ensemble" or "executor", by default None
35+
36+
Returns
37+
-------
38+
bool
39+
True if the node's operator is exportable for the supplied back-end
40+
"""
41+
backends = getattr(self.op, "exportable_backends", [])
42+
43+
return hasattr(self.op, "export") and backend in backends
44+
2645
def export(self, output_path: Union[str, os.PathLike], node_id: int = None, version: int = 1):
2746
"""
2847
Export a Triton config directory for this node.
@@ -58,6 +77,24 @@ def export_name(self):
5877
return self.op.export_name
5978

6079
def validate_schemas(self, root_schema, strict_dtypes=False):
80+
"""
81+
Checks that the output schema is valid given the previous
82+
nodes in the graph and following nodes in the graph, as
83+
well as any additional root inputs.
84+
85+
Parameters
86+
----------
87+
root_schema : Schema
88+
Schema of selection from the original data supplied
89+
strict_dtypes : bool, optional
90+
If True, raises an error when the dtypes in the input data
91+
do not match the dtypes in the schema, by default False
92+
93+
Raises
94+
------
95+
ValueError
96+
If an output column is produced but not used by child nodes
97+
"""
6198
super().validate_schemas(root_schema, strict_dtypes)
6299

63100
if self.children:

merlin/systems/dag/ops/operator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def export_name(self):
7171
"""
7272
return self.__class__.__name__.lower()
7373

74+
@property
75+
def exportable_backends(self):
76+
return ["ensemble"]
77+
7478
@abstractmethod
7579
def export(
7680
self,

tox.ini

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ commands =
1414
; Runs all CPU-based tests. NOTE: if you are using an M1 mac, this will fail. You need to
1515
; change the tensorflow dependency to `tensorflow-macos` in requirements/test-cpu.txt.
1616
deps = -rrequirements/test-cpu.txt
17-
commands = python -m pytest --cov-report term --cov=merlin -rxs tests/unit
17+
commands =
18+
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
19+
python -m pytest --cov-report term --cov=merlin -rxs tests/unit
1820

1921
[testenv:test-gpu]
2022
sitepackages=true
@@ -28,6 +30,7 @@ deps =
2830
pytest
2931
pytest-cov
3032
commands =
33+
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
3134
python -m pytest --cov-report term --cov merlin -rxs tests/unit
3235

3336
[testenv:test-merlin]
@@ -57,6 +60,7 @@ commands =
5760
; Install pre-commit-hooks to run these tests during development.
5861
deps = -rrequirements/dev.txt
5962
commands =
63+
python -m pip install --upgrade git+https://github.com/NVIDIA-Merlin/core.git
6064
flake8 setup.py merlin/ tests/
6165
black --check --diff merlin tests
6266
pylint merlin

0 commit comments

Comments
 (0)