Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,21 +478,20 @@ def fit_transform(
if documents.Document.values[0] is None:
custom_documents = self._images_to_text(documents, embeddings)

# Extract topics by calculating c-TF-IDF
self._extract_topics(custom_documents, embeddings=embeddings)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Reduce topics
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(custom_documents, embeddings=embeddings, fine_tune_representation=not self.nr_topics)
if self.nr_topics:
custom_documents = self._reduce_topics(custom_documents)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Save the top 3 most representative documents per topic
self._save_representative_docs(custom_documents)
else:
# Extract topics by calculating c-TF-IDF
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose)

# Reduce topics
else:
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(
documents, embeddings=embeddings, verbose=self.verbose, fine_tune_representation=not self.nr_topics
)
if self.nr_topics:
documents = self._reduce_topics(documents)

Expand Down Expand Up @@ -3972,6 +3971,7 @@ def _extract_topics(
embeddings: np.ndarray = None,
mappings=None,
verbose: bool = False,
fine_tune_representation: bool = True,
):
"""Extract topics from the clusters using a class-based TF-IDF.

Expand All @@ -3980,16 +3980,27 @@ def _extract_topics(
embeddings: The document embeddings
mappings: The mappings from topic to word
verbose: Whether to log the process of extracting topics
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
If False, the topic representation will remain as the base c-TF-IDF representation.

Returns:
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
"""
if verbose:
logger.info("Representation - Extracting topics from clusters using representation models.")
action = "Fine-tuning" if fine_tune_representation else "Extracting"
method = "representation models" if fine_tune_representation else "c-TF-IDF for topic reduction"
logger.info(f"Representation - {action} topics using {method}.")

documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join})
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)
self.topic_representations_ = self._extract_words_per_topic(
words,
documents,
fine_tune_representation=fine_tune_representation,
calculate_aspects=fine_tune_representation,
)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)

if verbose:
logger.info("Representation - Completed \u2713")

Expand Down Expand Up @@ -4245,6 +4256,7 @@ def _extract_words_per_topic(
words: List[str],
documents: pd.DataFrame,
c_tf_idf: csr_matrix = None,
fine_tune_representation: bool = True,
calculate_aspects: bool = True,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Based on tf_idf scores per topic, extract the top n words per topic.
Expand All @@ -4258,6 +4270,8 @@ def _extract_words_per_topic(
words: List of all words (sorted according to tf_idf matrix position)
documents: DataFrame with documents and their topic IDs
c_tf_idf: A c-TF-IDF matrix from which to calculate the top words
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
If False, the topic representation will remain as the base c-TF-IDF representation.
calculate_aspects: Whether to calculate additional topic aspects

Returns:
Expand Down Expand Up @@ -4288,15 +4302,15 @@ def _extract_words_per_topic(

# Fine-tune the topic representations
topics = base_topics.copy()
if not self.representation_model:
if not self.representation_model or not fine_tune_representation:
# Default representation: c_tf_idf + top_n_words
topics = {label: values[: self.top_n_words] for label, values in topics.items()}
elif isinstance(self.representation_model, list):
elif fine_tune_representation and isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, BaseRepresentation):
elif fine_tune_representation and isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, dict):
elif fine_tune_representation and isinstance(self.representation_model, dict):
if self.representation_model.get("Main"):
main_model = self.representation_model["Main"]
if isinstance(main_model, BaseRepresentation):
Expand Down Expand Up @@ -4350,6 +4364,12 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p
if isinstance(self.nr_topics, int):
if self.nr_topics < initial_nr_topics:
documents = self._reduce_to_n_topics(documents, use_ctfidf)
else:
logger.info(
f"Topic reduction - Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())})."
)
self._extract_topics(documents, verbose=self.verbose)
return documents
elif isinstance(self.nr_topics, str):
documents = self._auto_reduce_topics(documents, use_ctfidf)
else:
Expand Down Expand Up @@ -4412,7 +4432,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)

# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)

self._update_topic_size(documents)
return documents
Expand Down Expand Up @@ -4468,7 +4488,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
# Update documents and topics
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)
self._update_topic_size(documents)
return documents

Expand Down
1 change: 1 addition & 0 deletions tests/test_representation/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_topic_reduction_edge_cases(model, documents, request):
topics = np.random.randint(-1, nr_topics - 1, len(documents))
old_documents = pd.DataFrame({"Document": documents, "ID": range(len(documents)), "Topic": topics})
topic_model._update_topic_size(old_documents)
old_documents = topic_model._sort_mappings_by_frequency(old_documents)
topic_model._extract_topics(old_documents)
old_freq = topic_model.get_topic_freq()

Expand Down