Skip to content
This repository was archived by the owner on Dec 1, 2024. It is now read-only.

Commit 3834bb3

Browse files
authored
Data wrangle benchmark (#95)
* benchmarked 6.7B. * some new scripts * update some results of opt30b * update some results. * added some 175B results, one piece left. * updated the missing results. * Update README.md * update readme
1 parent 74bdca7 commit 3834bb3

14 files changed

+514
-39
lines changed

flexgen/apps/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ python completion.py --model facebook/opt-30b --percent 100 0 100 0 100 0 --comp
77
python completion.py --model facebook/opt-66b --percent 50 10 100 0 100 0 --compress-weight
88
```
99

10+
### Data Wrangling
11+
12+
Run the tests of data wrangling tasks in the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).
13+
Check [more details](./data_wrangle/README.md).
14+
```
15+
cd data_wrangle
16+
bash install
17+
bash test_batch_query_all_opt6.7b.sh
18+
bash test_batch_query_all_opt30b.sh
19+
bash test_batch_query_all_opt175b.sh
20+
```
21+
22+
1023
### HELM benchmark
1124
Run Massive Multitask Language Understanding (MMLU) scenario.
1225
```

flexgen/apps/data_wrangle/README.md

+70-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# FlexGen for Data Wrangling Tasks.
22

3-
Here we show how to use FlexGen for the data wrangling tasks. The implementation follows the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).
3+
Here we show how to use FlexGen for the data wrangling tasks including entity match (EM), data imputation (DI) and error detection (ED). The implementation follows the [fm_data_tasks](https://github.com/HazyResearch/fm_data_tasks) repo from [HazyResearch](https://github.com/HazyResearch).
44

55
## Install
66

@@ -9,10 +9,77 @@ Here we show how to use FlexGen for the data wrangling tasks. The implementation
99

1010
## Examples
1111

12-
- To check the outcome and verify the result of a data imputation task (e.g., Restaurant), run:
12+
- To check the outcome and verify the result of a data imputation task (e.g., Restaurant on OPT-6.7B), run:
1313

1414
bash test_single_query_case.sh
1515

16-
- To test FlexGen Throughput of a data imputation task (e.g., Restaurant), run:
16+
- To test the throughput of FlexGen for a data imputation task (e.g., Restaurant on OPT-6.7B), run:
1717

1818
bash test_batch_query_case.sh
19+
20+
- To run the complete tests of all tasks on OPT-6.7B:
21+
22+
bash test_batch_query_all_opt6.7b.sh
23+
24+
- To run the complete tests of all tasks on OPT-30B:
25+
26+
bash test_batch_query_all_opt30b.sh
27+
28+
- To run the complete tests of all tasks on OPT-175B:
29+
30+
bash test_batch_query_all_opt175b.sh
31+
32+
33+
34+
## Benchmark Results
35+
36+
- Notice that in this data wrangling tasks, such as entity match (EM), data imputation (DI) and error detection (ED), the input sequences length is **very long** (from 123 to 1274), but the output length is **very short** (e.g., 3, 5, or 10). Most of the inference time is spent on prefill phase, so here we report the throughput that includes both input and output tokens as our measurement.
37+
38+
- We run the experiments on the same setting as the HELM benchmark with a single T4 (16GB) GPU, 200GB of DRAM, and 1.5TB SSD connected by NVMe.
39+
40+
### OPT6.7B
41+
42+
| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
43+
|------------------------|-------------------|---------------|---------------|----------|----------------------|
44+
| EM: Fodors-Zagats | 189 | 744 | 3 | 109.556 | 1281.871 |
45+
| EM: Beer | 91 | 592 | 3 | 42.087 | 1272.360 |
46+
| EM: iTunes-Amazon | 109 | 529 | 3 | 59.467 | 966.178 |
47+
| EM: Walmart-Amazon | 200 | 748 | 3 | 126.538 | 1186.992 |
48+
| EM: Amazon-Google | 200 | 876 | 3 | 144.593 | 1215.828 |
49+
| EM: DBLP-ACM | 200 | 1274 | 3 | 207.513 | 1230.767 |
50+
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 | 232.65 | 1097.78 |
51+
| DI: Restaurant | 86 | 123 | 5 | 10.397 | 984.865 |
52+
| DI: Buy | 65 | 488 | 10 | 43.077 | 739.876 |
53+
| ED: Hospital | 200 | 200 | 3 | 30.137 | 1347.203 |
54+
55+
56+
### OPT30B
57+
58+
| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
59+
|------------------------|-------------------|---------------|---------------|----------|----------------------|
60+
| EM: Fodors-Zagats | 189 | 744 | 3 | 541.550 | 248.287 |
61+
| EM: Beer | 91 | 592 | 3 | 238.58 | 224.450 |
62+
| EM: iTunes-Amazon | 109 | 529 | 3 | 267.639 | 198.775 |
63+
| EM: Walmart-Amazon | 200 | 748 | 3 | 682.635 | 220.030 |
64+
| EM: Amazon-Google | 200 | 876 | 3 | 799.514 | 219.884 |
65+
| EM: DBLP-ACM | 200 | 1274 | 3 | 1119.272 | 228.184 |
66+
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 | 1271.534 | 190.636 |
67+
| DI: Restaurant | 86 | 123 | 5 | 60.310 | 169.790 |
68+
| DI: Buy | 65 | 488 | 10 | 185.882 | 160.747 |
69+
| ED: Hospital | 200 | 200 | 3 | 158.329 | 256.429 |
70+
71+
72+
### OPT175B
73+
74+
| Task | Tested Samples | Input Length | Output Length | Time (s) |Input + Output Throughput (token/s)|
75+
|------------------------|----------------|---------------|---------------|----------|----------------------|
76+
| EM: Fodors-Zagats | 189 | 744 | 3 |3928.310 | 34.228 |
77+
| EM: Beer | 91 | 592 | 3 |1356.786 | 35.083 |
78+
| EM: iTunes-Amazon | 109 | 529 | 3 |1569.062 | 33.906 |
79+
| EM: Walmart-Amazon | 200 | 748 | 3 |4171.319 | 36.008 |
80+
| EM: Amazon-Google | 200 | 876 | 3 |4893.572 | 35.925 |
81+
| EM: DBLP-ACM | 200 | 1274 | 3 |7624.726 | 33.496 |
82+
| EM: DBLP-GoogleScholar | 200 | 1209 | 3 |8275.828 | 29.290 |
83+
| DI: Restaurant | 86 | 123 | 5 |648.762 | 16.968 |
84+
| DI: Buy | 65 | 488 | 10 |2086.961 | 14.317 |
85+
| ED: Hospital | 200 | 200 | 3 |1154.133 | 35.178 |

flexgen/apps/data_wrangle/data_wrangle_run.py

+65-29
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
# The source code in this file is partially adapted from
2+
# https://github.com/HazyResearch/fm_data_tasks/blob/main/fm_data_tasks/utils/prompt_utils.py
3+
# which is under Apache License Version 2.0.
4+
15
"""Run inference."""
26
import argparse
7+
from tqdm import tqdm
38
import json
9+
import math
410
import logging
511
from pathlib import Path
612
import time
713
import numpy as np
814
from transformers import AutoTokenizer, AutoConfig
9-
# from manifest import Manifest
10-
1115
import flexgen.apps.data_wrangle.utils.data_utils as data_utils
1216
import flexgen.apps.data_wrangle.utils.prompt_utils as prompt_utils
1317
from flexgen.apps.data_wrangle.utils import constants
@@ -174,7 +178,10 @@ def parse_args() -> argparse.Namespace:
174178

175179

176180
def get_tokenizer(name):
177-
tokenizer = AutoTokenizer.from_pretrained(name, padding_side="left")
181+
if name == 'facebook/opt-175b':
182+
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-30b', padding_side="left")
183+
else:
184+
tokenizer = AutoTokenizer.from_pretrained(name, padding_side="left")
178185
tokenizer.add_bos_token = False
179186
if 'galactica' in name:
180187
config = AutoConfig.from_pretrained(name)
@@ -204,12 +211,11 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
204211
num_bits=4, group_size=64,
205212
group_dim=2, symmetric=False))
206213

