-
Notifications
You must be signed in to change notification settings - Fork 859
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Have you searched existing issues? 🔎
- I have searched and found no existing issues
Describe the bug
When calling TopicMapper.add_new_topics
, the method appends new rows to mappings_
with None
placeholders:
def add_new_topics(self, mappings: Mapping[int, int]):
length = len(self.mappings_[0])
for key, value in mappings.items():
to_append = [key] + ([None] * (length - 2)) + [value]
self.mappings_.append(to_append)
This works during runtime, but later when saving the model with model.save()
(serialization="safetensors"), it fails because _save_utils.save_topics
tries to cast the entire mappings_
table to np.array(..., dtype=int)
:
File ".../bertopic/_save_utils.py", line 442, in save_topics
"topic_mapper": np.array(model.topic_mapper_.mappings_, dtype=int).tolist(),
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'
Reproduction
from sklearn.datasets import fetch_20newsgroups
from sklearn.cluster import Birch
from bertopic.vectorizers import OnlineCountVectorizer
from bertopic import BERTopic
# Prepare documents
all_docs = fetch_20newsgroups(subset="train", remove=('headers', 'footers', 'quotes'))["data"]
first_docs = all_docs[:50] # Making it small so that new clusters emerge with partial_fit and new mappings are added to the topic mapper
doc_chunks = [all_docs[50:][i:i+1000] for i in range(0, len(all_docs[50:]), 1000)]
# Prepare sub-models that support online learning
cluster_model = Birch(threshold=1.5, n_clusters=None)
vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=.01),
# Train model for a first time
topic_model = BERTopic(
language="multilingual",
hdbscan_model=Birch(threshold=1.5, n_clusters=None),
vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=.01),
)
topic_model.fit_transform(documents=first_docs)
# Incremental fitting
for batch in doc_chunks:
topic_model.partial_fit(batch)
topic_model.save("temp", serialization="safetensors") # -> This throws the error
BERTopic Version
0.17.3
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working