File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def get_query(hparams):
4343 hparams .CLS , hparams .SEP , hparams .PAD ,
4444 hparams .max_len ,
4545 hparams .min_len ,
46- cnn_filter_window_size = max (hparams .filter_window_sizes )
46+ cnn_filter_window_size = max (hparams .filter_window_sizes ) if hparams . ftr_ext == 'cnn' else 0
4747 )
4848 return query , query_placeholder
4949
@@ -75,7 +75,7 @@ def get_doc_fields(hparams):
7575 hparams .CLS , hparams .SEP , hparams .PAD ,
7676 hparams .max_len ,
7777 hparams .min_len ,
78- cnn_filter_window_size = max (hparams .filter_window_sizes )
78+ cnn_filter_window_size = max (hparams .filter_window_sizes ) if hparams . ftr_ext == 'cnn' else 0
7979 )
8080 one_doc_field = tf .expand_dims (one_doc_field , axis = 0 )
8181 doc_fields .append (one_doc_field )
@@ -118,7 +118,7 @@ def get_usr_fields(hparams):
118118 hparams .CLS , hparams .SEP , hparams .PAD ,
119119 hparams .max_len ,
120120 hparams .min_len ,
121- cnn_filter_window_size = max (hparams .filter_window_sizes )
121+ cnn_filter_window_size = max (hparams .filter_window_sizes ) if hparams . ftr_ext == 'cnn' else 0
122122 )
123123 usr_fields .append (one_usr_field )
124124 return usr_fields , usr_text_placeholders
Original file line number Diff line number Diff line change @@ -99,6 +99,10 @@ def extend_hparams(hparams):
9999 tok2regex_pattern = {'plain' : None , 'punct' : r'(\pP)' }
100100 hparams .regex_replace_pattern = tok2regex_pattern [hparams .tokenization ]
101101
102+ # if not using cnn models, then disable cnn parameters
103+ if hparams .ftr_ext != 'cnn' :
104+ hparams .filter_window_sizes = [0 ]
105+
102106 assert hparams .pmetric is not None , "Please set your primary evaluation metric using --pmetric option"
103107 assert hparams .pmetric != 'confusion_matrix' , 'confusion_matrix cannot be used as primary evaluation metric.'
104108
You can’t perform that action at this time.
0 commit comments