-
Notifications
You must be signed in to change notification settings - Fork 31
Open
Description
❓ Questions & Help
Details
I am not sure if it is a bug, it bacially breaks my workflow and I can't upgrade to Merlin version 23.
I used merlin tensorflow 22.12 and my exported ensemble has all features as arrays and the config.pbtxt looks like
platform: "ensemble"
input {
name: "userid"
data_type: TYPE_INT64
dims: -1
dims: 1
}
input {
name: "contentid"
data_type: TYPE_INT64
dims: -1
dims: 1
}
However, after upgrading to 23.06(or 23.05), all inputs has one dim only and my grpc triton inference server refused to get a batch array as input and the config.pbtxt became following:
platform: "merlin_executor"
input {
name: "userid"
data_type: TYPE_INT64
dims: -1
}
input {
name: "contentid"
data_type: TYPE_INT64
dims: -1
}
Here is my train_and_export script, I ran this same script in merlin 22 and 23 and got the different results above
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.tensorflow import PredictTensorflow
from merlin.systems.dag.ops.workflow import TransformWorkflow
from nvtabular.workflow import Workflow
from merlin.models.tf import BinaryOutput
import tensorflow as tf
from merlin.io.dataset import Dataset
from merlin.schema.tags import Tags
from merlin.models.tf import Loader
import merlin.models.tf as mm
import argparse
import json
import os
import numpy as np
from nvtabular.loader.tf_utils import configure_tensorflow
import nvtabular as nvt
from nvtabular.ops import *
from merlin.schema.tags import Tags
import tensorflow as tf
import shutil
import nvtabular as nvt
from merlin.io import Shuffle
from merlin.core.utils import device_mem_size
import shutil
from pathlib import Path
from typing import List, Optional, Union
from merlin.io import Dataset
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
configure_tensorflow()
targets = ["watchfull", "liked", "commented", "scrolledtonext", "profileclicked", "followed"] >> FillMissing(
fill_val=0) >> AddMetadata(tags=[Tags.BINARY_CLASSIFICATION, "target"])
DATA_FOLDER = "./data/"
train = Dataset("./data/trains/", engine="parquet", part_size="128MB")
valid = Dataset("./data/valids/", engine="parquet", part_size="128MB")
output_path = os.path.join(DATA_FOLDER, "processed")
category_temp_directory = os.path.join(DATA_FOLDER, "categories")
user_id = ["userid"] >> Categorify(
freq_threshold=5, out_path=category_temp_directory) >> TagAsUserID()
item_id = ["contentid"] >> Categorify(
freq_threshold=5, out_path=category_temp_directory) >> TagAsItemID()
add_feat = [
"width",
"height",
"durationinseconds"] >> Categorify(out_path=category_temp_directory)
outputs = user_id + item_id + targets + add_feat
workflow = nvt.Workflow(outputs)
workflow_stored_path = os.path.join("./data/processed", "workflow")
workflow.fit(train)
workflow.transform(train).to_parquet(output_path=output_path + "/train/")
workflow.transform(valid).to_parquet(output_path=output_path + "/valid/")
# workflow_fit_transform(outputs, train, valid, output_path)
workflow.save(workflow_stored_path)
parser = argparse.ArgumentParser(
description='Hyperparameters for model training'
)
args = parser.parse_args()
train = Dataset(os.path.join("./data/processed/train", "*.parquet"), part_size="150MB")
valid = Dataset(os.path.join("./data/processed/valid", "*.parquet"), part_size="150MB")
# define schema object
schema = train.schema
batch_size = 512
LR = 0.03
MODEL = 'MULTI'
inputs = mm.InputBlock(schema)
prediction_tasks = mm.PredictionTasks(schema)
block = mm.MLPBlock([256, 128, 64, 32])
cgc = mm.CGCBlock(prediction_tasks, expert_block=block, num_task_experts=5,
num_shared_experts=3)
model = mm.Model(inputs, block, cgc, prediction_tasks)
opt = tf.keras.optimizers.Adagrad(learning_rate=LR)
model.compile(optimizer=opt, run_eagerly=False)
model.fit(train, validation_data=valid, batch_size=batch_size)
model.save("model_multi")
# export ensemble
os.environ["TF_GPU_ALLOCATOR"]="cuda_malloc_async"
label_columns = workflow.output_schema.select_by_tag(Tags.TARGET).column_names
workflow.remove_inputs(label_columns)
tf_model_path = os.path.join('./', 'model_multi')
serving_operators = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictTensorflow(model)
ensemble = Ensemble(serving_operators, workflow.input_schema)
export_path = os.path.join('./data/ensembles', 'multimodel')
ens_conf, node_confs = ensemble.export(export_path, name='multimodel')Metadata
Metadata
Assignees
Labels
No labels