Skip to content

Commit 1786263

Browse files
Added zero shot udf with span (#284)
1 parent 96c746e commit 1786263

44 files changed

Lines changed: 637 additions & 328 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

doc/changes/changes_2.2.0.md

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,18 @@
1-
# Transformers Extension 2.2.0, T.B.D
1+
# Transformers Extension 2.2.0, 2025-01-21
22

3-
Code name: T.B.D
3+
Code name: Bugfix for token classification
44

55
## Summary
66

7-
T.B.D
8-
9-
### Features
10-
11-
n/a
7+
This release includes a bugfix for handling unexpected results in the token classification udf,
8+
as well as internal refactorings for the unit tests.
129

1310
### Bugs
1411

1512
- #272: Fixed unit tests assertions not working correctly
1613
- #275: Fixed a bug where models returning unexpected results was not handled correctly
1714

18-
### Documentation
19-
20-
n/a
21-
2215
### Refactorings
2316

2417
- #273: Refactored unit tests for token_classification_udf to use StandAloneUDFMock, made params files more maintainable
2518
- #271: Moved test cases which only pertain to the base udf to base udf unit tests
26-
- #274: Refactored unit tests for zero_shot_text_classification_udf to use StandAloneUDFMock, made params files more maintainable
27-
28-
### Security
29-
30-
n/a

doc/developer_guide/developer_guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ inference output to the inputs.
8282
Before implementing the UDF logic (examined in item 4 in this section), the
8383
`run` function responsible for calling the newly created UDF script should be
8484
defined in `exasol_transformers_extension/udfs/callers/`.
85+
Also add the new udf to the lists in tests/utils/db_queries.py
8586

8687
### 3. UDF Template-Caller Matching
8788
The added UDF template and defined UDF caller should be added to the dictionary

exasol_transformers_extension/deployment/constants.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
"text_generation_udf_call.py":
1616
"text_generation_udf.jinja.sql",
1717
"translation_udf_call.py":
18-
"translation_udf.jinja.sql",
19-
"zero_shot_text_classification_udf.py":
20-
"zero_shot_text_classification_udf.jinja.sql"
18+
"translation_udf.jinja.sql"
2119
}
2220

