Description
❓ Questions & Help
Hi everyone,
I tried to deploy a Transformers4rec
model using pre-trained embedding following the Transformers4rec with pre-trained embeddings example and the transformers-next-item-prediction-with-pretrained-embeddings.ipynb (for Tensorflow Merlin-models
). However, it seems to be problems to trace the PyTorch model with pre-trained embeddings.
Details
Based on the above examples, I made the following example:
data = tr.data.music_streaming_testing_data
schema = data.merlin_schema.select_by_name([
"item_id",
"item_category",
"item_recency",
"item_genres",
])
batch_size, max_length, pretrained_dim = 128, 20, 16
item_cardinality = schema["item_id"].int_domain.max + 1
np_emb_item_id = np.random.rand(item_cardinality, pretrained_dim)
embeddings_op = EmbeddingOperator(
np_emb_item_id, lookup_key="item_id", embedding_name="pretrained_item_id_embeddings"
)
# set dataloader with pre-trained embeddings
data_loader = MerlinDataLoader.from_schema(
schema,
Dataset(data.path, schema=schema),
max_sequence_length=max_length,
batch_size=batch_size,
transforms=[embeddings_op],
shuffle=False,
)
# set the model schema from data-loader
model_schema = data_loader.output_schema
inputs = tr.TabularSequenceFeatures.from_schema(
model_schema,
max_sequence_length=max_length,
pretrained_output_dims=8,
normalizer="layer-norm",
d_output=64,
masking="mlm",
)
transformer_config = tr.XLNetConfig.build(64, 4, 2, 20)
task = tr.NextItemPredictionTask(weight_tying=True)
model = transformer_config.to_torch_model(inputs, task, max_sequence_length=max_length)
args = T4RecTrainingArguments(
output_dir=".",
max_steps=5,
num_train_epochs=1,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size // 2,
max_sequence_length=max_length,
fp16=False,
report_to=[],
debug=["r"],
)
# Explicitly pass the merlin dataloader with pre-trained embeddings
recsys_trainer = Trainer(
model=model,
args=args,
schema=schema,
train_dataloader=data_loader,
eval_dataloader=data_loader,
compute_metrics=True,
)
recsys_trainer.train()
eval_metrics = recsys_trainer.evaluate(eval_dataset=data.path, metric_key_prefix="eval")
### Model export
topk = 20
model.top_k = topk
model.eval()
df = cudf.read_parquet(data.path, columns=model.input_schema.column_names)
table = TensorTable.from_df(df.loc[:10])
for column in table.columns:
table[column] = convert_col(table[column], TorchColumn)
model_input_dict = table.to_dict()
traced_model = torch.jit.trace(model, model_input_dict, strict=True)
input_schema = model.input_schema
output_schema = model.output_schema
torch_op = schema.column_names >> embeddings_op >> PredictPyTorch(
traced_model, input_schema, output_schema
)
ensemble = Ensemble(torch_op, schema)
ens_config, node_configs = ensemble.export(".")
As you can see below, a matrix shape mismatch error raises when tried to trace the PyTorch model:
Traceback (most recent call last):
File "/opt/ml/code/train.py", line 899, in test_trainer_with_pretrained_embeddings
traced_model = torch.jit.trace(model, model_input_dict, strict=True)
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 794, in trace
return trace_module(
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1056, in trace_module
module._c._create_method_from_trace(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 581, in forward
head(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/model/base.py", line 382, in forward
body_outputs = self.body(body_outputs, training=training, testing=testing, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 256, in forward
input = module(input, training=training, testing=testing)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/tabular/base.py", line 392, in __call__
outputs = super().__call__(inputs, *args, **kwargs) # noqa
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/features/sequence.py", line 259, in forward
outputs = self.projection_module(outputs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 252, in forward
input = module(input, **filtered_kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/config/schema.py", line 50, in __call__
return super().__call__(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers4rec/torch/block/base.py", line 260, in forward
input = module(input)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (220x128 and 136x64)
It seems that the torch.jit.trace()
function can't recognize the pre-trained embeddings provided by the dataloader.
Do you have any suggestion on how to deploy a Transformers4rec
model with pre-trained embeddings on Triton Inference Server?
Thanks for your amazing work!