Skip to content
This repository was archived by the owner on Sep 26, 2020. It is now read-only.

Commit 35eb053

Browse files
committed
Add upload training results
1 parent 5b9ae9d commit 35eb053

File tree

1 file changed

+26
-54
lines changed

1 file changed

+26
-54
lines changed

axon/client.py

+26-54
Original file line numberDiff line numberDiff line change
@@ -603,34 +603,6 @@ def impl_download_untrained_model(model_path, bucket_name, region):
603603
print("Downloaded from: {}\n".format(key))
604604

605605

606-
def impl_upload_trained_model(model_path, bucket_name, region):
607-
"""
608-
Uploads an trained model to S3.
609-
610-
:param model_path: The file path to the model to upload, ending with the name of the model.
611-
:param bucket_name: The S3 bucket name.
612-
:param region: The region, or `None` to pull the region from the environment.
613-
"""
614-
client = make_client("s3", region)
615-
key = "axon-trained-models/" + os.path.basename(model_path)
616-
client.upload_file(model_path, bucket_name, key)
617-
print("Uploaded to: {}\n".format(key))
618-
619-
620-
def impl_download_trained_model(model_path, bucket_name, region):
621-
"""
622-
Downloads an trained model from S3.
623-
624-
:param model_path: The file path to download to, ending with the name of the model.
625-
:param bucket_name: The S3 bucket name.
626-
:param region: The region, or `None` to pull the region from the environment.
627-
"""
628-
client = make_client("s3", region)
629-
key = "axon-trained-models/" + os.path.basename(model_path)
630-
client.download_file(bucket_name, key, model_path)
631-
print("Downloaded from: {}\n".format(key))
632-
633-
634606
def impl_download_training_script(script_path, bucket_name, region):
635607
"""
636608
Downloads a training script from S3.
@@ -723,6 +695,16 @@ def impl_remove_heartbeat(job_id, bucket_name, region):
723695
print("Removed heartbeat file in: {}\n".format(remote_path))
724696

725697

698+
def impl_upload_training_results(job_id, output_dir, bucket_name, region):
699+
client = make_client("s3", region)
700+
files_to_upload = [os.path.join(output_dir, it) for it in os.listdir(output_dir)]
701+
files_to_upload = [it for it in files_to_upload if os.path.isfile(it)]
702+
for elem in files_to_upload:
703+
key = "axon-training-results/{}/{}".format(job_id, os.path.basename(elem))
704+
client.upload_file(elem, bucket_name, key)
705+
print("Uploaded to: {}\n".format(key))
706+
707+
726708
def create_progress_prefix(job_id):
727709
return "axon-training-progress/{}".format(job_id)
728710

@@ -848,32 +830,6 @@ def download_untrained_model(model_path, region):
848830
impl_download_untrained_model(model_path, ensure_s3_bucket(region), region)
849831

850832

851-
@cli.command(name="upload-trained-model")
852-
@click.argument("model-path")
853-
@click.option("--region", help="The region to connect to.",
854-
type=click.Choice(region_choices))
855-
def upload_trained_model(model_path, region):
856-
"""
857-
Uploads a trained model from a local file.
858-
859-
MODEL_PATH The path to the model to upload, ending with the name of the model.
860-
"""
861-
impl_upload_trained_model(model_path, ensure_s3_bucket(region), region)
862-
863-
864-
@cli.command(name="download-trained-model")
865-
@click.argument("model-path")
866-
@click.option("--region", help="The region to connect to.",
867-
type=click.Choice(region_choices))
868-
def download_trained_model(model_path, region):
869-
"""
870-
Downloads a trained model to a local file.
871-
872-
MODEL_PATH The path to download the model to, ending with the name of the model.
873-
"""
874-
impl_download_trained_model(model_path, ensure_s3_bucket(region), region)
875-
876-
877833
@cli.command(name="download-training-script")
878834
@click.argument("script-path")
879835
@click.option("--region", help="The region to connect to.",
@@ -955,3 +911,19 @@ def remove_heartbeat(job_id, region):
955911
JOB_ID The unique Job ID.
956912
"""
957913
impl_remove_heartbeat(job_id, ensure_s3_bucket(region), region)
914+
915+
916+
@cli.command(name="upload-training-results")
917+
@click.argument("job-id")
918+
@click.argument("output-dir")
919+
@click.option("--region", help="The region to connect to.",
920+
type=click.Choice(region_choices))
921+
def upload_training_results(job_id, output_dir, region):
922+
"""
923+
Uploads the results from running a training script.
924+
925+
JOB_ID The unique Job ID.
926+
927+
OUTPUT_DIR The directory containing the results.
928+
"""
929+
impl_upload_training_results(job_id, output_dir, ensure_s3_bucket(region), region)

0 commit comments

Comments
 (0)