Open
Description
Bug description
I am trying to extract embedding but the following options do not work.
Option 1:
I tried these scripts but none works:
model_transformer.query_embeddings(train, index='session_id')
or
model_transformer.query_embeddings(train, batch_size = 1024, index='session_id')
Option 2:
I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.
this works:
model_transformer.query_encoder(batch[0])
but iterating over loader batch by batch does not work:
all_sess_embeddings = []
for batch, _ in iter(loader):
embds = model_transformer.query_encoder(batch).numpy()
del batch
gc.collect()
all_sess_embeddings.append(embds)
Steps/Code to reproduce bug
Please go to this link to download the gist for the code to repro the issue:
https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2
Expected behavior
We should be able to extract session embeddings from query_model of the transformer model without any issues.
Environment details
- Merlin version:
- Platform:
- Python version:
- PyTorch version (GPU?):
- Tensorflow version (GPU?): Using tensorflow 23.06 image with the latest branches pulled.