diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py new file mode 100644 index 000000000..562712d3a --- /dev/null +++ b/merlin/systems/dag/ops/hugectr.py @@ -0,0 +1,446 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import os +import pathlib + +import numpy as np +import tritonclient.grpc.model_config_pb2 as model_config +from google.protobuf import text_format + +from merlin.core.dispatch import make_df +from merlin.dag import ColumnSelector +from merlin.schema import ColumnSchema, Schema +from merlin.schema.tags import Tags +from merlin.systems.dag.ops.compat import pb_utils +from merlin.systems.dag.ops.operator import ( + InferenceDataFrame, + InferenceOperator, + PipelineableInferenceOperator, +) + + +def _convert(data, slot_size_array, categorical_columns, labels=None): + """Prepares data for a request to the HugeCTR predict interface. + + Returns + ------- + Tuple of dense, categorical, and row index. + Corresponding to the three inputs required by a HugeCTR model. + """ + labels = labels or [] + dense_columns = list(set(data.columns) - set(categorical_columns + labels)) + categorical_dim = len(categorical_columns) + batch_size = data.shape[0] + + shift = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1].tolist() + + # These dtypes are static for HugeCTR + dense = np.array([data[dense_columns].values.flatten().tolist()], dtype="float32") + cat = np.array([(data[categorical_columns] + shift).values.flatten().tolist()], dtype="int64") + rowptr = np.array([list(range(batch_size * categorical_dim + 1))], dtype="int32") + + return dense, cat, rowptr + + +class PredictHugeCTR(PipelineableInferenceOperator): + def __init__(self, model, input_schema: Schema, *, backend="python", **hugectr_params): + """Instantiate a HugeCTR inference operator. + + Parameters + ---------- + model : HugeCTR Model Instance + A HugeCTR model class. + input_schema : merlin.schema.Schema + The schema representing the input columns expected by the model. + backend : str + The Triton backend to use to when running this operator. + **hugectr_params + The parameters to pass to the HugeCTR operator. + """ + if model is not None: + self.hugectr_op = HugeCTR(model, **hugectr_params) + + self.backend = backend + self.input_schema = input_schema + categorical_columns = self.input_schema.select_by_tag(Tags.CATEGORICAL).column_names + if not categorical_columns: + raise ValueError( + "HugeCTR require categorical columns." + "No columns with categorical tags were found in the input schema supplied." + ) + self._hugectr_model_name = None + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ) -> Schema: + """Return the output schema representing the columns this operator returns.""" + return self.hugectr_op.compute_output_schema( + input_schema, col_selector, prev_output_schema=prev_output_schema + ) + + def compute_input_schema( + self, + root_schema: Schema, + parents_schema: Schema, + deps_schema: Schema, + selector: ColumnSelector, + ) -> Schema: + """Return the input schema representing the input columns this operator expects to use.""" + return self.input_schema + + def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1): + """Export the class and related files to the path specified.""" + hugectr_model_config = self.hugectr_op.export( + path, + input_schema, + output_schema, + params=params, + node_id=node_id, + version=version, + ) + params = params or {} + params = { + **params, + "hugectr_model_name": hugectr_model_config.name, + "slot_sizes": hugectr_model_config.parameters["slot_sizes"].string_value, + } + return super().export( + path, + input_schema, + output_schema, + params=params, + node_id=node_id, + version=version, + backend=self.backend, + ) + + @classmethod + def from_config(cls, config: dict) -> "PredictHugeCTR": + """Instantiate the class from a dictionary representation. + + Expected structure: + { + "input_dict": str # JSON dict with input names and schemas + "params": str # JSON dict with params saved at export + } + + """ + + column_schemas = [ + ColumnSchema(name, **schema_properties) + for name, schema_properties in json.loads(config["input_dict"]).items() + ] + + input_schema = Schema(column_schemas) + + cls_instance = cls(None, input_schema) + params = json.loads(config["params"]) + + cls_instance.slot_sizes = json.loads(params["slot_sizes"]) + cls_instance.set_hugectr_model_name(params["hugectr_model_name"]) + return cls_instance + + @property + def hugectr_model_name(self): + return self._hugectr_model_name + + def set_hugectr_model_name(self, hugectr_model_name): + self._hugectr_model_name = hugectr_model_name + + def transform(self, df: InferenceDataFrame) -> InferenceDataFrame: + """Transform the dataframe by applying this FIL operator to the set of input columns. + + Parameters + ----------- + df: InferenceDataFrame + A pandas or cudf dataframe that this operator will work on + + Returns + ------- + InferenceDataFrame + Returns a transformed dataframe for this operator""" + slot_sizes = [slot for slots in self.slot_sizes for slot in slots] + categorical_columns = self.input_schema.select_by_tag(Tags.CATEGORICAL).column_names + dict_to_pd = {k: v.ravel() for k, v in df} + + df = make_df(dict_to_pd) + dense, cats, rowptr = _convert(df, slot_sizes, categorical_columns, labels=["label"]) + + inputs = [ + pb_utils.Tensor("DES", dense), + pb_utils.Tensor("CATCOLUMN", cats), + pb_utils.Tensor("ROWINDEX", rowptr), + ] + + inference_request = pb_utils.InferenceRequest( + model_name=self.hugectr_model_name, + requested_output_names=["OUTPUT0"], + inputs=inputs, + ) + inference_response = inference_request.exec() + output0 = pb_utils.get_output_tensor_by_name(inference_response, "OUTPUT0") + + return InferenceDataFrame({"OUTPUT0": output0}) + + +class HugeCTR(InferenceOperator): + """ + Creates an operator meant to house a HugeCTR model. + Allows the model to run as part of a merlin graph operations for inference. + """ + + def __init__( + self, + model, + max_batch_size=64, + device_list=None, + hit_rate_threshold=None, + gpucache=None, + freeze_sparse=None, + gpucacheper=None, + max_nnz=2, + embeddingkey_long_type=None, + ): + self.model = model + self.max_batch_size = max_batch_size + self.device_list = device_list or [] + embeddingkey_long_type = embeddingkey_long_type or "true" + gpucache = gpucache or "true" + gpucacheper = gpucacheper or 0.5 + + self.hugectr_params = dict( + hit_rate_threshold=hit_rate_threshold, + gpucache=gpucache, + freeze_sparse=freeze_sparse, + gpucacheper=gpucacheper, + max_nnz=max_nnz, + embeddingkey_long_type=embeddingkey_long_type, + ) + + super().__init__() + + def compute_input_schema( + self, + root_schema: Schema, + parents_schema: Schema, + deps_schema: Schema, + selector: ColumnSelector, + ): + """_summary_ + + Parameters + ---------- + root_schema : Schema + The original schema to the graph. + parents_schema : Schema + A schema comprised of the output schemas of all parent nodes. + deps_schema : Schema + A concatenation of the output schemas of all dependency nodes. + selector : ColumnSelector + Sub selection of columns required to compute the input schema. + + Returns + ------- + Schema + A schema describing the inputs of the model. + """ + return Schema( + [ + ColumnSchema("DES", dtype=np.float32), + ColumnSchema("CATCOLUMN", dtype=np.int64), + ColumnSchema("ROWINDEX", dtype=np.int32), + ] + ) + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ): + """Return output schema of the model. + + Parameters + ---------- + input_schema : Schema + Schema representing inputs to the model + col_selector : ColumnSelector + list of columns to focus on from input schema + prev_output_schema : Schema, optional + The output schema of the previous node, by default None + + Returns + ------- + Schema + Schema describing the output of the model. + """ + return Schema([ColumnSchema("OUTPUT0", dtype=np.float32)]) + + def export(self, path, input_schema, output_schema, node_id=None, params=None, version=1): + """Create and export the required config files for the hugectr model. + + Parameters + ---------- + path : current path of the model + _description_ + input_schema : Schema + Schema describing inputs to model + output_schema : Schema + Schema describing outputs of model + node_id : int, optional + The node's position in execution chain, by default None + version : int, optional + The version of the model, by default 1 + + Returns + ------- + config + Dictionary representation of config file in memory. + """ + node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name + node_export_path = pathlib.Path(path) / node_name + node_export_path.mkdir(exist_ok=True) + model_name = node_name + hugectr_model_path = pathlib.Path(node_export_path) / str(version) + hugectr_model_path.mkdir(exist_ok=True) + + network_file = os.path.join(hugectr_model_path, f"{model_name}.json") + + self.model.graph_to_json(graph_config_file=network_file) + self.model.save_params_to_files(str(hugectr_model_path) + "/") + model_json = json.loads(open(network_file, "r").read()) + dense_pattern = "*_dense_*.model" + dense_path = [ + os.path.join(hugectr_model_path, path.name) + for path in hugectr_model_path.glob(dense_pattern) + if "opt" not in path.name + ][0] + sparse_pattern = "*_sparse_*.model" + sparse_paths = [ + os.path.join(hugectr_model_path, path.name) + for path in hugectr_model_path.glob(sparse_pattern) + if "opt" not in path.name + ] + + config_dict = dict() + config_dict["supportlonglong"] = True + + data_layer = model_json["layers"][0] + sparse_layers = [ + layer + for layer in model_json["layers"] + if layer["type"] == "DistributedSlotSparseEmbeddingHash" + ] + full_slots = [x["sparse_embedding_hparam"]["slot_size_array"] for x in sparse_layers] + num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"]) + vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers] + + model = dict() + model["model"] = model_name + model["slot_num"] = num_cat_columns + model["sparse_files"] = sparse_paths + model["dense_file"] = dense_path + model["maxnum_des_feature_per_sample"] = data_layer["dense"]["dense_dim"] + model["network_file"] = network_file + model["num_of_worker_buffer_in_pool"] = 4 + model["num_of_refresher_buffer_in_pool"] = 1 + model["deployed_device_list"] = self.device_list + model["max_batch_size"] = self.max_batch_size + model["default_value_for_each_table"] = [0.0] * len(sparse_layers) + model["hit_rate_threshold"] = 0.9 + model["gpucacheper"] = self.hugectr_params["gpucacheper"] + model["gpucache"] = True + model["cache_refresh_percentage_per_iteration"] = 0.2 + model["maxnum_catfeature_query_per_table_per_sample"] = [ + len(x["sparse_embedding_hparam"]["slot_size_array"]) for x in sparse_layers + ] + model["embedding_vecsize_per_table"] = vec_size + model["embedding_table_names"] = [x["top"] for x in sparse_layers] + config_dict["models"] = [model] + + parameter_server_config_path = str(node_export_path.parent / "ps.json") + with open(parameter_server_config_path, "w") as f: + f.write(json.dumps(config_dict)) + + self.hugectr_params["config"] = network_file + + # These are no longer required from hugectr_backend release 3.7 + self.hugectr_params["cat_feature_num"] = num_cat_columns + self.hugectr_params["des_feature_num"] = data_layer["dense"]["dense_dim"] + self.hugectr_params["embedding_vector_size"] = vec_size[0] + self.hugectr_params["slots"] = num_cat_columns + self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"] + self.hugectr_params["slot_sizes"] = full_slots + config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) + + with open(os.path.join(node_export_path, "config.pbtxt"), "w") as o: + text_format.PrintMessage(config, o) + + return config + + +def _hugectr_config(name, parameters, max_batch_size=None): + """Create a config for a HugeCTR model. + + Parameters + ---------- + name : string + The name of the hugectr model. + parameters : dictionary + Dictionary holding parameter values required by hugectr + max_batch_size : int, optional + The maximum batch size to be processed per batch, by an inference request, by default None + + Returns + ------- + config + Dictionary representation of hugectr config. + """ + config = model_config.ModelConfig(name=name, backend="hugectr", max_batch_size=max_batch_size) + + config.input.append( + model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]) + ) + + config.input.append( + model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]) + ) + + config.input.append( + model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]) + ) + + config.output.append( + model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1]) + ) + + config.instance_group.append(model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)) + + for parameter_key, parameter_value in parameters.items(): + if parameter_value is None: + continue + + if isinstance(parameter_value, list): + config.parameters[parameter_key].string_value = json.dumps(parameter_value) + elif isinstance(parameter_value, bool): + config.parameters[parameter_key].string_value = str(parameter_value).lower() + config.parameters[parameter_key].string_value = str(parameter_value) + + return config diff --git a/merlin/systems/dag/ops/operator.py b/merlin/systems/dag/ops/operator.py index a181a6619..3a005d220 100644 --- a/merlin/systems/dag/ops/operator.py +++ b/merlin/systems/dag/ops/operator.py @@ -12,7 +12,7 @@ from merlin.dag import BaseOperator # noqa from merlin.dag.selector import ColumnSelector # noqa -from merlin.schema import Schema # noqa +from merlin.schema import Schema, Tags # noqa from merlin.systems.dag.node import InferenceNode # noqa from merlin.systems.triton.export import _convert_dtype # noqa @@ -253,6 +253,7 @@ def _schema_to_dict(schema: Schema) -> dict: "dtype": col_schema.dtype.name, "is_list": col_schema.is_list, "is_ragged": col_schema.is_ragged, + "tags": [tag.value if isinstance(tag, Tags) else str(tag) for tag in col_schema.tags], } return schema_dict diff --git a/merlin/systems/triton/utils.py b/merlin/systems/triton/utils.py index 82f43ccf8..ec3879f58 100644 --- a/merlin/systems/triton/utils.py +++ b/merlin/systems/triton/utils.py @@ -14,7 +14,7 @@ @contextlib.contextmanager -def run_triton_server(modelpath): +def run_triton_server(modelpath, backend_config="tensorflow,version=2"): """This function starts up a Triton server instance and returns a client to it. Parameters @@ -32,7 +32,7 @@ def run_triton_server(modelpath): TRITON_SERVER_PATH, "--model-repository", modelpath, - "--backend-config=tensorflow,version=2", + f"--backend-config={backend_config}", ] env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = "0" diff --git a/tests/unit/systems/hugectr/__init__.py b/tests/unit/systems/hugectr/__init__.py new file mode 100644 index 000000000..0b8ff56d3 --- /dev/null +++ b/tests/unit/systems/hugectr/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py new file mode 100644 index 000000000..4571ad1e1 --- /dev/null +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -0,0 +1,360 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import numpy as np +import pytest + +import nvtabular as nvt +from merlin.core.dispatch import make_df +from merlin.dag import ColumnSelector +from merlin.schema import ColumnSchema, Schema +from merlin.schema.tags import Tags +from merlin.systems.dag.ensemble import Ensemble +from merlin.systems.dag.ops.hugectr import HugeCTR, PredictHugeCTR, _convert +from tests.unit.systems.utils.triton import _run_ensemble_on_tritonserver + +try: + import hugectr + from hugectr.inference import CreateInferenceSession, InferenceParams + from mpi4py import MPI # noqa pylint: disable=unused-import +except ImportError: + hugectr = None + + +triton = pytest.importorskip("merlin.systems.triton") +grpcclient = pytest.importorskip("tritonclient.grpc") +cudf = pytest.importorskip("cudf") + + +def _run_model(slot_sizes, source, dense_dim): + solver = hugectr.CreateSolver( + vvgpu=[[0]], + batchsize=10, + batchsize_eval=10, + max_eval_batches=50, + i64_input_key=True, + use_mixed_precision=False, + repeat_dataset=True, + ) + # https://github.com/NVIDIA-Merlin/HugeCTR/blob/9e648f879166fc93931c676a5594718f70178a92/docs/source/api/python_interface.md#datareaderparams + reader = hugectr.DataReaderParams( + data_reader_type=hugectr.DataReaderType_t.Parquet, + source=[os.path.join(source, "_file_list.txt")], + eval_source=os.path.join(source, "_file_list.txt"), + check_type=hugectr.Check_t.Non, + ) + + optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.Adam) + model = hugectr.Model(solver, reader, optimizer) + + model.add( + hugectr.Input( + label_dim=1, + label_name="label", + dense_dim=dense_dim, + dense_name="dense", + data_reader_sparse_param_array=[ + hugectr.DataReaderSparseParam("data1", len(slot_sizes) + 1, True, len(slot_sizes)) + ], + ) + ) + model.add( + hugectr.SparseEmbedding( + embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, + workspace_size_per_gpu_in_mb=107, + embedding_vec_size=16, + combiner="sum", + sparse_embedding_name="sparse_embedding1", + bottom_name="data1", + slot_size_array=slot_sizes, + optimizer=optimizer, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.InnerProduct, + bottom_names=["dense"], + top_names=["fc1"], + num_output=512, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.Reshape, + bottom_names=["sparse_embedding1"], + top_names=["reshape1"], + leading_dim=48, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.InnerProduct, + bottom_names=["reshape1", "fc1"], + top_names=["fc2"], + num_output=1, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss, + bottom_names=["fc2", "label"], + top_names=["loss"], + ) + ) + model.compile() + model.summary() + model.fit(max_iter=20, display=100, eval_interval=200, snapshot=10) + + return model + + +@pytest.mark.skip(reason="More than one hugectr pytest results in segfault") +def test_hugectr_op(tmpdir): + cat_dtypes = {"a": int, "b": int, "c": int} + + categorical_columns = list(cat_dtypes.keys()) + + gdf = make_df( + { + "a": np.arange(64, dtype=np.int64), + "b": np.arange(64, dtype=np.int64), + "c": np.arange(64, dtype=np.int64), + "d": np.random.rand(64).tolist(), + "label": [0] * 64, + }, + ) + gdf["label"] = gdf["label"].astype("float32") + gdf["d"] = gdf["d"].astype("float32") + train_dataset = nvt.Dataset(gdf) + + dense_columns = ["d"] + + dict_dtypes = {} + for col in dense_columns: + dict_dtypes[col] = np.float32 + + for col in categorical_columns: + dict_dtypes[col] = np.int64 + + for col in ["label"]: + dict_dtypes[col] = np.float32 + + train_path = os.path.join(tmpdir, "train/") + os.mkdir(train_path) + + train_dataset.to_parquet( + output_path=train_path, + shuffle=nvt.io.Shuffle.PER_PARTITION, + cats=categorical_columns, + conts=dense_columns, + labels=["label"], + dtypes=dict_dtypes, + ) + + embeddings = {"a": (64, 16), "b": (64, 16), "c": (64, 16)} + + total_cardinality = 0 + slot_sizes = [] + + for column in cat_dtypes: + slot_sizes.append(embeddings[column][0]) + total_cardinality += embeddings[column][0] + + # slot sizes = list of caridinalities per column, total is sum of individual + model = _run_model(slot_sizes, train_path, len(dense_columns)) + + model_op = HugeCTR(model, max_nnz=2, device_list=[0]) + + model_repository_path = os.path.join(tmpdir, "model_repository") + + input_schema = Schema( + [ + ColumnSchema("DES", dtype=np.float32), + ColumnSchema("CATCOLUMN", dtype=np.int64), + ColumnSchema("ROWINDEX", dtype=np.int32), + ] + ) + triton_chain = ColumnSelector(["DES", "CATCOLUMN", "ROWINDEX"]) >> model_op + ens = Ensemble(triton_chain, input_schema) + + os.makedirs(model_repository_path) + + enc_config, node_configs = ens.export(model_repository_path) + + assert enc_config + assert len(node_configs) == 1 + assert node_configs[0].name == "0_hugectr" + + df = train_dataset.to_ddf().compute()[:5] + dense, cats, rowptr = _convert(df, slot_sizes, categorical_columns, labels=["label"]) + + inputs = [ + grpcclient.InferInput("DES", dense.shape, triton.np_to_triton_dtype(dense.dtype)), + grpcclient.InferInput("CATCOLUMN", cats.shape, triton.np_to_triton_dtype(cats.dtype)), + grpcclient.InferInput("ROWINDEX", rowptr.shape, triton.np_to_triton_dtype(rowptr.dtype)), + ] + inputs[0].set_data_from_numpy(dense) + inputs[1].set_data_from_numpy(cats) + inputs[2].set_data_from_numpy(rowptr) + + response = _run_ensemble_on_tritonserver( + model_repository_path, + ["OUTPUT0"], + inputs, + "0_hugectr", + backend_config=f"hugectr,ps={tmpdir}/model_repository/ps.json", + ) + assert len(response.as_numpy("OUTPUT0")) == df.shape[0] + + model_config = node_configs[0].parameters["config"].string_value + + hugectr_name = node_configs[0].name + dense_path = f"{tmpdir}/model_repository/{hugectr_name}/1/_dense_0.model" + sparse_files = [f"{tmpdir}/model_repository/{hugectr_name}/1/0_sparse_0.model"] + out_predict = _predict( + dense, cats, rowptr, model_config, hugectr_name, dense_path, sparse_files + ) + + np.testing.assert_array_almost_equal(response.as_numpy("OUTPUT0"), np.array(out_predict)) + del model + + +def test_predict_hugectr(tmpdir): + cat_dtypes = {"a": int, "b": int, "c": int} + + categorical_columns = ["a", "b", "c"] + + gdf = make_df( + { + "a": np.arange(64, dtype=np.int64), + "b": np.arange(64, dtype=np.int64), + "c": np.arange(64, dtype=np.int64), + "d": np.random.rand(64).tolist(), + "label": [0] * 64, + }, + ) + gdf["label"] = gdf["label"].astype("float32") + gdf["d"] = gdf["d"].astype("float32") + train_dataset = nvt.Dataset(gdf) + + dense_columns = ["d"] + + dict_dtypes = {} + col_schemas = train_dataset.schema.column_schemas + for col in dense_columns: + col_schemas[col] = col_schemas[col].with_tags(Tags.CONTINUOUS) + dict_dtypes[col] = np.float32 + + for col in categorical_columns: + col_schemas[col] = col_schemas[col].with_tags(Tags.CATEGORICAL) + dict_dtypes[col] = np.int64 + + for col in ["label"]: + col_schemas[col] = col_schemas[col].with_tags(Tags.TARGET) + dict_dtypes[col] = np.float32 + + train_path = os.path.join(tmpdir, "train/") + os.mkdir(train_path) + + train_dataset.to_parquet( + output_path=train_path, + shuffle=nvt.io.Shuffle.PER_PARTITION, + cats=categorical_columns, + conts=dense_columns, + labels=["label"], + dtypes=dict_dtypes, + ) + + embeddings = {"a": (64, 16), "b": (64, 16), "c": (64, 16)} + + total_cardinality = 0 + slot_sizes = [] + + for column in cat_dtypes: + slot_sizes.append(embeddings[column][0]) + total_cardinality += embeddings[column][0] + + # slot sizes = list of caridinalities per column, total is sum of individual + model = _run_model(slot_sizes, train_path, len(dense_columns)) + + model_op = PredictHugeCTR(model, train_dataset.schema, max_nnz=2, device_list=[0]) + + model_repository_path = os.path.join(tmpdir, "model_repository") + + input_schema = train_dataset.schema + triton_chain = input_schema.column_names >> model_op + ens = Ensemble(triton_chain, input_schema) + + os.makedirs(model_repository_path) + + enc_config, node_configs = ens.export(model_repository_path) + + assert enc_config + assert len(node_configs) == 1 + assert node_configs[0].name == "0_predicthugectr" + + df = train_dataset.to_ddf().compute()[:5] + dense, cats, rowptr = _convert(df, slot_sizes, categorical_columns, labels=["label"]) + + response = _run_ensemble_on_tritonserver( + model_repository_path, + ["OUTPUT0"], + df, + "ensemble_model", + backend_config=f"hugectr,ps={tmpdir}/model_repository/ps.json", + ) + assert len(response.as_numpy("OUTPUT0")) == df.shape[0] + + model_config = f"{tmpdir}/model_repository/0_hugectr/1/0_hugectr.json" + + hugectr_name = "0_hugectr" + dense_path = f"{tmpdir}/model_repository/{hugectr_name}/1/_dense_0.model" + sparse_files = [f"{tmpdir}/model_repository/{hugectr_name}/1/0_sparse_0.model"] + out_predict = _predict( + dense, cats, rowptr, model_config, hugectr_name, dense_path, sparse_files + ) + + np.testing.assert_array_almost_equal(response.as_numpy("OUTPUT0"), np.array(out_predict)) + + +def test_no_categoricals(): + with pytest.raises(ValueError) as exc_info: + PredictHugeCTR(None, Schema()) + assert "HugeCTR require categorical columns." in str(exc_info.value) + + +def _predict( + dense_features, embedding_columns, row_ptrs, config_file, model_name, dense_path, sparse_paths +): + inference_params = InferenceParams( + model_name=model_name, + max_batchsize=64, + hit_rate_threshold=0.5, + dense_model_file=dense_path, + sparse_model_files=sparse_paths, + device_id=0, + use_gpu_embedding_cache=True, + cache_size_percentage=0.2, + i64_input_key=True, + use_mixed_precision=False, + ) + inference_session = CreateInferenceSession(config_file, inference_params) + output = inference_session.predict( + dense_features[0].tolist(), embedding_columns[0].tolist(), row_ptrs[0].tolist() + ) + return output diff --git a/tests/unit/systems/utils/triton.py b/tests/unit/systems/utils/triton.py index 2c84e1ba8..682e9e69b 100644 --- a/tests/unit/systems/utils/triton.py +++ b/tests/unit/systems/utils/triton.py @@ -28,15 +28,15 @@ def _run_ensemble_on_tritonserver( - tmpdir, - output_columns, - df, - model_name, + tmpdir, output_columns, df, model_name, backend_config="tensorflow,version=2" ): - inputs = triton.convert_df_to_triton_input(df.columns, df) + if not isinstance(df, list): + inputs = triton.convert_df_to_triton_input(df.columns, df) + else: + inputs = df outputs = [grpcclient.InferRequestedOutput(col) for col in output_columns] response = None - with run_triton_server(tmpdir) as client: + with run_triton_server(tmpdir, backend_config=backend_config) as client: response = client.infer(model_name, inputs, outputs=outputs) return response