Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit ccf9232

Browse files
committed
Add create and remove heartbeat
1 parent fc12020 commit ccf9232

File tree

1 file changed

+68
-2
lines changed

1 file changed

+68
-2
lines changed

axon/client.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -689,14 +689,48 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689689
with open(local_file, "w") as f:
690690
f.write(progress_text)
691691
client = make_client("s3", region)
692-
remote_path = "axon-training-progress/" + os.path.basename(model_name) + "/" + \
693-
os.path.basename(dataset_name) + "/progress.txt"
692+
remote_path = create_progress_prefix(model_name, dataset_name) + "/progress.txt"
694693
client.upload_file(path, bucket_name, remote_path)
695694
print("Updated progress in: {}\n".format(remote_path))
696695
finally:
697696
os.remove(path)
698697

699698

699+
def impl_create_heartbeat(model_name, dataset_name, bucket_name, region):
700+
"""
701+
Creates a heartbeat that Axon uses to check if the training script is running properly.
702+
703+
:param model_name: The filename of the model.
704+
:param dataset_name: The filename of the dataset.
705+
:param bucket_name: The S3 bucket name.
706+
:param region: The region, or `None` to pull the region from the environment.
707+
"""
708+
client = make_client("s3", region)
709+
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
710+
client.put_object(Body="1", Bucket=bucket_name, Key=remote_path)
711+
print("Created heartbeat file in: {}\n".format(remote_path))
712+
713+
714+
def impl_remove_heartbeat(model_name, dataset_name, bucket_name, region):
715+
"""
716+
Removes a heartbeat that Axon uses to check if the training script is running properly.
717+
718+
:param model_name: The filename of the model.
719+
:param dataset_name: The filename of the dataset.
720+
:param bucket_name: The S3 bucket name.
721+
:param region: The region, or `None` to pull the region from the environment.
722+
"""
723+
client = make_client("s3", region)
724+
remote_path = create_progress_prefix(model_name, dataset_name) + "/heartbeat.txt"
725+
client.put_object(Body="0", Bucket=bucket_name, Key=remote_path)
726+
print("Removed heartbeat file in: {}\n".format(remote_path))
727+
728+
729+
def create_progress_prefix(model_name, dataset_name):
730+
return "axon-training-progress/" + os.path.basename(model_name) + "/" + \
731+
os.path.basename(dataset_name)
732+
733+
700734
@click.group()
701735
def cli():
702736
return
@@ -902,3 +936,35 @@ def update_training_progress(model_name, dataset_name, progress_text, region):
902936
"""
903937
impl_update_training_progress(model_name, dataset_name, progress_text, ensure_s3_bucket(region),
904938
region)
939+
940+
941+
@cli.command(name="create-heartbeat")
942+
@click.argument("model-name")
943+
@click.argument("dataset-name")
944+
@click.option("--region", help="The region to connect to.",
945+
type=click.Choice(region_choices))
946+
def create_heartbeat(model_name, dataset_name, region):
947+
"""
948+
Creates a heartbeat that Axon uses to check if the training script is running properly.
949+
950+
MODEL_NAME The filename of the model currently being trained.
951+
952+
DATASET_NAME The name of the dataset currently being trained on.
953+
"""
954+
impl_create_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)
955+
956+
957+
@cli.command(name="remove-heartbeat")
958+
@click.argument("model-name")
959+
@click.argument("dataset-name")
960+
@click.option("--region", help="The region to connect to.",
961+
type=click.Choice(region_choices))
962+
def remove_heartbeat(model_name, dataset_name, region):
963+
"""
964+
Removes a heartbeat that Axon uses to check if the training script is running properly.
965+
966+
MODEL_NAME The filename of the model currently being trained.
967+
968+
DATASET_NAME The name of the dataset currently being trained on.
969+
"""
970+
impl_remove_heartbeat(model_name, dataset_name, ensure_s3_bucket(region), region)

0 commit comments

Comments
 (0)