Skip to content

Commit 914af07

Browse files
ahmedlone127prabodAbdullahMubeenAnwarDevinTDHa
authored
[SPARKNLP-1231] Implement ModernBertEmbedings (#14736)
* Add ModernBertTokenizer and ModernBertEmbeddings with comprehensive tests - Implemented ModernBertTokenizer for tokenization using BPE. - Developed ModernBertEmbeddings for token-level embeddings with support for various engines (TensorFlow, ONNX, OpenVINO). - Added extensive test cases for ModernBertEmbeddings covering various scenarios including empty tokens, special characters, long sentences, and batch processing. - Included parameter validation for maxSentenceLength to ensure it adheres to model constraints. - Enhanced documentation for both classes and their usage examples. * SPARKNLP-1231 fixing bugs + improving BPE toknization for ModernBert * Add ModernBertEmbeddings docs and example * ModernBert slow tests --------- Co-authored-by: Prabod Rathnayaka <prabod@rathnayaka.me> Co-authored-by: Abdullah mubeen <bdllhmubeen@gmail.com> Co-authored-by: Devin Ha <devin@trungducha.de>
1 parent 5826dbc commit 914af07

14 files changed

Lines changed: 3172 additions & 2 deletions

File tree

docs/en/annotators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ Additionally, these transformers are available.
173173
{% include templates/anno_table_entry.md path="./transformers" name="MiniLMEmbeddings" summary="Sentence embeddings using MiniLM, a lightweight and efficient sentence embedding model that can generate text embeddings for various NLP tasks."%}
174174
{% include templates/anno_table_entry.md path="./transformers" name="MistralTransformer" summary="MistralTransformer loads Mistral models, efficient dense and mixture-of-experts (MoE) language models optimized for high performance on reasoning and coding tasks." %}
175175
{% include templates/anno_table_entry.md path="./transformers" name="MLLamaForMultimodal" summary="MLLamaForMultimodal is an open-source multimodal model that combines a vision encoder with a large language model."%}
176+
{% include templates/anno_table_entry.md path="./transformers" name="ModernBertEmbeddings" summary="Token-level embeddings using ModernBERT, a modernized bidirectional encoder that is 8x faster and uses 5x less memory than traditional BERT, with support for sequences up to 8192 tokens."%}
176177
{% include templates/anno_table_entry.md path="./transformers" name="MPNetEmbeddings" summary="Sentence embeddings using MPNet."%}
177178
{% include templates/anno_table_entry.md path="./transformers" name="MPNetForQuestionAnswering" summary="MPNet Models with a span classification head on top for extractive question-answering tasks like SQuAD."%}
178179
{% include templates/anno_table_entry.md path="./transformers" name="MPNetForSequenceClassification" summary="MPNet Models with sequence classification/regression head on top e.g. for multi-class document classification tasks."%}
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
{%- capture title -%}
2+
ModernBertEmbeddings
3+
{%- endcapture -%}
4+
5+
{%- capture description -%}
6+
Token-level embeddings using ModernBERT (Modern Bidirectional Encoder Representations from Transformers) a state-of-the-art encoder model designed for improved efficiency and performance compared to traditional BERT
7+
models. It incorporates modern improvements including Flash Attention, unpadding, and GeGLU activation functions,
8+
and supports sequence lengths up to 8192 tokens.
9+
10+
Pretrained models can be loaded with `pretrained` of the companion object:
11+
```
12+
val embeddings = ModernBertEmbeddings.pretrained()
13+
.setInputCols("token", "document")
14+
.setOutputCol("modernbert_embeddings")
15+
```
16+
The default model is `"modernbert-base"`, if no name is provided.
17+
18+
For available pretrained models please see the [Models Hub](https://sparknlp.org/models?task=Embeddings).
19+
20+
For extended examples of usage, see [ModernBertEmbeddings.ipynb](https://github.com/JohnSnowLabs/spark-nlp/tree/master/examples/python/annotation/text/english/embeddings/ModernBertEmbeddings.ipynb) and [ModernBertEmbeddingsTestSpec](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/ModernBertEmbeddingsTestSpec.scala).
21+
22+
**Sources** :
23+
24+
[Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663)
25+
26+
[https://huggingface.co/answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)
27+
28+
**Paper abstract**
29+
30+
*Encoder-only transformer models such as BERT offer a great performance-size tradeoff for retrieval and
31+
classification tasks with respect to larger decoder-only models. Despite being the workhorse of numerous
32+
production pipelines, there have been limited Pareto improvements to BERT since its release. In this paper, we
33+
introduce ModernBERT, bringing modern model optimizations to encoder-only models and representing a major Pareto
34+
improvement over older encoders. Trained on 2 trillion tokens with a native 8192 sequence length, ModernBERT
35+
models exhibit state-of-the-art results on a large pool of evaluations encompassing diverse classification tasks
36+
and both single and multi-vector retrieval on different domains (including code). In addition to strong downstream
37+
performance, ModernBERT is also the most speed and memory efficient encoder and is designed for inference on
38+
common GPUs.*
39+
{%- endcapture -%}
40+
41+
{%- capture input_anno -%}
42+
DOCUMENT, TOKEN
43+
{%- endcapture -%}
44+
45+
{%- capture output_anno -%}
46+
WORD_EMBEDDINGS
47+
{%- endcapture -%}
48+
49+
{%- capture api_link -%}
50+
[ModernBertEmbeddings](/api/com/johnsnowlabs/nlp/embeddings/ModernBertEmbeddings)
51+
{%- endcapture -%}
52+
53+
{%- capture python_api_link -%}
54+
[ModernBertEmbeddings](/api/python/reference/autosummary/sparknlp/annotator/embeddings/modernbert_embeddings/index.html#sparknlp.annotator.embeddings.modernbert_embeddings.ModernBertEmbeddings)
55+
{%- endcapture -%}
56+
57+
{%- capture source_link -%}
58+
[ModernBertEmbeddings](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/embeddings/ModernBertEmbeddings.scala)
59+
{%- endcapture -%}
60+
61+
{%- capture prediction_python_example -%}
62+
import sparknlp
63+
from sparknlp.base import *
64+
from sparknlp.annotator import *
65+
from pyspark.ml import Pipeline
66+
67+
# First extract the prerequisites for the NerDLModel
68+
documentAssembler = DocumentAssembler() \
69+
.setInputCol("text") \
70+
.setOutputCol("document")
71+
72+
sentence = SentenceDetector() \
73+
.setInputCols(["document"]) \
74+
.setOutputCol("sentence")
75+
76+
tokenizer = Tokenizer() \
77+
.setInputCols(["sentence"]) \
78+
.setOutputCol("token")
79+
80+
# Use the transformer embeddings
81+
embeddings = ModernBertEmbeddings.pretrained() \
82+
.setInputCols(["document", "token"]) \
83+
.setOutputCol("embeddings")
84+
85+
# This pretrained model requires those specific transformer embeddings
86+
ner_model = NerDLModel.pretrained("ner_dl_bert", "en") \
87+
.setInputCols(["document", "token", "embeddings"]) \
88+
.setOutputCol("ner")
89+
90+
pipeline = Pipeline().setStages([
91+
documentAssembler,
92+
sentence,
93+
tokenizer,
94+
embeddings,
95+
ner_model
96+
])
97+
98+
data = spark.createDataFrame([["U.N. official Ekeus heads for Baghdad."]]).toDF("text")
99+
result = pipeline.fit(data).transform(data)
100+
101+
result.select("ner.result").show(truncate=False)
102+
+------------------------------------+
103+
|result |
104+
+------------------------------------+
105+
|[I-LOC, O, O, I-PER, O, O, I-LOC, O]|
106+
+------------------------------------+
107+
{%- endcapture -%}
108+
109+
{%- capture prediction_scala_example -%}
110+
import spark.implicits._
111+
import com.johnsnowlabs.nlp.base.DocumentAssembler
112+
import com.johnsnowlabs.nlp.annotators.Tokenizer
113+
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
114+
import com.johnsnowlabs.nlp.embeddings.ModernBertEmbeddings
115+
import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLModel
116+
import org.apache.spark.ml.Pipeline
117+
118+
// First extract the prerequisites for the NerDLModel
119+
val documentAssembler = new DocumentAssembler()
120+
.setInputCol("text")
121+
.setOutputCol("document")
122+
123+
val sentence = new SentenceDetector()
124+
.setInputCols("document")
125+
.setOutputCol("sentence")
126+
127+
val tokenizer = new Tokenizer()
128+
.setInputCols("sentence")
129+
.setOutputCol("token")
130+
131+
// Use the transformer embeddings
132+
val embeddings = ModernBertEmbeddings.pretrained()
133+
.setInputCols(Array("document", "token"))
134+
.setOutputCol("embeddings")
135+
136+
// This pretrained model requires those specific transformer embeddings
137+
val nerModel = NerDLModel.pretrained("ner_dl_bert", "en")
138+
.setInputCols(Array("document", "token", "embeddings"))
139+
.setOutputCol("ner")
140+
141+
val pipeline = new Pipeline().setStages(Array(
142+
documentAssembler,
143+
sentence,
144+
tokenizer,
145+
embeddings,
146+
nerModel
147+
))
148+
149+
val data = Seq("U.N. official Ekeus heads for Baghdad.").toDF("text")
150+
val result = pipeline.fit(data).transform(data)
151+
152+
result.select("ner.result").show(false)
153+
+------------------------------------+
154+
|result |
155+
+------------------------------------+
156+
|[I-LOC, O, O, I-PER, O, O, I-LOC, O]|
157+
+------------------------------------+
158+
{%- endcapture -%}
159+
160+
{%- capture training_python_example -%}
161+
import sparknlp
162+
from sparknlp.base import *
163+
from sparknlp.annotator import *
164+
from sparknlp.training import *
165+
from pyspark.ml import Pipeline
166+
167+
# First extract the prerequisites for the NerDLApproach
168+
documentAssembler = DocumentAssembler() \
169+
.setInputCol("text") \
170+
.setOutputCol("document")
171+
172+
sentence = SentenceDetector() \
173+
.setInputCols(["document"]) \
174+
.setOutputCol("sentence")
175+
176+
tokenizer = Tokenizer() \
177+
.setInputCols(["sentence"]) \
178+
.setOutputCol("token")
179+
180+
embeddings = ModernBertEmbeddings.pretrained() \
181+
.setInputCols(["sentence", "token"]) \
182+
.setOutputCol("embeddings")
183+
184+
# Then the training can start with the transformer embeddings
185+
nerTagger = NerDLApproach() \
186+
.setInputCols(["sentence", "token", "embeddings"]) \
187+
.setLabelColumn("label") \
188+
.setOutputCol("ner") \
189+
.setMaxEpochs(1) \
190+
.setVerbose(0)
191+
192+
pipeline = Pipeline().setStages([
193+
documentAssembler,
194+
sentence,
195+
tokenizer,
196+
embeddings,
197+
nerTagger
198+
])
199+
200+
# We use the text and labels from the CoNLL dataset
201+
conll = CoNLL()
202+
trainingData = conll.readDataset(spark, "eng.train")
203+
204+
pipelineModel = pipeline.fit(trainingData)
205+
{%- endcapture -%}
206+
207+
{%- capture training_scala_example -%}
208+
import com.johnsnowlabs.nlp.base.DocumentAssembler
209+
import com.johnsnowlabs.nlp.annotators.Tokenizer
210+
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
211+
import com.johnsnowlabs.nlp.embeddings.ModernBertEmbeddings
212+
import com.johnsnowlabs.nlp.annotators.ner.dl.NerDLApproach
213+
import com.johnsnowlabs.nlp.training.CoNLL
214+
import org.apache.spark.ml.Pipeline
215+
216+
// First extract the prerequisites for the NerDLApproach
217+
val documentAssembler = new DocumentAssembler()
218+
.setInputCol("text")
219+
.setOutputCol("document")
220+
221+
val sentence = new SentenceDetector()
222+
.setInputCols("document")
223+
.setOutputCol("sentence")
224+
225+
val tokenizer = new Tokenizer()
226+
.setInputCols("sentence")
227+
.setOutputCol("token")
228+
229+
val embeddings = ModernBertEmbeddings.pretrained()
230+
.setInputCols("sentence", "token")
231+
.setOutputCol("embeddings")
232+
233+
// Then the training can start with the transformer embeddings
234+
val nerTagger = new NerDLApproach()
235+
.setInputCols("sentence", "token", "embeddings")
236+
.setLabelColumn("label")
237+
.setOutputCol("ner")
238+
.setMaxEpochs(1)
239+
.setVerbose(0)
240+
241+
val pipeline = new Pipeline().setStages(Array(
242+
documentAssembler,
243+
sentence,
244+
tokenizer,
245+
embeddings,
246+
nerTagger
247+
))
248+
249+
// We use the text and labels from the CoNLL dataset
250+
val conll = CoNLL()
251+
val trainingData = conll.readDataset(spark, "src/test/resources/conll2003/eng.train")
252+
253+
val pipelineModel = pipeline.fit(trainingData)
254+
{%- endcapture -%}
255+
256+
{%- capture embeddings_python_example -%}
257+
import sparknlp
258+
from sparknlp.base import *
259+
from sparknlp.common import *
260+
from sparknlp.annotator import *
261+
from sparknlp.training import *
262+
from pyspark.ml import Pipeline
263+
264+
documentAssembler = DocumentAssembler() \
265+
.setInputCol("text") \
266+
.setOutputCol("document")
267+
268+
tokenizer = Tokenizer() \
269+
.setInputCols(["document"]) \
270+
.setOutputCol("token")
271+
272+
embeddings = ModernBertEmbeddings.pretrained() \
273+
.setInputCols(["token", "document"]) \
274+
.setOutputCol("modernbert_embeddings")
275+
276+
embeddingsFinisher = EmbeddingsFinisher() \
277+
.setInputCols(["modernbert_embeddings"]) \
278+
.setOutputCols("finished_embeddings") \
279+
.setOutputAsVector(True)
280+
281+
pipeline = Pipeline().setStages([
282+
documentAssembler,
283+
tokenizer,
284+
embeddings,
285+
embeddingsFinisher
286+
])
287+
288+
data = spark.createDataFrame([["This is a sentence."]]).toDF("text")
289+
result = pipeline.fit(data).transform(data)
290+
291+
result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
292+
+--------------------------------------------------------------------------------+
293+
| result|
294+
+--------------------------------------------------------------------------------+
295+
|[-2.3497989177703857,0.480538547039032,-0.3238905668258667,-1.612930893898010...|
296+
|[-2.1357314586639404,0.32984697818756104,-0.6032363176345825,-1.6791689395904...|
297+
|[-1.8244884014129639,-0.27088963985443115,-1.059438943862915,-0.9817547798156...|
298+
|[-1.1648050546646118,-0.4725411534309387,-0.5938255786895752,-1.5780693292617...|
299+
|[-0.9125322699546814,0.4563939869403839,-0.3975459933280945,-1.81611204147338...|
300+
+--------------------------------------------------------------------------------+
301+
302+
{%- endcapture -%}
303+
304+
{%- capture embeddings_scala_example -%}
305+
import spark.implicits._
306+
import com.johnsnowlabs.nlp.base.DocumentAssembler
307+
import com.johnsnowlabs.nlp.annotators.Tokenizer
308+
import com.johnsnowlabs.nlp.embeddings.ModernBertEmbeddings
309+
import com.johnsnowlabs.nlp.EmbeddingsFinisher
310+
import org.apache.spark.ml.Pipeline
311+
312+
val documentAssembler = new DocumentAssembler()
313+
.setInputCol("text")
314+
.setOutputCol("document")
315+
316+
val tokenizer = new Tokenizer()
317+
.setInputCols("document")
318+
.setOutputCol("token")
319+
320+
val embeddings = ModernBertEmbeddings.pretrained("modernbert-base", "en")
321+
.setInputCols("token", "document")
322+
.setOutputCol("modernbert_embeddings")
323+
324+
val embeddingsFinisher = new EmbeddingsFinisher()
325+
.setInputCols("modernbert_embeddings")
326+
.setOutputCols("finished_embeddings")
327+
.setOutputAsVector(true)
328+
329+
val pipeline = new Pipeline().setStages(Array(
330+
documentAssembler,
331+
tokenizer,
332+
embeddings,
333+
embeddingsFinisher
334+
))
335+
336+
val data = Seq("This is a sentence.").toDF("text")
337+
val result = pipeline.fit(data).transform(data)
338+
339+
result.selectExpr("explode(finished_embeddings) as result").show(5, 80)
340+
+--------------------------------------------------------------------------------+
341+
| result|
342+
+--------------------------------------------------------------------------------+
343+
|[-2.3497989177703857,0.480538547039032,-0.3238905668258667,-1.612930893898010...|
344+
|[-2.1357314586639404,0.32984697818756104,-0.6032363176345825,-1.6791689395904...|
345+
|[-1.8244884014129639,-0.27088963985443115,-1.059438943862915,-0.9817547798156...|
346+
|[-1.1648050546646118,-0.4725411534309387,-0.5938255786895752,-1.5780693292617...|
347+
|[-0.9125322699546814,0.4563939869403839,-0.3975459933280945,-1.81611204147338...|
348+
+--------------------------------------------------------------------------------+
349+
350+
{%- endcapture -%}
351+
352+
{% include templates/transformer_usecases_template.md
353+
title=title
354+
description=description
355+
input_anno=input_anno
356+
output_anno=output_anno
357+
python_api_link=python_api_link
358+
api_link=api_link
359+
source_link=source_link
360+
prediction_python_example=prediction_python_example
361+
prediction_scala_example=prediction_scala_example
362+
training_python_example=training_python_example
363+
training_scala_example=training_scala_example
364+
embeddings_python_example=embeddings_python_example
365+
embeddings_scala_example=embeddings_scala_example
366+
%}

0 commit comments

Comments
 (0)