207-
print(f"Init weights begin.")
214+
logger.info(f"Init weights begin.")
208215
tic = time.time()
209216
model = OptLM(args.model, env, args.path, policy)
210-
print(f"Init weights end. Elapsed: {time.time() - tic:.2f} s", flush=True)
217+
logger.info(f"Init weights end. Elapsed: {time.time() - tic:.2f} s")
211218

212-
213219
if args.add_task_instruction:
214220
prompt = lambda x: f"{task_instruction} {x}"
215221
else:
@@ -246,14 +252,13 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
246252
gt = test_data["label_str"]
247253
preds = []
248254
idx = 0
249-
# Run a few for printing -- they are cached
250255
for _ in range(args.num_print):
251256
logger.info(prompt(queries[idx]))
252257
tic = time.time()
253258
input_ids_tmp = tokenizer(prompt(queries[idx]), padding="max_length",
254259
return_tensors="np",
255260
max_length=args.pad_to_seq_len).input_ids
256-
print(input_ids_tmp.shape)
261+
logger.info(input_ids_tmp.shape)
257262
output_ids_tmp = model.generate(input_ids_tmp,
258263
do_sample=True,
259264
temperature=args.temperature,
@@ -292,6 +297,7 @@ def single_query_test(args, task_instruction, test_data, task, pd_data_files, te
292297
f"_{int(args.add_task_instruction)}inst"
293298
f"_{int(args.class_balanced)}cb"
294299
f"_{args.sample_method}"
300+
f"_{args.model}"
295301
f"_{args.num_print}run"
296302
f"_{int(args.dry_run)}dry" / f"trial_{trial_num}.feather"
297303
)
@@ -336,16 +342,17 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
336342
num_bits=4, group_size=64,
337343
group_dim=2, symmetric=False))
338344