2321
constants = InstallScriptsConstants(

exasol_transformers_extension/deployment/work_with_spans_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
UDF_CALL_TEMPLATES = {
66
"span_token_classification_udf_call.py":
77
"span_token_classification_udf.jinja.sql",
8+
"span_zero_shot_text_classification_udf_call.py":
9+
"span_zero_shot_text_classification_udf.jinja.sql"
810
}
911

1012
work_with_spans_constants = InstallScriptsConstants(

exasol_transformers_extension/deployment/work_without_spans_constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
UDF_CALL_TEMPLATES = {
66
"token_classification_udf_call.py":
77
"token_classification_udf.jinja.sql",
8+
"zero_shot_text_classification_udf.py":
9+
"zero_shot_text_classification_udf.jinja.sql"
810
}
911

1012
work_without_spans_constants = InstallScriptsConstants(

exasol_transformers_extension/resources/templates/with_spans/span_token_classification_udf.jinja.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ CREATE OR REPLACE {{ language_alias }} SET SCRIPT "TE_TOKEN_CLASSIFICATION_UDF_W
44
sub_dir VARCHAR(2000000),
55
model_name VARCHAR(2000000),
66
text_data VARCHAR(2000000),
7-
text_data_docid INTEGER,
7+
text_data_doc_id INTEGER,
88
text_data_char_begin INTEGER,
99
text_data_char_end INTEGER,
1010
aggregation_strategy VARCHAR(2000000)
@@ -13,14 +13,14 @@ CREATE OR REPLACE {{ language_alias }} SET SCRIPT "TE_TOKEN_CLASSIFICATION_UDF_W
1313
bucketfs_conn VARCHAR(2000000),
1414
sub_dir VARCHAR(2000000),
1515
model_name VARCHAR(2000000),
16-
text_data_docid INTEGER,
16+
text_data_doc_id INTEGER,
1717
text_data_char_begin INTEGER,
1818
text_data_char_end INTEGER,
1919
aggregation_strategy VARCHAR(2000000),
2020
entity_covered_text VARCHAR(2000000),
2121
entity_type VARCHAR(2000000),
2222
score DOUBLE,
23-
entity_docid INTEGER,
23+
entity_doc_id INTEGER,
2424
entity_char_begin INTEGER,
2525
entity_char_end INTEGER,
2626
error_message VARCHAR(2000000) ) AS
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
CREATE OR REPLACE {{ language_alias }} SET SCRIPT "TE_ZERO_SHOT_TEXT_CLASSIFICATION_UDF_WITH_SPAN"(
2+
device_id INTEGER,
3+
bucketfs_conn VARCHAR(2000000),
4+
sub_dir VARCHAR(2000000),
5+
model_name VARCHAR(2000000),
6+
text_data VARCHAR(2000000),
7+
text_data_doc_id INTEGER,
8+
text_data_char_begin INTEGER,
9+
text_data_char_end INTEGER,
10+
candidate_labels VARCHAR(2000000)
11+
ORDER BY {{ ordered_columns | join(" ASC,") }} ASC
12+
)EMITS (
13+
bucketfs_conn VARCHAR(2000000),
14+
sub_dir VARCHAR(2000000),
15+
model_name VARCHAR(2000000),
16+
text_data_doc_id INTEGER,
17+
text_data_char_begin INTEGER,
18+
text_data_char_end INTEGER,
19+
label VARCHAR(2000000),
20+
score DOUBLE,
21+
rank INTEGER,
22+
error_message VARCHAR(2000000) ) AS
23+
24+
{{ script_content }}
25+
26+
/

exasol_transformers_extension/resources/templates/zero_shot_text_classification_udf.jinja.sql renamed to exasol_transformers_extension/resources/templates/without_spans/zero_shot_text_classification_udf.jinja.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ CREATE OR REPLACE {{ language_alias }} SET SCRIPT "TE_ZERO_SHOT_TEXT_CLASSIFICAT
1010
bucketfs_conn VARCHAR(2000000),
1111
sub_dir VARCHAR(2000000),
1212
model_name VARCHAR(2000000),
13-
test_data VARCHAR(2000000),
13+
text_data VARCHAR(2000000),
1414
candidate_labels VARCHAR(2000000),
1515
label VARCHAR(2000000),
1616
score DOUBLE,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from exasol_transformers_extension.udfs.models.zero_shot_text_classification_udf import \
2+
ZeroShotTextClassificationUDF
3+
4+
udf = ZeroShotTextClassificationUDF(exa, work_with_spans=True)
5+
6+
7+
def run(ctx):
8+
return udf.run(ctx)

exasol_transformers_extension/udfs/models/token_classification_udf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def execute_prediction(self, model_df: pd.DataFrame) -> List[List[Dict[str, Any]
7979

8080
def create_new_span_columns(self, model_df: pd.DataFrame) -> pd.DataFrame:
8181
# create new columns for use with spans
82-
model_df[["entity_docid", "entity_char_begin", "entity_char_end"]] = None, None, None
82+
model_df[["entity_doc_id", "entity_char_begin", "entity_char_end"]] = None, None, None
8383
# we use different names in udf with span and without, so need to rename
8484
# this decision was made as to improve the naming of the columns without
8585
# breaking the interface of the existing udf
@@ -95,10 +95,10 @@ def drop_old_data_for_span_execution(self, model_df: pd.DataFrame) -> pd.DataFra
9595
return model_df
9696

9797
def make_entity_span(self, df_row):
98-
token_docid = df_row["text_data_docid"]
98+
token_doc_id = df_row["text_data_doc_id"]
9999
token_char_begin = df_row["start_pos"] + df_row['text_data_char_begin']
100100
token_char_end = df_row["end_pos"] + df_row['text_data_char_begin']
101-
return pd.Series([token_docid, token_char_begin, token_char_end])
101+
return pd.Series([token_doc_id, token_char_begin, token_char_end])
102102

103103
def append_predictions_to_input_dataframe(
104104
self, model_df: pd.DataFrame, pred_df_list: List[pd.DataFrame]) \
@@ -124,7 +124,7 @@ def append_predictions_to_input_dataframe(
124124

125125
if self.work_with_spans:
126126
model_df = self.create_new_span_columns(model_df)
127-
model_df[["entity_docid", "entity_char_begin", "entity_char_end"]] =\
127+
model_df[["entity_doc_id", "entity_char_begin", "entity_char_end"]] =\
128128
model_df.apply(self.make_entity_span, axis=1)
129129
model_df = self.drop_old_data_for_span_execution(model_df)
130130
return model_df

0 commit comments

Comments
 (0)