Skip to content

Commit 602c62d

Browse files
authored
Analytics cli (#894)
* run-stuff is straight copy of run-csv-items * only 1 sut per run * fix small bug * correct output directory structure * The runners are responsible for setting up the output directory * write empty metadata * rename sub_dir_name to run_id * basic metadata * annotation and response counts in metadata * No more --cache-dir * run-stuff is now run-csv * black * mypy * more details about how validate_uid works * log instead of print * use logger instead of warnings in main * rename to run-job
1 parent 18e16c1 commit 602c62d

File tree

6 files changed

+430
-18
lines changed

6 files changed

+430
-18
lines changed

src/modelgauge/annotation_pipeline.py

+6
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,25 @@ def handle_uncached_item(self, item):
141141

142142
class AnnotatorSink(Sink):
143143
unfinished: defaultdict[SutInteraction, dict[str, str]]
144+
sut_response_counts: defaultdict[str, int]
145+
annotation_counts: defaultdict[str, int]
144146

145147
def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput):
146148
super().__init__()
147149
self.annotators = annotators
148150
self.writer = writer
149151
self.unfinished = defaultdict(lambda: dict())
152+
self.sut_response_counts = defaultdict(lambda: 0)
153+
self.annotation_counts = defaultdict(lambda: 0)
150154

151155
def run(self):
152156
with self.writer:
153157
super().run()
154158

155159
def handle_item(self, item):
156160
sut_interaction, annotator_uid, annotation = item
161+
self.sut_response_counts[sut_interaction.sut_uid] += 1
162+
self.annotation_counts[annotator_uid] += 1
157163
if isinstance(annotation, BaseModel):
158164
annotation = annotation.model_dump()
159165
self.unfinished[sut_interaction][annotator_uid] = annotation

src/modelgauge/command_line.py

+3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def validate_uid(ctx, param, value):
101101
"""Callback function for click.option UID validation.
102102
Raises a BadParameter exception if the user-supplied arg(s) are not valid UIDs.
103103
Applicable for parameters '--sut', '--test', and '--annotator'.
104+
If no UID is provided (e.g. an empty list or `None`), the value is returned as-is.
104105
"""
106+
if not value:
107+
return value
105108
# Identify what object we are validating UIDs for.
106109
if "--sut" in param.opts:
107110
registry = SUTS

src/modelgauge/main.py

+108-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import datetime
1+
import logging
22
import os
33
import pathlib
44
import warnings
@@ -35,6 +35,8 @@
3535
from modelgauge.sut_registry import SUTS
3636
from modelgauge.test_registry import TESTS
3737

38+
logger = logging.getLogger(__name__)
39+
3840

3941
@modelgauge_cli.command(name="list")
4042
@LOCAL_PLUGIN_DIR_OPTION
@@ -233,6 +235,109 @@ def run_test(
233235
print("Full TestRecord json written to", output_file)
234236

235237

238+
@modelgauge_cli.command()
239+
@sut_options_options
240+
@click.option(
241+
"sut_uid",
242+
"-s",
243+
"--sut",
244+
help="Which registered SUT to run.",
245+
multiple=False,
246+
required=False,
247+
callback=validate_uid,
248+
)
249+
@click.option(
250+
"annotator_uids",
251+
"-a",
252+
"--annotator",
253+
help="Which registered annotator(s) to run",
254+
multiple=True,
255+
required=False,
256+
callback=validate_uid,
257+
)
258+
@click.option(
259+
"--workers",
260+
type=int,
261+
default=None,
262+
help="Number of worker threads, default is 10 * number of SUTs.",
263+
)
264+
@click.option(
265+
"--output-dir",
266+
"-o",
267+
default="airr_data/runs",
268+
type=click.Path(file_okay=False, dir_okay=True, path_type=pathlib.Path),
269+
)
270+
@click.option("--tag", type=str, help="Tag to include in the output directory name.")
271+
@click.option("--debug", is_flag=True, help="Show internal pipeline debugging information.")
272+
@click.argument(
273+
"input_path",
274+
type=click.Path(exists=True, path_type=pathlib.Path),
275+
)
276+
def run_job(sut_uid, annotator_uids, workers, output_dir, tag, debug, input_path, max_tokens, temp, top_p, top_k):
277+
"""Run rows in a CSV through some SUTs and/or annotators.
278+
279+
If running a SUT, the file must have 'UID' and 'Text' columns. The output will be saved to a CSV file.
280+
If running ONLY annotators, the file must have 'UID', 'Prompt', 'SUT', and 'Response' columns. The output will be saved to a json lines file.
281+
"""
282+
# Check all objects for missing secrets.
283+
secrets = load_secrets_from_config()
284+
if sut_uid:
285+
check_secrets(secrets, sut_uids=[sut_uid], annotator_uids=annotator_uids)
286+
else:
287+
check_secrets(secrets, annotator_uids=annotator_uids)
288+
289+
if sut_uid:
290+
sut = SUTS.make_instance(sut_uid, secrets=secrets)
291+
if AcceptsTextPrompt not in sut.capabilities:
292+
raise click.BadParameter(f"{sut_uid} does not accept text prompts")
293+
suts = {sut_uid: sut}
294+
295+
annotators = {
296+
annotator_uid: ANNOTATORS.make_instance(annotator_uid, secrets=secrets) for annotator_uid in annotator_uids
297+
}
298+
299+
# Get all SUT options
300+
sut_options = create_sut_options(max_tokens, temp, top_p, top_k)
301+
302+
# Create correct pipeline runner based on input.
303+
if sut_uid and annotators:
304+
pipeline_runner = PromptPlusAnnotatorRunner(
305+
workers,
306+
input_path,
307+
output_dir,
308+
None,
309+
sut_options,
310+
tag,
311+
suts=suts,
312+
annotators=annotators,
313+
)
314+
elif sut_uid:
315+
pipeline_runner = PromptRunner(workers, input_path, output_dir, None, sut_options, tag, suts=suts)
316+
elif annotators:
317+
if max_tokens is not None or temp is not None or top_p is not None or top_k is not None:
318+
logger.warning(f"Received SUT options but only running annotators. Options will not be used.")
319+
pipeline_runner = AnnotatorRunner(workers, input_path, output_dir, None, None, tag, annotators=annotators)
320+
else:
321+
raise ValueError("Must specify at least one SUT or annotator.")
322+
323+
with click.progressbar(
324+
length=pipeline_runner.num_total_items,
325+
label=f"Processing {pipeline_runner.num_input_items} input items"
326+
+ (f" * 1 SUT" if sut_uid else "")
327+
+ (f" * {len(annotators)} annotators" if annotators else "")
328+
+ ":",
329+
) as bar:
330+
last_complete_count = 0
331+
332+
def show_progress(data):
333+
nonlocal last_complete_count
334+
complete_count = data["completed"]
335+
bar.update(complete_count - last_complete_count)
336+
last_complete_count = complete_count
337+
338+
pipeline_runner.run(show_progress, debug)
339+
340+
236341
@modelgauge_cli.command()
237342
@sut_options_options
238343
@click.option(
@@ -300,9 +405,8 @@ def run_csv_items(sut_uids, annotator_uids, workers, cache_dir, debug, input_pat
300405
print(sut_options)
301406

302407
# Create correct pipeline runner based on input.
303-
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
408+
output_path = input_path.parent
304409
if suts and annotators:
305-
output_path = input_path.parent / pathlib.Path(input_path.stem + "-annotated-responses" + timestamp + ".jsonl")
306410
pipeline_runner = PromptPlusAnnotatorRunner(
307411
workers,
308412
input_path,
@@ -313,13 +417,11 @@ def run_csv_items(sut_uids, annotator_uids, workers, cache_dir, debug, input_pat
313417
annotators=annotators,
314418
)
315419
elif suts:
316-
output_path = input_path.parent / pathlib.Path(input_path.stem + "-responses-" + timestamp + ".csv")
317420
pipeline_runner = PromptRunner(workers, input_path, output_path, cache_dir, sut_options, suts=suts)
318421
elif annotators:
319422
if max_tokens is not None or temp is not None or top_p is not None or top_k is not None:
320423
warnings.warn(f"Received SUT options but only running annotators. Options will not be used.")
321-
output_path = input_path.parent / pathlib.Path(input_path.stem + "-annotations-" + timestamp + ".jsonl")
322-
pipeline_runner = AnnotatorRunner(workers, input_path, output_path, cache_dir, annotators=annotators)
424+
pipeline_runner = AnnotatorRunner(workers, input_path, output_path, cache_dir, None, annotators=annotators)
323425
else:
324426
raise ValueError("Must specify at least one SUT or annotator.")
325427

@@ -340,8 +442,6 @@ def show_progress(data):
340442

341443
pipeline_runner.run(show_progress, debug)
342444

343-
print(f"output saved to {output_path}")
344-
345445

346446
def main():
347447
modelgauge_cli()

src/modelgauge/pipeline_runner.py

+120-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from abc import ABC, abstractmethod
2+
import datetime
3+
import json
4+
import logging
25

36
from modelgauge.annotation_pipeline import (
47
AnnotatorAssigner,
@@ -18,15 +21,20 @@
1821
CsvPromptOutput,
1922
)
2023

24+
logger = logging.getLogger(__name__)
25+
2126

2227
class PipelineRunner(ABC):
23-
def __init__(self, num_workers, input_path, output_path, cache_dir, sut_options=None):
28+
def __init__(self, num_workers, input_path, output_dir, cache_dir, sut_options, tag=None):
2429
self.num_workers = num_workers
2530
self.input_path = input_path
26-
self.output_path = output_path
31+
self.root_dir = output_dir
2732
self.cache_dir = cache_dir
2833
self.sut_options = sut_options
34+
self.tag = tag
2935
self.pipeline_segments = []
36+
self.start_time = datetime.datetime.now()
37+
self.finish_time = None
3038

3139
self._initialize_segments()
3240

@@ -44,13 +52,46 @@ def num_total_items(self):
4452
"""Total number of items to process."""
4553
pass
4654

55+
def metadata(self):
56+
duration = self.finish_time - self.start_time
57+
hours, minutes, seconds = str(duration).split(":")
58+
duration_string = f"{hours}h{minutes}m{seconds}s"
59+
60+
metadata = {
61+
"run_id": self.run_id,
62+
"run_info": {
63+
"started": str(self.start_time),
64+
"finished": str(self.finish_time),
65+
"duration": duration_string,
66+
},
67+
"input": {
68+
"source": self.input_path.name,
69+
"num_items": self.num_input_items,
70+
},
71+
}
72+
return metadata
73+
74+
def output_dir(self):
75+
output_path = self.root_dir / self.run_id
76+
if not output_path.exists():
77+
logger.info(f"Creating output dir {output_path}")
78+
output_path.mkdir(parents=True)
79+
return output_path
80+
4781
def run(self, progress_callback, debug):
4882
pipeline = Pipeline(
4983
*self.pipeline_segments,
5084
progress_callback=progress_callback,
5185
debug=debug,
5286
)
5387
pipeline.run()
88+
self.finish_time = datetime.datetime.now()
89+
logger.info(f"\noutput saved to {self.output_dir() / self.output_file_name}")
90+
self._write_metadata()
91+
92+
@staticmethod
93+
def format_date(date):
94+
return date.strftime("%Y%m%d-%H%M%S")
5495

5596
@abstractmethod
5697
def _initialize_segments(self):
@@ -62,7 +103,7 @@ def _add_prompt_segments(self, suts, include_sink=True):
62103
self.pipeline_segments.append(PromptSutAssigner(suts))
63104
self.pipeline_segments.append(PromptSutWorkers(suts, self.num_workers, cache_path=self.cache_dir))
64105
if include_sink:
65-
output = CsvPromptOutput(self.output_path, suts)
106+
output = CsvPromptOutput(self.output_dir() / self.output_file_name, suts)
66107
self.pipeline_segments.append(PromptSink(suts, output))
67108

68109
def _add_annotator_segments(self, annotators, include_source=True):
@@ -71,9 +112,45 @@ def _add_annotator_segments(self, annotators, include_source=True):
71112
self.pipeline_segments.append(AnnotatorSource(input))
72113
self.pipeline_segments.append(AnnotatorAssigner(annotators))
73114
self.pipeline_segments.append(AnnotatorWorkers(annotators, self.num_workers))
74-
output = JsonlAnnotatorOutput(self.output_path)
115+
output = JsonlAnnotatorOutput(self.output_dir() / self.output_file_name)
75116
self.pipeline_segments.append(AnnotatorSink(annotators, output))
76117

118+
def _annotator_metadata(self):
119+
counts = self.pipeline_segments[-1].annotation_counts
120+
return {
121+
"annotators": [
122+
{
123+
"uid": uid,
124+
}
125+
for uid, annotator in self.annotators.items()
126+
],
127+
"annotations": {
128+
"count": sum(counts.values()),
129+
"by_annotator": {uid: {"count": count} for uid, count in counts.items()},
130+
},
131+
}
132+
133+
def _sut_metadata(self):
134+
counts = self.pipeline_segments[-1].sut_response_counts
135+
return {
136+
"suts": [
137+
{
138+
"uid": uid,
139+
"initialization_record": sut.initialization_record.model_dump(),
140+
"sut_options": self.sut_options.model_dump(exclude_none=True),
141+
}
142+
for uid, sut in self.suts.items()
143+
],
144+
"responses": {
145+
"count": sum(counts.values()),
146+
"by_sut": {uid: {"count": count} for uid, count in counts.items()},
147+
},
148+
}
149+
150+
def _write_metadata(self):
151+
with open(self.output_dir() / "metadata.json", "w") as f:
152+
json.dump(self.metadata(), f, indent=4)
153+
77154

78155
class PromptRunner(PipelineRunner):
79156
def __init__(self, *args, suts):
@@ -84,6 +161,19 @@ def __init__(self, *args, suts):
84161
def num_total_items(self):
85162
return self.num_input_items * len(self.suts)
86163

164+
@property
165+
def output_file_name(self):
166+
return "prompt-responses.csv"
167+
168+
@property
169+
def run_id(self):
170+
timestamp = self.format_date(self.start_time)
171+
base_subdir_name = timestamp + "-" + self.tag if self.tag else timestamp
172+
return f"{base_subdir_name}-{'-'.join(self.suts.keys())}"
173+
174+
def metadata(self):
175+
return {**super().metadata(), **self._sut_metadata()}
176+
87177
def _initialize_segments(self):
88178
self._add_prompt_segments(self.suts, include_sink=True)
89179

@@ -98,6 +188,19 @@ def __init__(self, *args, suts, annotators):
98188
def num_total_items(self):
99189
return self.num_input_items * len(self.suts) * len(self.annotators)
100190

191+
@property
192+
def output_file_name(self):
193+
return "prompt-responses-annotated.jsonl"
194+
195+
@property
196+
def run_id(self):
197+
timestamp = self.format_date(self.start_time)
198+
base_subdir_name = timestamp + "-" + self.tag if self.tag else timestamp
199+
return f"{base_subdir_name}-{'-'.join(self.suts.keys())}-{'-'.join(self.annotators.keys())}"
200+
201+
def metadata(self):
202+
return {**super().metadata(), **self._sut_metadata(), **self._annotator_metadata()}
203+
101204
def _initialize_segments(self):
102205
# Hybrid pipeline: prompt source + annotator sink
103206
self._add_prompt_segments(self.suts, include_sink=False)
@@ -113,5 +216,18 @@ def __init__(self, *args, annotators):
113216
def num_total_items(self):
114217
return self.num_input_items * len(self.annotators)
115218

219+
@property
220+
def output_file_name(self):
221+
return "annotations.jsonl"
222+
223+
@property
224+
def run_id(self):
225+
timestamp = self.format_date(self.start_time)
226+
base_subdir_name = timestamp + "-" + self.tag if self.tag else timestamp
227+
return f"{base_subdir_name}-{'-'.join(self.annotators.keys())}"
228+
229+
def metadata(self):
230+
return {**super().metadata(), **self._annotator_metadata()}
231+
116232
def _initialize_segments(self):
117233
self._add_annotator_segments(self.annotators, include_source=True)

0 commit comments

Comments
 (0)