339-
print(f"Init weights begin.")
345+
logger.info(f"Init weights begin.")
340346
tic = time.time()
341347
model = OptLM(args.model, env, args.path, policy)
342-
print(f"Init weights end. Elapsed: {time.time() - tic:.2f} s", flush=True)
348+
logger.info(f"Init weights end. Elapsed: {time.time() - tic:.2f} s.")
343349

344350
if args.add_task_instruction:
345351
prompt = lambda x: f"{task_instruction} {x}"
346352
else:
347353
prompt = lambda x: f"{x}"
348-
trial_metrics = {"prec": [], "rec": [], "f1": [], "acc": [], "throughput": []}
354+
trial_metrics = {"prec": [], "rec": [], "f1": [], "acc": [], "total_time": [],
355+
"output_throughput": [], "total_throughput": []}
349356

350357
saved_prefix = None
351358

@@ -377,31 +384,54 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
377384
preds = []
378385
idx = 0
379386

380-
# Run a few for printing -- they are cached
387+
max_prompt_seq_length = 0
381388
prompt_strs = []
382389
for _ in range(args.num_run):
383-
if idx == 0:
384-
logger.info(f"This is a sample prompt: {prompt(queries[idx])}")
390+
# if idx == 0:
391+
# logger.info(f"This is a sample prompt: {prompt(queries[idx])}")
385392
prompt_strs.append(prompt(queries[idx]))
386-
idx += 1
387393

394+
current_prompt_tmp = tokenizer(prompt(queries[idx]), padding="max_length",
395+
return_tensors="np", max_length=args.pad_to_seq_len).input_ids
396+
# logger.info(f"Current prompt <{idx}> length: {current_prompt_tmp.shape[1]}")
397+
max_prompt_seq_length = max(max_prompt_seq_length, current_prompt_tmp.shape[1])
398+
idx += 1
399+
400+
logger.info(f"max_prompt_seq_length: {max_prompt_seq_length}")
388401
tic = time.time()
389402

