|
504 | 504 | " )\n",
|
505 | 505 | "\n",
|
506 | 506 | " batch_size = assert_is_instance(parameters[\"batch_size\"], int)\n",
|
| 507 | + " num_epochs = len(train_x) // batch_size\n", |
507 | 508 | "\n",
|
508 | 509 | " start_time = time.time()\n",
|
509 |
| - " is_early_stopped = False # Flag to indicate if the trial was early stopped\n", |
510 |
| - " for i in range(0, len(train_x) // batch_size):\n", |
| 510 | + " for i in range(0, num_epochs):\n", |
511 | 511 | " start_idx = i * batch_size\n",
|
512 | 512 | " end_idx = (i + 1) * batch_size\n",
|
513 | 513 | "\n",
|
|
520 | 520 | " \"score\": clf.score(valid_x, valid_y),\n",
|
521 | 521 | " \"training_time\": time.time() - start_time,\n",
|
522 | 522 | " }\n",
|
| 523 | + "\n", |
| 524 | + " # On the final epoch call complete_trial and break, else call attach_data\n", |
| 525 | + " if i == num_epochs - 1:\n", |
| 526 | + " client.complete_trial(\n", |
| 527 | + " trial_index=trial_index,\n", |
| 528 | + " raw_data=raw_data,\n", |
| 529 | + " progression=end_idx, # Use the index of the last example in the batch as the progression value\n", |
| 530 | + " )\n", |
| 531 | + " break\n", |
| 532 | + "\n", |
523 | 533 | " client.attach_data(\n",
|
524 | 534 | " trial_index=trial_index,\n",
|
525 | 535 | " raw_data=raw_data,\n",
|
526 |
| - " progression=end_idx, # Use the index of the last example in the batch as the progression value\n", |
| 536 | + " progression=end_idx,\n", |
527 | 537 | " )\n",
|
528 | 538 | "\n",
|
529 | 539 | " # If the trial is underperforming, stop it\n",
|
530 | 540 | " if client.should_stop_trial_early(trial_index=trial_index):\n",
|
531 |
| - " client.mark_trial_early_stopped(\n", |
532 |
| - " trial_index=trial_index,\n", |
533 |
| - " raw_data=raw_data,\n", |
534 |
| - " progression=end_idx,\n", |
535 |
| - " )\n", |
536 |
| - " is_early_stopped = True\n", |
537 |
| - " break\n", |
538 |
| - "\n", |
539 |
| - " if not is_early_stopped:\n", |
540 |
| - " # If the trial was not early stopped, mark it as completed\n", |
541 |
| - " client.complete_trial(trial_index=trial_index)\n", |
542 |
| - " print(f\"Completed trial {trial_index} with {parameters=}, {raw_data=}\")" |
| 541 | + " client.mark_trial_early_stopped(trial_index=trial_index)\n", |
| 542 | + " break\n" |
543 | 543 | ]
|
544 | 544 | },
|
545 | 545 | {
|
|
0 commit comments