Skip to content

Commit c0b2f14

Browse files
committed
Added the smart drafting argument to the product prediction script
1 parent 6cb70b7 commit c0b2f14

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

scripts/product_prediction.sh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,12 @@ function run_speculative_beam_search() {
118118
local NBEST=${3:-5} # Number of best sequences
119119
local DRAFT_LEN=${4:-10} # Draft sequence length
120120
local N_DRAFTS=${5:-23} # Maximum number of parallel drafts
121-
local GPU=${6:-1}
122-
local SAVE_PREDICTIONS=${7:-false} # Whether to save predictions to disk. Slows down the run.
121+
local SMART_DRAFTS_MODE=${6:-false}
122+
local GPU=${7:-0}
123+
local SAVE_PREDICTIONS=${8:-false} # Whether to save predictions to disk. Slows down the run.
123124

124125
local DEVICE="--trainer.accelerator cpu --trainer.devices 1"
125-
if [ -n "${GPU}" ]; then
126+
if [ "${GPU}" != "false" ]; then
126127
DEVICE="--trainer.accelerator gpu --trainer.devices [${GPU}]"
127128
fi
128129

@@ -143,6 +144,7 @@ function run_speculative_beam_search() {
143144
--model.report_prediction_file ${OUTPUT_DIR}/report.txt \
144145
--data.batch_size ${BS} \
145146
--model.generation beam_search_speculative \
147+
--model.smart_drafts_mode ${SMART_DRAFTS_MODE} \
146148
--model.draft_len ${DRAFT_LEN} \
147149
--model.beam_size ${NBEST} \
148150
--model.max_len ${MAX_LEN} \
@@ -193,6 +195,7 @@ done
193195

194196
SAVE_PREDICTIONS=false
195197
N_BEST=5
198+
SMART_DRAFTS=false
196199

197200
# Beam search decoding with five hypotheses
198201
# Five runs for time spread estimation
@@ -203,28 +206,28 @@ for i in {1..6}; do
203206
draft_len=10
204207
n_drafts=23
205208
run_beam_search results_product_final_beam_search ${batch_size} ${N_BEST} ${GPU} ${SAVE_PREDICTIONS}
206-
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${GPU} ${SAVE_PREDICTIONS}
209+
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${SMART_DRAFTS} ${GPU} ${SAVE_PREDICTIONS}
207210

208211
# Batch size 2, 14 draft tokens, 10 drafts
209212
batch_size=2
210213
draft_len=14
211214
n_drafts=10
212215
run_beam_search results_product_final_beam_search ${batch_size} ${N_BEST} ${GPU} ${SAVE_PREDICTIONS}
213-
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${GPU} ${SAVE_PREDICTIONS}
216+
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${SMART_DRAFTS} ${GPU} ${SAVE_PREDICTIONS}
214217

215218
# Batch size 3, 9 draft tokens, 10 drafts
216219
batch_size=3
217220
draft_len=9
218221
n_drafts=10
219222
run_beam_search results_product_final_beam_search ${batch_size} ${N_BEST} ${GPU} ${SAVE_PREDICTIONS}
220-
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${GPU} ${SAVE_PREDICTIONS}
223+
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${SMART_DRAFTS} ${GPU} ${SAVE_PREDICTIONS}
221224

222225
# Batch size 4, 10 draft tokens, 7 drafts
223226
batch_size=4
224227
draft_len=10
225228
n_drafts=7
226229
run_beam_search results_product_final_beam_search ${batch_size} ${N_BEST} ${GPU} ${SAVE_PREDICTIONS}
227-
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${GPU} ${SAVE_PREDICTIONS}
230+
run_speculative_beam_search results_product_final_beam_search_speculative ${batch_size} ${N_BEST} ${draft_len} ${n_drafts} ${SMART_DRAFTS} ${GPU} ${SAVE_PREDICTIONS}
228231

229232
if [ "$i" -eq 5 ]; then
230233
SAVE_PREDICTIONS=true

0 commit comments

Comments
 (0)