Skip to content

[BUG] Unable to extract session embeddings from a session-based transformer model #163

Open
@rnyak

Description

@rnyak

Bug description

I am trying to extract embedding but the following options do not work.

Option 1:

I tried these scripts but none works:

model_transformer.query_embeddings(train, index='session_id')

or 

model_transformer.query_embeddings(train, batch_size = 1024,  index='session_id')

Option 2:

I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.

this works:
model_transformer.query_encoder(batch[0])

but iterating over loader batch by batch does not work:

all_sess_embeddings = []
for batch, _ in iter(loader):
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

Steps/Code to reproduce bug

Please go to this link to download the gist for the code to repro the issue:

https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2

Expected behavior

We should be able to extract session embeddings from query_model of the transformer model without any issues.

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): Using tensorflow 23.06 image with the latest branches pulled.

Metadata

Metadata

Labels

P0bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions