Skip to content

Commit 0baa3fc

Browse files
committed
Added OpenSSL in linux arm build step in ci.yaml. Bumped to 0.1.9
1 parent 86b3527 commit 0baa3fc

File tree

6 files changed

+56
-5
lines changed

6 files changed

+56
-5
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ jobs:
163163
with:
164164
python-version: ${{ matrix.py-version }}
165165

166+
- name: Install OpenSSL
167+
run: sudo apt-get install libssl-dev
168+
166169
- name: Build wheels
167170
uses: PyO3/maturin-action@v1
168171
with:

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "polars-candle"
3-
version = "0.1.7"
3+
version = "0.1.9"
44
edition = "2021"
55
authors = ["Wouter Doppenberg <[email protected]>"]
66
description = "A text embedding extension for the Polars Dataframe library."

polars_candle/candle_ext.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,42 @@
66

77
ModelName = Literal["bert-base-uncased", "bert-base-cased"]
88

9+
# Old models that don't belong to any organization
10+
# Source https://github.com/UKPLab/sentence-transformers/blob/c0fc0e8238f7f48a1e92dc90f6f96c86f69f1e02/sentence_transformers/SentenceTransformer.py#L219
11+
__MODEL_HUB_ORGANIZATION__ = "sentence-transformers"
12+
__BASIC_TRANSFORMER_MODELS__ = (
13+
"bert-base-cased-finetuned-mrpc",
14+
"bert-base-cased",
15+
"bert-base-chinese",
16+
"bert-base-german-cased",
17+
"bert-base-german-dbmdz-cased",
18+
"bert-base-german-dbmdz-uncased",
19+
"bert-base-multilingual-cased",
20+
"bert-base-multilingual-uncased",
21+
"bert-base-uncased",
22+
"bert-large-cased-whole-word-masking-finetuned-squad",
23+
"bert-large-cased-whole-word-masking",
24+
"bert-large-cased",
25+
"bert-large-uncased-whole-word-masking-finetuned-squad",
26+
"bert-large-uncased-whole-word-masking",
27+
"bert-large-uncased",
28+
"camembert-base",
29+
"distilbert-base-cased-distilled-squad",
30+
"distilbert-base-cased",
31+
"distilbert-base-german-cased",
32+
"distilbert-base-multilingual-cased",
33+
"distilbert-base-uncased-distilled-squad",
34+
"distilbert-base-uncased-finetuned-sst-2-english",
35+
"distilbert-base-uncased",
36+
"distilgpt2",
37+
"distilroberta-base",
38+
"roberta-base-openai-detector",
39+
"roberta-base",
40+
"roberta-large-mnli",
41+
"roberta-large-openai-detector",
42+
"roberta-large",
43+
)
44+
945

1046
@pl.api.register_expr_namespace("candle")
1147
class CandleExt:
@@ -23,7 +59,7 @@ class CandleExt:
2359
... ],
2460
... })
2561
>>> df.with_columns(
26-
... pl.col("text").candle.embed_text("bert-base-uncased")
62+
... pl.col("text").candle.embed_text("bert-base-cased")
2763
... )
2864
"""
2965

@@ -57,6 +93,8 @@ def embed_text(
5793
Expr
5894
An expression with the embedded text.
5995
"""
96+
if "/" not in model_repo:
97+
model_repo = __MODEL_HUB_ORGANIZATION__ + "/" + model_repo
6098

6199
return register_plugin_function(
62100
plugin_path=Path(__file__).parent,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "polars_candle"
3-
version = "0.1.7"
3+
version = "0.1.9"
44
requires-python = ">=3.9"
55
license = "Apache-2.0"
66
classifiers = [
@@ -14,7 +14,7 @@ keywords = ["polars", "dataframe", "embedding", "nlp", "candle"]
1414

1515
[tool.poetry]
1616
name = "polars_candle"
17-
version = "0.1.7"
17+
version = "0.1.8"
1818
description = "A text embedding extension for the Polars Dataframe library."
1919
authors = ["Wouter Doppenberg <[email protected]>"]
2020

tests/test_candle.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ def test_basic_two_sentences_with_gpu():
4242
)
4343

4444

45+
def test_basic_model_no_repo():
46+
df = pl.DataFrame({"s": ["This is a sentence", "This is another sentence"]})
47+
48+
df = df.with_columns(
49+
pl.col("s")
50+
.candle.embed_text("nli-distilbert-base", device="gpu")
51+
.alias("s_embedding")
52+
)
53+
54+
4555
def test_basic_with_none():
4656
df = pl.DataFrame(
4757
{"s": ["This is a sentence", None, "This is another sentence", None]}

0 commit comments

Comments
 (0)