Skip to content

Commit 8a78e5c

Browse files
authored
Add --output-dir to run-csv-items (#896)
1 parent 602c62d commit 8a78e5c

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

src/modelgauge/main.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,20 @@ def show_progress(data):
370370
help="Directory to cache model answers (only applies to SUTs).",
371371
type=click.Path(file_okay=False, dir_okay=True, writable=True, path_type=pathlib.Path),
372372
)
373+
@click.option(
374+
"--output-dir",
375+
"-o",
376+
default=".",
377+
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
378+
)
373379
@click.option("--debug", is_flag=True, help="Show internal pipeline debugging information.")
374380
@click.argument(
375381
"input_path",
376382
type=click.Path(exists=True, path_type=pathlib.Path),
377383
)
378-
def run_csv_items(sut_uids, annotator_uids, workers, cache_dir, debug, input_path, max_tokens, temp, top_p, top_k):
384+
def run_csv_items(
385+
sut_uids, annotator_uids, workers, cache_dir, output_dir, debug, input_path, max_tokens, temp, top_p, top_k
386+
):
379387
"""Run rows in a CSV through some SUTs and/or annotators.
380388
381389
If running SUTs, the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
@@ -405,23 +413,22 @@ def run_csv_items(sut_uids, annotator_uids, workers, cache_dir, debug, input_pat
405413
print(sut_options)
406414

407415
# Create correct pipeline runner based on input.
408-
output_path = input_path.parent
409416
if suts and annotators:
410417
pipeline_runner = PromptPlusAnnotatorRunner(
411418
workers,
412419
input_path,
413-
output_path,
420+
output_dir,
414421
cache_dir,
415422
sut_options,
416423
suts=suts,
417424
annotators=annotators,
418425
)
419426
elif suts:
420-
pipeline_runner = PromptRunner(workers, input_path, output_path, cache_dir, sut_options, suts=suts)
427+
pipeline_runner = PromptRunner(workers, input_path, output_dir, cache_dir, sut_options, suts=suts)
421428
elif annotators:
422429
if max_tokens is not None or temp is not None or top_p is not None or top_k is not None:
423430
warnings.warn(f"Received SUT options but only running annotators. Options will not be used.")
424-
pipeline_runner = AnnotatorRunner(workers, input_path, output_path, cache_dir, None, annotators=annotators)
431+
pipeline_runner = AnnotatorRunner(workers, input_path, output_dir, cache_dir, None, annotators=annotators)
425432
else:
426433
raise ValueError("Must specify at least one SUT or annotator.")
427434

tests/modelgauge_tests/test_cli.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ def test_run_prompts_normal(caplog, tmp_path):
149149
runner = CliRunner()
150150
result = runner.invoke(
151151
main.modelgauge_cli,
152-
["run-csv-items", "--sut", "demo_yes_no", str(in_path)],
152+
["run-csv-items", "--sut", "demo_yes_no", "-o", tmp_path, str(in_path)],
153153
catch_exceptions=False,
154154
)
155155

156156
assert result.exit_code == 0
157157

158158
out_path = re.findall(r"\S+\.csv", caplog.text)[0]
159-
with open(tmp_path / out_path, "r") as f:
159+
with open(out_path, "r") as f:
160160
reader = csv.DictReader(f)
161161

162162
rows = (next(reader), next(reader))
@@ -175,7 +175,7 @@ def test_run_prompts_invalid_sut(arg_name, tmp_path):
175175
runner = CliRunner()
176176
result = runner.invoke(
177177
main.modelgauge_cli,
178-
["run-csv-items", arg_name, "unknown-uid", str(in_path)],
178+
["run-csv-items", arg_name, "unknown-uid", "-o", tmp_path, str(in_path)],
179179
catch_exceptions=False,
180180
)
181181

@@ -188,7 +188,7 @@ def test_run_prompts_multiple_invalid_suts(tmp_path):
188188
runner = CliRunner()
189189
result = runner.invoke(
190190
main.modelgauge_cli,
191-
["run-csv-items", "--sut", "unknown-uid1", "--sut", "unknown-uid2", str(in_path)],
191+
["run-csv-items", "--sut", "unknown-uid1", "--sut", "unknown-uid2", "-o", tmp_path, str(in_path)],
192192
catch_exceptions=False,
193193
)
194194

@@ -203,7 +203,7 @@ def test_run_prompts_invalid_annotator(sut_uid, tmp_path):
203203
runner = CliRunner()
204204
result = runner.invoke(
205205
main.modelgauge_cli,
206-
["run-csv-items", "--sut", sut_uid, "--annotator", "unknown-uid", str(in_path)],
206+
["run-csv-items", "--sut", sut_uid, "--annotator", "unknown-uid", "-o", tmp_path, str(in_path)],
207207
catch_exceptions=False,
208208
)
209209

@@ -225,6 +225,8 @@ def test_run_prompts_with_annotators(caplog, tmp_path):
225225
"demo_annotator",
226226
"--workers",
227227
"5",
228+
"-o",
229+
tmp_path,
228230
str(in_path),
229231
],
230232
catch_exceptions=False,
@@ -233,7 +235,7 @@ def test_run_prompts_with_annotators(caplog, tmp_path):
233235

234236
out_path = re.findall(r"\S+\.jsonl", caplog.text)[0]
235237
output = []
236-
with jsonlines.open(tmp_path / out_path) as reader:
238+
with jsonlines.open(out_path) as reader:
237239
output.append(reader.read())
238240
output.append(reader.read())
239241
assert {
@@ -271,6 +273,8 @@ def test_run_prompts_with_options(mock_translate_text_prompt, tmp_path, extra_op
271273
"0.5",
272274
"--top-k",
273275
"0",
276+
"-o",
277+
tmp_path,
274278
str(in_path),
275279
*extra_options,
276280
],
@@ -293,7 +297,7 @@ def test_run_prompts_bad_sut(tmp_path):
293297
runner = CliRunner()
294298
result = runner.invoke(
295299
main.modelgauge_cli,
296-
["run-csv-items", "--sut", "noreqs", str(in_path)],
300+
["run-csv-items", "--sut", "noreqs", "-o", tmp_path, str(in_path)],
297301
catch_exceptions=False,
298302
)
299303
assert result.exit_code == 2
@@ -317,14 +321,16 @@ def test_run_annotators(caplog, tmp_path):
317321
"run-csv-items",
318322
"--annotator",
319323
"demo_annotator",
324+
"-o",
325+
tmp_path,
320326
str(in_path),
321327
],
322328
catch_exceptions=False,
323329
)
324330
assert result.exit_code == 0
325331

326332
out_path = re.findall(r"\S+\.jsonl", caplog.text)[0]
327-
with jsonlines.open(tmp_path / out_path) as reader:
333+
with jsonlines.open(out_path) as reader:
328334
assert reader.read() == {
329335
"UID": "p1",
330336
"Prompt": "Say yes",
@@ -356,6 +362,8 @@ def test_run_annotators_with_sut_options(tmp_path, option_name, option_val):
356362
"demo_annotator",
357363
f"--{option_name}",
358364
option_val,
365+
"-o",
366+
tmp_path,
359367
str(in_path),
360368
],
361369
catch_exceptions=False,

0 commit comments

Comments
 (0)