Skip to content

Commit 9fd7983

Browse files
authored
Expose embeddings dimensions in SentenceEmbeddingsModel (#371)
* Expose method to extract sentence embeddings dimensions * Updated changelog
1 parent 66944eb commit 9fd7983

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. The format
66
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
77
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
88
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.
9+
- Addition of a `get_embedding_dim` method to get the dimension of the embeddings for sentence embeddings pipelines
910

1011
## Changed
1112
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.

src/pipelines/sentence_embeddings/pipeline.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ pub struct SentenceEmbeddingsModel {
167167
pooling_layer: Pooling,
168168
dense_layer: Option<Dense>,
169169
normalize_embeddings: bool,
170+
embeddings_dim: i64,
170171
}
171172

172173
impl SentenceEmbeddingsModel {
@@ -196,7 +197,6 @@ impl SentenceEmbeddingsModel {
196197
.validate()?;
197198

198199
// Setup tokenizer
199-
200200
let tokenizer_config = SentenceEmbeddingsTokenizerConfig::from_file(
201201
tokenizer_config_resource.get_local_path()?,
202202
);
@@ -223,7 +223,6 @@ impl SentenceEmbeddingsModel {
223223
)?;
224224

225225
// Setup transformer
226-
227226
let mut var_store = nn::VarStore::new(device);
228227
let transformer_config = ConfigOption::from_file(
229228
transformer_type,
@@ -234,15 +233,15 @@ impl SentenceEmbeddingsModel {
234233
var_store.load(transformer_weights_resource.get_local_path()?)?;
235234

236235
// Setup pooling layer
237-
238236
let pooling_config = PoolingConfig::from_file(pooling_config_resource.get_local_path()?);
237+
let mut embeddings_dim = pooling_config.word_embedding_dimension;
239238
let pooling_layer = Pooling::new(pooling_config);
240239

241240
// Setup dense layer
242-
243241
let dense_layer = if modules.dense_module().is_some() {
244242
let dense_config =
245243
DenseConfig::from_file(dense_config_resource.unwrap().get_local_path()?);
244+
embeddings_dim = dense_config.out_features;
246245
Some(Dense::new(
247246
dense_config,
248247
dense_weights_resource.unwrap().get_local_path()?,
@@ -264,6 +263,7 @@ impl SentenceEmbeddingsModel {
264263
pooling_layer,
265264
dense_layer,
266265
normalize_embeddings,
266+
embeddings_dim,
267267
})
268268
}
269269

@@ -282,6 +282,11 @@ impl SentenceEmbeddingsModel {
282282
self.tokenizer_truncation_strategy = truncation_strategy;
283283
}
284284

285+
/// Return the embedding output dimension
286+
pub fn get_embedding_dim(&self) -> Result<i64, RustBertError> {
287+
Ok(self.embeddings_dim)
288+
}
289+
285290
/// Tokenizes the inputs
286291
pub fn tokenize<S>(&self, inputs: &[S]) -> SentenceEmbeddingsTokenizerOutput
287292
where

0 commit comments

Comments
 (0)