@@ -689,14 +689,48 @@ def impl_update_training_progress(model_name, dataset_name, progress_text, bucke
689
689
with open (local_file , "w" ) as f :
690
690
f .write (progress_text )
691
691
client = make_client ("s3" , region )
692
- remote_path = "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
693
- os .path .basename (dataset_name ) + "/progress.txt"
692
+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/progress.txt"
694
693
client .upload_file (path , bucket_name , remote_path )
695
694
print ("Updated progress in: {}\n " .format (remote_path ))
696
695
finally :
697
696
os .remove (path )
698
697
699
698
699
+ def impl_create_heartbeat (model_name , dataset_name , bucket_name , region ):
700
+ """
701
+ Creates a heartbeat that Axon uses to check if the training script is running properly.
702
+
703
+ :param model_name: The filename of the model.
704
+ :param dataset_name: The filename of the dataset.
705
+ :param bucket_name: The S3 bucket name.
706
+ :param region: The region, or `None` to pull the region from the environment.
707
+ """
708
+ client = make_client ("s3" , region )
709
+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
710
+ client .put_object (Body = "1" , Bucket = bucket_name , Key = remote_path )
711
+ print ("Created heartbeat file in: {}\n " .format (remote_path ))
712
+
713
+
714
+ def impl_remove_heartbeat (model_name , dataset_name , bucket_name , region ):
715
+ """
716
+ Removes a heartbeat that Axon uses to check if the training script is running properly.
717
+
718
+ :param model_name: The filename of the model.
719
+ :param dataset_name: The filename of the dataset.
720
+ :param bucket_name: The S3 bucket name.
721
+ :param region: The region, or `None` to pull the region from the environment.
722
+ """
723
+ client = make_client ("s3" , region )
724
+ remote_path = create_progress_prefix (model_name , dataset_name ) + "/heartbeat.txt"
725
+ client .put_object (Body = "0" , Bucket = bucket_name , Key = remote_path )
726
+ print ("Removed heartbeat file in: {}\n " .format (remote_path ))
727
+
728
+
729
+ def create_progress_prefix (model_name , dataset_name ):
730
+ return "axon-training-progress/" + os .path .basename (model_name ) + "/" + \
731
+ os .path .basename (dataset_name )
732
+
733
+
700
734
@click .group ()
701
735
def cli ():
702
736
return
@@ -902,3 +936,35 @@ def update_training_progress(model_name, dataset_name, progress_text, region):
902
936
"""
903
937
impl_update_training_progress (model_name , dataset_name , progress_text , ensure_s3_bucket (region ),
904
938
region )
939
+
940
+
941
+ @cli .command (name = "create-heartbeat" )
942
+ @click .argument ("model-name" )
943
+ @click .argument ("dataset-name" )
944
+ @click .option ("--region" , help = "The region to connect to." ,
945
+ type = click .Choice (region_choices ))
946
+ def create_heartbeat (model_name , dataset_name , region ):
947
+ """
948
+ Creates a heartbeat that Axon uses to check if the training script is running properly.
949
+
950
+ MODEL_NAME The filename of the model currently being trained.
951
+
952
+ DATASET_NAME The name of the dataset currently being trained on.
953
+ """
954
+ impl_create_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
955
+
956
+
957
+ @cli .command (name = "remove-heartbeat" )
958
+ @click .argument ("model-name" )
959
+ @click .argument ("dataset-name" )
960
+ @click .option ("--region" , help = "The region to connect to." ,
961
+ type = click .Choice (region_choices ))
962
+ def remove_heartbeat (model_name , dataset_name , region ):
963
+ """
964
+ Removes a heartbeat that Axon uses to check if the training script is running properly.
965
+
966
+ MODEL_NAME The filename of the model currently being trained.
967
+
968
+ DATASET_NAME The name of the dataset currently being trained on.
969
+ """
970
+ impl_remove_heartbeat (model_name , dataset_name , ensure_s3_bucket (region ), region )
0 commit comments