390-
input_ids_tmp = tokenizer(prompt_strs, padding="max_length",
403+
input_ids = tokenizer(prompt_strs, padding="max_length",
391404
return_tensors="np",
392-
max_length=args.pad_to_seq_len).input_ids
393-
output_ids_tmp = model.generate(input_ids_tmp,
394-
do_sample=True,
395-
temperature=args.temperature,
396-
max_new_tokens=args.max_tokens,
397-
stop=args.stop_token)
405+
max_length=max_prompt_seq_length).input_ids
406+
output_ids = []
407+
408+
flexgen_batch_size = args.gpu_batch_size*args.num_gpu_batches
409+
num_batched_run = math.floor(args.num_run/flexgen_batch_size)
410+
args.num_run = num_batched_run * flexgen_batch_size
411+
input_ids = input_ids[0:args.num_run]
412+
413+
for i in tqdm(range(num_batched_run)):
414+
input_ids_tmp = input_ids[i*flexgen_batch_size: (i+1)*flexgen_batch_size]
415+
output_ids_tmp = model.generate(input_ids_tmp,
416+
do_sample=True,
417+
temperature=args.temperature,
418+
max_new_tokens=args.max_tokens,
419+
stop=args.stop_token)
420+
output_ids.extend(output_ids_tmp)
421+
398422
toc = time.time()
399-
input_strs = tokenizer.batch_decode(input_ids_tmp, skip_special_tokens=True)
400-
output_strs = tokenizer.batch_decode(output_ids_tmp, skip_special_tokens=True)
401-
preds = [ output_strs[i][len(input_strs[i]):] for i in range(len(input_strs))]
423+
input_strs = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
424+
output_strs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
425+
preds = [output_strs[i][len(input_strs[i]):] for i in range(len(input_strs))]
402426

403-
throughput = args.num_run * args.max_tokens/(time.time() - tic)
404-
print(f"Batch inference run end. Elapsed: { toc - tic:.2f} s, Throughput: {throughput:.2f} token/s")
427+
total_time = time.time() - tic
428+
total_prompt_tokens = args.num_run * max_prompt_seq_length
429+
total_generate_tokens = args.num_run * args.max_tokens
430+
output_throughput = total_generate_tokens/total_time
431+
total_throughput = (total_prompt_tokens+total_generate_tokens)/total_time
432+
logger.info(f"Batch inference run end. Elapsed: {total_time:.2f} s;")
433+
logger.info(f"Output throughput: {output_throughput:.2f} token/s;")
434+
logger.info(f"Total throughput: {total_throughput:.2f} token/s;")
405435
# Save trial predictions
406436
save_data = test_data.iloc[:args.num_run].copy(deep=True).reset_index()
407437
gt = gt[:args.num_run]
@@ -412,13 +442,18 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
412442

413443
logger.info(
414444
f"Metrics Trial {trial_num}\n"
415-
f"Prec: {prec:.3f} Recall: {rec:.3f} Acc: {acc:.3f} F1: {f1:.3f} FlexGen Throughput: {throughput:.3f}"
445+
f"Prec: {prec:.3f} Recall: {rec:.3f} Acc: {acc:.3f} F1: {f1:.3f} \n"
446+
f"<FlexGen> time: {total_time:.3f} \n"
447+
f"<FlexGen> output throughput: {output_throughput:.3f} \n"
448+
f"<FlexGen> total throughput: {total_throughput:.3f}"
416449
)
417450
trial_metrics["rec"].append(rec)
418451
trial_metrics["prec"].append(prec)
419452
trial_metrics["acc"].append(acc)
420453
trial_metrics["f1"].append(f1)
421-
trial_metrics["throughput"].append(throughput)
454+
trial_metrics["total_time"].append(total_time)
455+
trial_metrics["output_throughput"].append(output_throughput)
456+
trial_metrics["total_throughput"].append(total_throughput)
422457

