Skip to content

Lbl2TransformerVec(Lbl2Vec).predict_model_docs() stalls / lack of GPU utilization #11

@frza1

Description

@frza1

It appears that on larger label datasets (>1000 labels), Lbl2TransformerVec(Lbl2Vec).predict_model_docs() will stall at the "calculate document vector <-> label vector similarities" step, perhaps due to a memory issue. Tracing the issue, it may be due to the below "utils.top_similar_vectors" function which converts the Torch tensors to numpy, which is called on in an apply function with predict_model_docs(). Would there be a way to refactor the below to perhaps leave the torch tensors in GPU and then convert to numpy outside of this function to improve performance?

The issue only seems to appear with label counts >1000.

utils.py

def top_similar_vectors(key_vector: np.array, candidate_vectors: List[np.array]) -> List[tuple]:
'''
 Calculates the cosines similarities of a given key vector to a list of candidate vectors.
 Parameters
 ----------
 key_vector : `np.array`_
         The key embedding vector

 candidate_vectors : List[`np.array`_]
         A list of candidate embedding vectors
 Returns
 -------
 top_results : List[tuples]
      A descending sorted of tuples of (cos_similarity, list_idx) by cosine similarities for each candidate vector in the list
 '''

cos_scores = util.cos_sim(key_vector, np.asarray(candidate_vectors))[0]
top_results = torch.topk(cos_scores, k=len(candidate_vectors))
## Return the tensors then convert to numpy

## Consider refactoring implementation to leave tensors in GPU instead of move to CPU at this point
top_cos_scores = top_results[0].detach().cpu().numpy()
top_indices = top_results[1].detach().cpu().numpy()

return list(zip(top_cos_scores, top_indices))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions