Skip to content

Commit d92502f

Browse files
Cli fixes and improvements (#25)
* Revamp cli args (#45) * Rachel/follow (#46) * Add fine_tunes.follow. Add better error handling for disconnected streams * return early * fix an oops * lint * Nicer strings * ensure end token is not applied to classification (#44) * ensure end token is not applied to classification * black Co-authored-by: Boris Power <[email protected]>
1 parent dc15660 commit d92502f

File tree

3 files changed

+67
-47
lines changed

3 files changed

+67
-47
lines changed

openai/cli.py

+65-45
Original file line numberDiff line numberDiff line change
@@ -286,34 +286,26 @@ def create(cls, args):
286286
create_args["validation_file"] = cls._get_or_upload(
287287
args.validation_file, args.check_if_files_exist
288288
)
289-
if args.model:
290-
create_args["model"] = args.model
291-
if args.n_epochs:
292-
create_args["n_epochs"] = args.n_epochs
293-
if args.batch_size:
294-
create_args["batch_size"] = args.batch_size
295-
if args.learning_rate_multiplier:
296-
create_args["learning_rate_multiplier"] = args.learning_rate_multiplier
297-
create_args["use_packing"] = args.use_packing
298-
if args.prompt_loss_weight:
299-
create_args["prompt_loss_weight"] = args.prompt_loss_weight
300-
if args.compute_classification_metrics:
301-
create_args[
302-
"compute_classification_metrics"
303-
] = args.compute_classification_metrics
304-
if args.classification_n_classes:
305-
create_args["classification_n_classes"] = args.classification_n_classes
306-
if args.classification_positive_class:
307-
create_args[
308-
"classification_positive_class"
309-
] = args.classification_positive_class
310-
if args.classification_betas:
311-
betas = [float(x) for x in args.classification_betas.split(",")]
312-
create_args["classification_betas"] = betas
289+
290+
for hparam in (
291+
"model",
292+
"n_epochs",
293+
"batch_size",
294+
"learning_rate_multiplier",
295+
"prompt_loss_weight",
296+
"use_packing",
297+
"compute_classification_metrics",
298+
"classification_n_classes",
299+
"classification_positive_class",
300+
"classification_betas",
301+
):
302+
attr = getattr(args, hparam)
303+
if attr is not None:
304+
create_args[hparam] = attr
313305

314306
resp = openai.FineTune.create(**create_args)
315307

316-
if args.no_wait:
308+
if args.no_follow:
317309
print(resp)
318310
return
319311

@@ -345,20 +337,32 @@ def results(cls, args):
345337

346338
@classmethod
347339
def events(cls, args):
348-
if not args.stream:
349-
resp = openai.FineTune.list_events(id=args.id) # type: ignore
350-
print(resp)
351-
return
340+
if args.stream:
341+
raise openai.error.OpenAIError(
342+
message=(
343+
"The --stream parameter is deprecated, use fine_tunes.follow "
344+
"instead:\n\n"
345+
" openai api fine_tunes.follow -i {id}\n".format(id=args.id)
346+
),
347+
)
348+
349+
resp = openai.FineTune.list_events(id=args.id) # type: ignore
350+
print(resp)
351+
352+
@classmethod
353+
def follow(cls, args):
352354
cls._stream_events(args.id)
353355

354356
@classmethod
355357
def _stream_events(cls, job_id):
356358
def signal_handler(sig, frame):
357359
status = openai.FineTune.retrieve(job_id).status
358360
sys.stdout.write(
359-
"\nStream interrupted. Job is still {status}. "
361+
"\nStream interrupted. Job is still {status}.\n"
362+
"To resume the stream, run:\n\n"
363+
" openai api fine_tunes.follow -i {job_id}\n\n"
360364
"To cancel your job, run:\n\n"
361-
"openai api fine_tunes.cancel -i {job_id}\n".format(
365+
" openai api fine_tunes.cancel -i {job_id}\n\n".format(
362366
status=status, job_id=job_id
363367
)
364368
)
@@ -368,16 +372,24 @@ def signal_handler(sig, frame):
368372

369373
events = openai.FineTune.stream_events(job_id)
370374
# TODO(rachel): Add a nifty spinner here.
371-
for event in events:
372-
sys.stdout.write(
373-
"[%s] %s"
374-
% (
375-
datetime.datetime.fromtimestamp(event["created_at"]),
376-
event["message"],
375+
try:
376+
for event in events:
377+
sys.stdout.write(
378+
"[%s] %s"
379+
% (
380+
datetime.datetime.fromtimestamp(event["created_at"]),
381+
event["message"],
382+
)
377383
)
384+
sys.stdout.write("\n")
385+
sys.stdout.flush()
386+
except Exception:
387+
sys.stdout.write(
388+
"\nStream interrupted (client disconnected).\n"
389+
"To resume the stream, run:\n\n"
390+
" openai api fine_tunes.follow -i {job_id}\n\n".format(job_id=job_id)
378391
)
379-
sys.stdout.write("\n")
380-
sys.stdout.flush()
392+
return
381393

382394
resp = openai.FineTune.retrieve(id=job_id)
383395
status = resp["status"]
@@ -688,9 +700,9 @@ def help(args):
688700
help="The model to start fine-tuning from",
689701
)
690702
sub.add_argument(
691-
"--no_wait",
703+
"--no_follow",
692704
action="store_true",
693-
help="If set, returns immediately after creating the job. Otherwise, waits for the job to complete.",
705+
help="If set, returns immediately after creating the job. Otherwise, streams events and waits for the job to complete.",
694706
)
695707
sub.add_argument(
696708
"--n_epochs",
@@ -727,7 +739,7 @@ def help(args):
727739
dest="use_packing",
728740
help="Disables the packing flag (see --use_packing for description)",
729741
)
730-
sub.set_defaults(use_packing=True)
742+
sub.set_defaults(use_packing=None)
731743
sub.add_argument(
732744
"--prompt_loss_weight",
733745
type=float,
@@ -741,6 +753,7 @@ def help(args):
741753
help="If set, we calculate classification-specific metrics such as accuracy "
742754
"and F-1 score using the validation set at the end of every epoch.",
743755
)
756+
sub.set_defaults(compute_classification_metrics=None)
744757
sub.add_argument(
745758
"--classification_n_classes",
746759
type=int,
@@ -755,10 +768,11 @@ def help(args):
755768
)
756769
sub.add_argument(
757770
"--classification_betas",
771+
type=float,
772+
nargs="+",
758773
help="If this is provided, we calculate F-beta scores at the specified beta "
759774
"values. The F-beta score is a generalization of F-1 score. This is only "
760-
"used for binary classification. The expected format is a comma-separated "
761-
"list - e.g. 1,1.5,2",
775+
"used for binary classification.",
762776
)
763777
sub.set_defaults(func=FineTune.create)
764778

@@ -772,15 +786,21 @@ def help(args):
772786

773787
sub = subparsers.add_parser("fine_tunes.events")
774788
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
789+
790+
# TODO(rachel): Remove this in 1.0
775791
sub.add_argument(
776792
"-s",
777793
"--stream",
778794
action="store_true",
779-
help="If set, events will be streamed until the job is done. Otherwise, "
795+
help="[DEPRECATED] If set, events will be streamed until the job is done. Otherwise, "
780796
"displays the event history to date.",
781797
)
782798
sub.set_defaults(func=FineTune.events)
783799

800+
sub = subparsers.add_parser("fine_tunes.follow")
801+
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
802+
sub.set_defaults(func=FineTune.follow)
803+
784804
sub = subparsers.add_parser("fine_tunes.cancel")
785805
sub.add_argument("-i", "--id", required=True, help="The id of the fine-tune job")
786806
sub.set_defaults(func=FineTune.cancel)

openai/validators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def common_completion_suffix_validator(df):
326326
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
327327

328328
ft_type = infer_task_type(df)
329-
if ft_type == "open-ended generation":
329+
if ft_type == "open-ended generation" or ft_type == "classification":
330330
return Remediation(name="common_suffix")
331331

332332
def add_suffix(x, suffix):

openai/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.9.2"
1+
VERSION = "0.9.3"

0 commit comments

Comments
 (0)