423458
output_file = (
424459
Path(args.output_dir)
@@ -429,6 +464,7 @@ def batch_query_test(args, task_instruction, test_data, task, pd_data_files, tes
429464
f"_{int(args.add_task_instruction)}inst"
430465
f"_{int(args.class_balanced)}cb"
431466
f"_{args.sample_method}"
467+
f"_{args.model}"
432468
f"_{args.num_run}run"
433469
f"_{int(args.dry_run)}dry" / f"trial_{trial_num}.feather"
434470
)

flexgen/apps/data_wrangle/install.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pip install pandas==1.4.2
2-
pip install sentence_transformers==2.2.0
2+
pip install sentence-transformers==2.2.2
33
pip install rich==12.2.0
44
pip install pyarrow==7.0.0
55

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
python3 ./data_wrangle_run.py\
2+
--num_run 189 \
3+
--num_trials 1 \
4+
--nan_tok "" \
5+
--do_test \
6+
--sample_method manual \
7+
--data_dir data/datasets/entity_matching/structured/Fodors-Zagats \
8+
--batch_run --pad-to-seq-len 744 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 15 --num-gpu-batches 6
9+
10+
python3 ./data_wrangle_run.py\
11+
--num_run 91 \
12+
--num_trials 1 \
13+
--nan_tok "" \
14+
--do_test \
15+
--sample_method manual \
16+
--data_dir data/datasets/entity_matching/structured/Beer \
17+
--batch_run --pad-to-seq-len 592 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 20 --num-gpu-batches 4
18+
19+
python3 ./data_wrangle_run.py\
20+
--num_run 109 \
21+
--num_trials 1 \
22+
--nan_tok "" \
23+
--do_test \
24+
--sample_method manual \
25+
--data_dir data/datasets/entity_matching/structured/iTunes-Amazon \
26+
--batch_run --pad-to-seq-len 529 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 20 --num-gpu-batches 5
27+
28+
python3 ./data_wrangle_run.py\
29+
--num_run 200 \
30+
--num_trials 1 \
31+
--nan_tok "" \
32+
--do_test \
33+
--sample_method manual \
34+
--data_dir data/datasets/entity_matching/structured/Walmart-Amazon \
35+
--batch_run --pad-to-seq-len 748 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 10 --num-gpu-batches 10
36+
37+
python3 ./data_wrangle_run.py\
38+
--num_run 200 \
39+
--num_trials 1 \
40+
--nan_tok "" \
41+
--do_test \
42+
--sample_method manual \
43+
--data_dir data/datasets/entity_matching/structured/Amazon-Google \
44+
--batch_run --pad-to-seq-len 876 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 10 --num-gpu-batches 10
45+
46+
python3 ./data_wrangle_run.py\
47+
--num_run 200 \
48+
--num_trials 1 \
49+
--nan_tok "" \
50+
--do_test \
51+
--sample_method manual \
52+
--data_dir data/datasets/entity_matching/structured/DBLP-ACM \
53+
--batch_run --pad-to-seq-len 1274 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 8 --num-gpu-batches 5
54+
55+
python3 ./data_wrangle_run.py\
56+
--num_run 200 \
57+
--num_trials 1 \
58+
--nan_tok "" \
59+
--do_test \
60+
--sample_method manual \
61+
--data_dir data/datasets/entity_matching/structured/DBLP-GoogleScholar \
62+
--batch_run --pad-to-seq-len 1209 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 8 --num-gpu-batches 5
63+
64+
python3 ./data_wrangle_run.py\
65+
--num_run 86 \
66+
--num_trials 1 \
67+
--max_tokens 5 \
68+
--do_test \
69+
--sample_method manual \
70+
--data_dir data/datasets/data_imputation/Restaurant \
71+
--batch_run --pad-to-seq-len 123 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 86 --num-gpu-batches 1
72+
73+
python3 ./data_wrangle_run.py\
74+
--num_run 65 \
75+
--num_trials 1 \
76+
--max_tokens 10 \
77+
--do_test \
78+
--sample_method manual \
79+
--data_dir data/datasets/data_imputation/Buy \
80+
--batch_run --pad-to-seq-len 488 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 30 --num-gpu-batches 2
81+
82+
python3 ./data_wrangle_run.py\
83+
--num_run 200 \
84+
--num_trials 1 \
85+
--do_test \
86+
--sample_method manual \
87+
--data_dir data/datasets/error_detection/Hospital \
88+
--batch_run --pad-to-seq-len 200 --model facebook/opt-175b --pin-weight 0 --cpu --percent 0 50 0 0 0 100 --gpu-batch-size 50 --num-gpu-batches 4

0 commit comments

Comments
 (0)