@@ -603,34 +603,6 @@ def impl_download_untrained_model(model_path, bucket_name, region):
603
603
print ("Downloaded from: {}\n " .format (key ))
604
604
605
605
606
- def impl_upload_trained_model (model_path , bucket_name , region ):
607
- """
608
- Uploads an trained model to S3.
609
-
610
- :param model_path: The file path to the model to upload, ending with the name of the model.
611
- :param bucket_name: The S3 bucket name.
612
- :param region: The region, or `None` to pull the region from the environment.
613
- """
614
- client = make_client ("s3" , region )
615
- key = "axon-trained-models/" + os .path .basename (model_path )
616
- client .upload_file (model_path , bucket_name , key )
617
- print ("Uploaded to: {}\n " .format (key ))
618
-
619
-
620
- def impl_download_trained_model (model_path , bucket_name , region ):
621
- """
622
- Downloads an trained model from S3.
623
-
624
- :param model_path: The file path to download to, ending with the name of the model.
625
- :param bucket_name: The S3 bucket name.
626
- :param region: The region, or `None` to pull the region from the environment.
627
- """
628
- client = make_client ("s3" , region )
629
- key = "axon-trained-models/" + os .path .basename (model_path )
630
- client .download_file (bucket_name , key , model_path )
631
- print ("Downloaded from: {}\n " .format (key ))
632
-
633
-
634
606
def impl_download_training_script (script_path , bucket_name , region ):
635
607
"""
636
608
Downloads a training script from S3.
@@ -723,6 +695,16 @@ def impl_remove_heartbeat(job_id, bucket_name, region):
723
695
print ("Removed heartbeat file in: {}\n " .format (remote_path ))
724
696
725
697
698
+ def impl_upload_training_results (job_id , output_dir , bucket_name , region ):
699
+ client = make_client ("s3" , region )
700
+ files_to_upload = [os .path .join (output_dir , it ) for it in os .listdir (output_dir )]
701
+ files_to_upload = [it for it in files_to_upload if os .path .isfile (it )]
702
+ for elem in files_to_upload :
703
+ key = "axon-training-results/{}/{}" .format (job_id , os .path .basename (elem ))
704
+ client .upload_file (elem , bucket_name , key )
705
+ print ("Uploaded to: {}\n " .format (key ))
706
+
707
+
726
708
def create_progress_prefix (job_id ):
727
709
return "axon-training-progress/{}" .format (job_id )
728
710
@@ -848,32 +830,6 @@ def download_untrained_model(model_path, region):
848
830
impl_download_untrained_model (model_path , ensure_s3_bucket (region ), region )
849
831
850
832
851
- @cli .command (name = "upload-trained-model" )
852
- @click .argument ("model-path" )
853
- @click .option ("--region" , help = "The region to connect to." ,
854
- type = click .Choice (region_choices ))
855
- def upload_trained_model (model_path , region ):
856
- """
857
- Uploads a trained model from a local file.
858
-
859
- MODEL_PATH The path to the model to upload, ending with the name of the model.
860
- """
861
- impl_upload_trained_model (model_path , ensure_s3_bucket (region ), region )
862
-
863
-
864
- @cli .command (name = "download-trained-model" )
865
- @click .argument ("model-path" )
866
- @click .option ("--region" , help = "The region to connect to." ,
867
- type = click .Choice (region_choices ))
868
- def download_trained_model (model_path , region ):
869
- """
870
- Downloads a trained model to a local file.
871
-
872
- MODEL_PATH The path to download the model to, ending with the name of the model.
873
- """
874
- impl_download_trained_model (model_path , ensure_s3_bucket (region ), region )
875
-
876
-
877
833
@cli .command (name = "download-training-script" )
878
834
@click .argument ("script-path" )
879
835
@click .option ("--region" , help = "The region to connect to." ,
@@ -955,3 +911,19 @@ def remove_heartbeat(job_id, region):
955
911
JOB_ID The unique Job ID.
956
912
"""
957
913
impl_remove_heartbeat (job_id , ensure_s3_bucket (region ), region )
914
+
915
+
916
+ @cli .command (name = "upload-training-results" )
917
+ @click .argument ("job-id" )
918
+ @click .argument ("output-dir" )
919
+ @click .option ("--region" , help = "The region to connect to." ,
920
+ type = click .Choice (region_choices ))
921
+ def upload_training_results (job_id , output_dir , region ):
922
+ """
923
+ Uploads the results from running a training script.
924
+
925
+ JOB_ID The unique Job ID.
926
+
927
+ OUTPUT_DIR The directory containing the results.
928
+ """
929
+ impl_upload_training_results (job_id , output_dir , ensure_s3_bucket (region ), region )
0 commit comments