Skip to content

Added huggingface model support for Top2Vec #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![](https://img.shields.io/badge/arXiv-2008.09470-00ff00.svg)](http://arxiv.org/abs/2008.09470)


**Update: Pre-trained Universal Sentence Encoders and BERT Sentence Transformer now available for embedding. Read [more](#pretrained).**
**Update: Pre-trained [🤗](https://huggingface.co/) models now available for embedding. Read [more](#pretrained).**

Top2Vec
=======
Expand Down Expand Up @@ -46,7 +46,7 @@ attracted the documents to the dense area are the topic words.

### The Algorithm:

#### 1. Create jointly embedded document and word vectors using [Doc2Vec](https://radimrehurek.com/gensim/models/doc2vec.html) or [Universal Sentence Encoder](https://tfhub.dev/google/collections/universal-sentence-encoder/1) or [BERT Sentence Transformer](https://www.sbert.net/).
#### 1. Create jointly embedded document and word vectors using [Doc2Vec](https://radimrehurek.com/gensim/models/doc2vec.html) or [Universal Sentence Encoder](https://tfhub.dev/google/collections/universal-sentence-encoder/1) or [BERT Sentence Transformer](https://www.sbert.net/) or [🤗](https://huggingface.co/).
>Documents will be placed close to other similar documents and close to the most distinguishing words.

<!--![](https://raw.githubusercontent.com/ddangelov/Top2Vec/master/images/doc_word_embedding.svg?sanitize=true)-->
Expand Down Expand Up @@ -101,6 +101,10 @@ To install pre-trained BERT sentence transformer options:

pip install top2vec[sentence_transformers]

To install pre-trained 🤗 transformers options:

pip install top2vec[flair]

To install indexing options:

pip install top2vec[indexing]
Expand Down Expand Up @@ -144,18 +148,25 @@ Doc2Vec will be used by default to generate the joint word and document embeddin
* `universal-sentence-encoder`
* `universal-sentence-encoder-multilingual`
* `distiluse-base-multilingual-cased`
* `flair`

```python
from top2vec import Top2Vec

model = Top2Vec(documents, embedding_model='universal-sentence-encoder')
```

```python
from top2vec import Top2Vec

model = Top2Vec(documents, embedding_model='flair')
```

For large data sets and data sets with very unique vocabulary doc2vec could
produce better results. This will train a doc2vec model from scratch. This method
is language agnostic. However multiple languages will not be aligned.

Using the universal sentence encoder options will be much faster since those are
Using the universal sentence encoder or flair options will be much faster since those are
pre-trained and efficient models. The universal sentence encoder options are
suggested for smaller data sets. They are also good options for large data sets
that are in English or in languages covered by the multilingual model. It is also
Expand All @@ -166,7 +177,7 @@ for multilingual datasets and languages that are not covered by the multilingual
universal sentence encoder. The transformer is significantly slower than
the universal sentence encoder options.

More information on [universal-sentence-encoder](https://tfhub.dev/google/universal-sentence-encoder/4), [universal-sentence-encoder-multilingual](https://tfhub.dev/google/universal-sentence-encoder-multilingual/3), and [distiluse-base-multilingual-cased](https://www.sbert.net/docs/pretrained_models.html).
More information on [universal-sentence-encoder](https://tfhub.dev/google/universal-sentence-encoder/4), [universal-sentence-encoder-multilingual](https://tfhub.dev/google/universal-sentence-encoder-multilingual/3), [distiluse-base-multilingual-cased](https://www.sbert.net/docs/pretrained_models.html), and [flair](https://github.com/flairNLP/flair)/[🤗](https://huggingface.co/).


Citation
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
author = 'Dimo Angelov'

# The full version, including alpha/beta/rc tags
release = '1.0.26'
release = '1.0.27'


# -- General configuration ---------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions notebooks/fetch_20newsgroups.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":null,"source":["import numpy as np \r\n","import pandas as pd \r\n","import json\r\n","import os\r\n","import ipywidgets as widgets\r\n","from IPython.display import clear_output, display\r\n","from top2vec import Top2Vec"],"outputs":[],"metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true}},{"cell_type":"code","execution_count":null,"source":["from sklearn.datasets import fetch_20newsgroups"],"outputs":[],"metadata":{"trusted":true}},{"cell_type":"code","execution_count":null,"source":["newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["newsgroups.data[0]"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["model = Top2Vec(documents=newsgroups.data, embedding_model='flair', use_embedding_model_tokenizer=True)"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["model.get_num_topics()"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["topic_sizes, topic_nums = model.get_topic_sizes()"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["topic_sizes, topic_nums"],"outputs":[],"metadata":{}},{"cell_type":"code","execution_count":null,"source":["model.generate_topic_wordcloud(0)"],"outputs":[],"metadata":{}}],"metadata":{"kernelspec":{"name":"python3","display_name":"Python 3.6.8 64-bit ('huggingface_env': conda)"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.8","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"},"interpreter":{"hash":"f925cffa23e71364e95bf19513be47693a86ee1d684cc94b0f4368f2c87d4403"}},"nbformat":4,"nbformat_minor":4}
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ tensorflow_text
torch
sentence_transformers
hnswlib
flair
transformers
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setuptools.setup(
name="top2vec",
packages=["top2vec"],
version="1.0.26",
version="1.0.27",
author="Dimo Angelov",
author_email="[email protected]",
description="Top2Vec learns jointly embedded topic, document and word vectors.",
Expand Down Expand Up @@ -43,6 +43,11 @@
'torch',
'sentence_transformers',
],
'flair': [
"transformers>=4.0.0",
"torch>=1.5.0,!=1.8",
"flair==0.7"
],
'indexing': [
'hnswlib',
],
Expand Down
65 changes: 54 additions & 11 deletions top2vec/Top2Vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@
try:
from sentence_transformers import SentenceTransformer

_HAVE_TORCH = True
_HAVE_SENTENCE_TRANSFORMERS = True
except ImportError:
_HAVE_TORCH = False
_HAVE_SENTENCE_TRANSFORMERS = False

try:
from flair.data import Sentence
from flair.embeddings import TransformerDocumentEmbeddings

_HAVE_FLAIR = True
except ImportError:
_HAVE_FLAIR = False

logger = logging.getLogger('top2vec')
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -237,7 +245,8 @@ def __init__(self,

acceptable_embedding_models = ["universal-sentence-encoder-multilingual",
"universal-sentence-encoder",
"distiluse-base-multilingual-cased"]
"distiluse-base-multilingual-cased",
"flair"]

self.embedding_model_path = embedding_model_path

Expand Down Expand Up @@ -337,7 +346,7 @@ def return_doc(doc):

# embed words
self.word_indexes = dict(zip(self.vocab, range(len(self.vocab))))
self.word_vectors = self._l2_normalize(np.array(self.embed(self.vocab)))
self.word_vectors = self._l2_normalize(np.array(self._get_embedding(self.vocab)))

# embed documents
if use_embedding_model_tokenizer:
Expand Down Expand Up @@ -533,11 +542,11 @@ def _embed_documents(self, train_corpus):
extra = len(train_corpus) % batch_size

for ind in range(0, batches):
document_vectors.append(self.embed(train_corpus[current:current + batch_size]))
document_vectors.append(self._get_embedding(train_corpus[current:current + batch_size]))
current += batch_size

if extra > 0:
document_vectors.append(self.embed(train_corpus[current:current + extra]))
document_vectors.append(self._get_embedding(train_corpus[current:current + extra]))

document_vectors = self._l2_normalize(np.array(np.vstack(document_vectors)))

Expand All @@ -547,7 +556,28 @@ def _embed_query(self, query):
self._check_import_status()
self._check_model_status()

return self._l2_normalize(np.array(self.embed([query])[0]))
return self._l2_normalize(np.array(self._get_embedding([query])[0]))

def _get_embedding(self, documents):

if isinstance(documents, str):
documents = [documents]

if self.embedding_model == 'flair':
embeddings = []
for index, document in enumerate(documents):
try:
sentence = Sentence(document) if document else Sentence("an empty document")
self.embed(sentence)
except RuntimeError:
sentence = Sentence("an empty document")
self.embed(sentence)
embedding = sentence.embedding.detach().cpu().numpy()
embeddings.append(embedding)
embeddings = np.asarray(embeddings)
return embeddings
else:
return self.embed(documents)

def _set_document_vectors(self, document_vectors):
if self.embedding_model == 'doc2vec':
Expand Down Expand Up @@ -819,23 +849,37 @@ def _check_word_index_status(self):
"Call index_word_vectors method before setting use_index=True.")

def _check_import_status(self):
if self.embedding_model != 'distiluse-base-multilingual-cased':
if self.embedding_model == 'flair':
if not _HAVE_FLAIR:
raise ImportError(f"{self.embedding_model} is not available.\n\n"
"Try: pip install top2vec[flair]\n\n"
"Alternatively try: pip install torch transformers flair")
elif self.embedding_model != 'distiluse-base-multilingual-cased':
if not _HAVE_TENSORFLOW:
raise ImportError(f"{self.embedding_model} is not available.\n\n"
"Try: pip install top2vec[sentence_encoders]\n\n"
"Alternatively try: pip install tensorflow tensorflow_hub tensorflow_text")
else:
if not _HAVE_TORCH:
if not _HAVE_SENTENCE_TRANSFORMERS:
raise ImportError(f"{self.embedding_model} is not available.\n\n"
"Try: pip install top2vec[sentence_transformers]\n\n"
"Alternatively try: pip install torch sentence_transformers")


def _check_model_status(self):
if self.embed is None:
if self.verbose is False:
logger.setLevel(logging.DEBUG)

if self.embedding_model != "distiluse-base-multilingual-cased":
if self.embedding_model == "flair":
if self.embedding_model_path is None:
model = TransformerDocumentEmbeddings("roberta-base")
else:
model = TransformerDocumentEmbeddings(self.embedding_model_path)
if "fine_tune" in model.__dict__:
model.fine_tune = False
self.embed = model.embed
elif self.embedding_model != "distiluse-base-multilingual-cased":
if self.embedding_model_path is None:
logger.info(f'Downloading {self.embedding_model} model')
if self.embedding_model == "universal-sentence-encoder-multilingual":
Expand All @@ -846,7 +890,6 @@ def _check_model_status(self):
logger.info(f'Loading {self.embedding_model} model at {self.embedding_model_path}')
module = self.embedding_model_path
self.embed = hub.load(module)

else:
if self.embedding_model_path is None:
logger.info(f'Downloading {self.embedding_model} model')
Expand Down
12 changes: 11 additions & 1 deletion top2vec/tests/test_top2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,19 @@
embedding_model='distiluse-base-multilingual-cased',
use_embedding_model_tokenizer=True)

# test Flair
top2vec_flair = Top2Vec(documents=newsgroups_documents,
embedding_model='flair')

# test Flair with model emebdding
top2vec_flair_model_embedding = Top2Vec(documents=newsgroups_documents,
embedding_model='flair',
use_embedding_model_tokenizer=True)

models = [top2vec, top2vec_docids, top2vec_no_docs, top2vec_corpus_file,
top2vec_use, top2vec_use_multilang, top2vec_transformer_multilang,
top2vec_use_model_embedding, top2vec_transformer_model_embedding]
top2vec_use_model_embedding, top2vec_transformer_model_embedding,
top2vec_flair, top2vec_flair_model_embedding]


def get_model_vocab(top2vec_model):
Expand Down