Skip to content

Commit 54b9f6d

Browse files
mpolson64facebook-github-bot
authored andcommitted
Fix tutorials test (#3458)
Summary: Pull Request resolved: #3458 Reviewed By: Cesar-Cardoso Differential Revision: D70631834 fbshipit-source-id: 4e59b4df955103f0eca73907cb80d48c3965c0b3
1 parent 3fc31f7 commit 54b9f6d

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

tutorials/automl/automl.ipynb

+15-15
Original file line numberDiff line numberDiff line change
@@ -504,10 +504,10 @@
504504
" )\n",
505505
"\n",
506506
" batch_size = assert_is_instance(parameters[\"batch_size\"], int)\n",
507+
" num_epochs = len(train_x) // batch_size\n",
507508
"\n",
508509
" start_time = time.time()\n",
509-
" is_early_stopped = False # Flag to indicate if the trial was early stopped\n",
510-
" for i in range(0, len(train_x) // batch_size):\n",
510+
" for i in range(0, num_epochs):\n",
511511
" start_idx = i * batch_size\n",
512512
" end_idx = (i + 1) * batch_size\n",
513513
"\n",
@@ -520,26 +520,26 @@
520520
" \"score\": clf.score(valid_x, valid_y),\n",
521521
" \"training_time\": time.time() - start_time,\n",
522522
" }\n",
523+
"\n",
524+
" # On the final epoch call complete_trial and break, else call attach_data\n",
525+
" if i == num_epochs - 1:\n",
526+
" client.complete_trial(\n",
527+
" trial_index=trial_index,\n",
528+
" raw_data=raw_data,\n",
529+
" progression=end_idx, # Use the index of the last example in the batch as the progression value\n",
530+
" )\n",
531+
" break\n",
532+
"\n",
523533
" client.attach_data(\n",
524534
" trial_index=trial_index,\n",
525535
" raw_data=raw_data,\n",
526-
" progression=end_idx, # Use the index of the last example in the batch as the progression value\n",
536+
" progression=end_idx,\n",
527537
" )\n",
528538
"\n",
529539
" # If the trial is underperforming, stop it\n",
530540
" if client.should_stop_trial_early(trial_index=trial_index):\n",
531-
" client.mark_trial_early_stopped(\n",
532-
" trial_index=trial_index,\n",
533-
" raw_data=raw_data,\n",
534-
" progression=end_idx,\n",
535-
" )\n",
536-
" is_early_stopped = True\n",
537-
" break\n",
538-
"\n",
539-
" if not is_early_stopped:\n",
540-
" # If the trial was not early stopped, mark it as completed\n",
541-
" client.complete_trial(trial_index=trial_index)\n",
542-
" print(f\"Completed trial {trial_index} with {parameters=}, {raw_data=}\")"
541+
" client.mark_trial_early_stopped(trial_index=trial_index)\n",
542+
" break\n"
543543
]
544544
},
545545
{

tutorials/early_stopping/early_stopping.ipynb

+1-3
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,7 @@
352352
"\n",
353353
" # If the trial is underperforming, stop it\n",
354354
" if client.should_stop_trial_early(trial_index=trial_index):\n",
355-
" client.mark_trial_early_stopped(\n",
356-
" trial_index=trial_index, raw_data=raw_data, progression=t\n",
357-
" )\n",
355+
" client.mark_trial_early_stopped(trial_index=trial_index)\n",
358356
" break"
359357
]
360358
},

0 commit comments

Comments
 (0)