diff --git a/superai/meta_ai/ai.py b/superai/meta_ai/ai.py index 33f87cbd..6ddd7a63 100644 --- a/superai/meta_ai/ai.py +++ b/superai/meta_ai/ai.py @@ -352,7 +352,26 @@ def load_from_s3(cls, path: str, weights_path: Optional[str] = None) -> "AI": log.info(f"Downloading and unpacking AI object from bucket `{bucket_name}` and path `{path_to_object}`") s3.download_file(bucket_name, path_to_object, os.path.join(download_folder, "AISavedModel.tar.gz")) with tarfile.open(os.path.join(download_folder, "AISavedModel.tar.gz")) as tar: - tar.extractall(path=os.path.join(download_folder, "AISavedModel")) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner) + + + safe_extract(tar, path=os.path.join(download_folder,"AISavedModel")) return cls.load_local( load_path=os.path.join(download_folder, "AISavedModel", "ai"), weights_path=weights_path, diff --git a/superai/meta_ai/base/base_ai.py b/superai/meta_ai/base/base_ai.py index a4279633..7f1262db 100644 --- a/superai/meta_ai/base/base_ai.py +++ b/superai/meta_ai/base/base_ai.py @@ -131,7 +131,26 @@ def _pull_weights(weights_uri: str, output_path: str) -> str: log.info(f"Downloading and unpacking AI object from bucket `{bucket_name}` and path `{path_to_object}`") s3.download_file(bucket_name, path_to_object, os.path.join(output_path, object_name)) with tarfile.open(os.path.join(output_path, object_name)) as tar: - tar.extractall(path=full_path) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=full_path) log.info(f"Successfully downloaded and unpacked weights to path `{full_path}`") else: BaseModel._pull_s3_folder(weights_uri, full_path) diff --git a/tests/meta_ai/test_unit_ai.py b/tests/meta_ai/test_unit_ai.py index 1f45b34d..132eb501 100644 --- a/tests/meta_ai/test_unit_ai.py +++ b/tests/meta_ai/test_unit_ai.py @@ -42,7 +42,26 @@ def test_compression(): another_folder_path = os.path.join(".AISave", "another_folder") os.makedirs(another_folder_path) with tarfile.open(path_to_tarfile) as tar: - tar.extractall(path=another_folder_path) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner) + + + safe_extract(tar, path=another_folder_path) for i in range(1, 5): assert os.path.exists(os.path.join(another_folder_path, f"{i}_file.txt")) shutil.rmtree(folder_path)