Skip to content

Commit 6d8d50b

Browse files
authored
fixing benchmark tests (#1244)
1 parent 5e62157 commit 6d8d50b

4 files changed

+8
-8
lines changed

benchmark/benchmark_basic_english_normalize.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@ def _run_benchmark_lookup(train, tokenizer):
1818
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize)
1919

2020
# existing eager lookup
21-
train, _ = AG_NEWS()
21+
train = AG_NEWS(split='train')
2222
print("BasicEnglishNormalize - Eager Mode")
2323
_run_benchmark_lookup(train, existing_basic_english_tokenizer)
2424

2525
# experimental eager lookup
26-
train, _ = AG_NEWS()
26+
train = AG_NEWS(split='train')
2727
print("BasicEnglishNormalize Experimental - Eager Mode")
2828
_run_benchmark_lookup(train, experimental_basic_english_normalize)
2929

3030
# experimental jit lookup
31-
train, _ = AG_NEWS()
31+
train = AG_NEWS(split='train')
3232
print("BasicEnglishNormalize Experimental - Jit Mode")
3333
_run_benchmark_lookup(train, experimental_jit_basic_english_normalize)
3434

benchmark/benchmark_experimental_vectors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def _run_benchmark_lookup(tokens, vector):
1313
vector[token]
1414
print("Lookup time:", time.monotonic() - t0)
1515

16-
train, = AG_NEWS(data_select='train')
16+
train = AG_NEWS(split='train')
1717
vocab = train.get_vocab()
1818
tokens = []
1919
for (label, text) in train:

benchmark/benchmark_experimental_vocab.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _run_benchmark_lookup(tokens, vocab):
9494
tokens = []
9595
tokens_lists = []
9696

97-
train, = AG_NEWS(data_select='train')
97+
train = AG_NEWS(split='train')
9898
vocab = train.get_vocab()
9999
for (_, text) in train:
100100
cur_tokens = []

benchmark/benchmark_sentencepiece.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torchtext.experimental.transforms import load_sp_model as load_pybind_sp_model
44
from torchtext.data.functional import load_sp_model as load_torchbind_sp_model
55
from torchtext.utils import download_from_url
6-
from torchtext.datasets import text_classification as raw
6+
from torchtext.datasets import DATASETS
77

88

99
def benchmark_sentencepiece(args):
@@ -17,13 +17,13 @@ def _run_benchmark(train, spm_processor):
1717
sp_model_path = download_from_url('https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model')
1818

1919
# existing sentencepiece model with torchbind
20-
train, _ = raw.DATASETS[args.dataset]()
20+
train = DATASETS[args.dataset](split='train')
2121
sp_model = load_torchbind_sp_model(sp_model_path)
2222
print("SentencePiece EncodeAsIds - torchbind")
2323
_run_benchmark(train, sp_model.EncodeAsIds)
2424

2525
# experimental sentencepiece model with pybind
26-
train, _ = raw.DATASETS[args.dataset]()
26+
train = DATASETS[args.dataset](split='train')
2727
sp_model = load_pybind_sp_model(sp_model_path)
2828
print("SentencePiece EncodeAsIds - pybind")
2929
_run_benchmark(train, sp_model.EncodeAsIds)

0 commit comments

Comments
 (0)