Skip to content

Commit 7c1fac6

Browse files
authored
fix text preprocessing for inference model (#40)
* fix text preprocessing for inference model * add filter_window_sizes override for non-cnn models
1 parent 43c213c commit 7c1fac6

2 files changed

Lines changed: 7 additions & 3 deletions

File tree

src/detext/train/train_helper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_query(hparams):
4343
hparams.CLS, hparams.SEP, hparams.PAD,
4444
hparams.max_len,
4545
hparams.min_len,
46-
cnn_filter_window_size=max(hparams.filter_window_sizes)
46+
cnn_filter_window_size=max(hparams.filter_window_sizes) if hparams.ftr_ext == 'cnn' else 0
4747
)
4848
return query, query_placeholder
4949

@@ -75,7 +75,7 @@ def get_doc_fields(hparams):
7575
hparams.CLS, hparams.SEP, hparams.PAD,
7676
hparams.max_len,
7777
hparams.min_len,
78-
cnn_filter_window_size=max(hparams.filter_window_sizes)
78+
cnn_filter_window_size=max(hparams.filter_window_sizes) if hparams.ftr_ext == 'cnn' else 0
7979
)
8080
one_doc_field = tf.expand_dims(one_doc_field, axis=0)
8181
doc_fields.append(one_doc_field)
@@ -118,7 +118,7 @@ def get_usr_fields(hparams):
118118
hparams.CLS, hparams.SEP, hparams.PAD,
119119
hparams.max_len,
120120
hparams.min_len,
121-
cnn_filter_window_size=max(hparams.filter_window_sizes)
121+
cnn_filter_window_size=max(hparams.filter_window_sizes) if hparams.ftr_ext == 'cnn' else 0
122122
)
123123
usr_fields.append(one_usr_field)
124124
return usr_fields, usr_text_placeholders

src/detext/utils/misc_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def extend_hparams(hparams):
9999
tok2regex_pattern = {'plain': None, 'punct': r'(\pP)'}
100100
hparams.regex_replace_pattern = tok2regex_pattern[hparams.tokenization]
101101

102+
# if not using cnn models, then disable cnn parameters
103+
if hparams.ftr_ext != 'cnn':
104+
hparams.filter_window_sizes = [0]
105+
102106
assert hparams.pmetric is not None, "Please set your primary evaluation metric using --pmetric option"
103107
assert hparams.pmetric != 'confusion_matrix', 'confusion_matrix cannot be used as primary evaluation metric.'
104108

0 commit comments

Comments
 (0)