From 423749aa09e21049021bf8429a5cd83db03745d1 Mon Sep 17 00:00:00 2001 From: TrellixVulnTeam Date: Mon, 21 Nov 2022 08:06:18 +0000 Subject: [PATCH] Adding tarfile member sanitization to extractall() --- .../training/ms_marco/eval_msmarco.py | 21 ++++++- .../multilingual/translate_queries.py | 21 ++++++- .../ms_marco/train_bi-encoder_margin-mse.py | 42 ++++++++++++- .../ms_marco/train_bi-encoder_mnrl.py | 63 ++++++++++++++++++- .../ms_marco/train_cross-encoder_kd.py | 42 ++++++++++++- .../ms_marco/train_cross-encoder_scratch.py | 42 ++++++++++++- 6 files changed, 220 insertions(+), 11 deletions(-) diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/eval_msmarco.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/eval_msmarco.py index 88a278e..b244982 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/eval_msmarco.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/eval_msmarco.py @@ -46,7 +46,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) if not os.path.exists(qrels_filepath): diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/multilingual/translate_queries.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/multilingual/translate_queries.py index ef726b3..3554463 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/multilingual/translate_queries.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/multilingual/translate_queries.py @@ -61,7 +61,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(queries_filepath, 'r', encoding='utf8') as fIn: diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_margin-mse.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_margin-mse.py index 2c29494..88c13be 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_margin-mse.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_margin-mse.py @@ -86,7 +86,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) logging.info("Read corpus: collection.tsv") with open(collection_filepath, 'r', encoding='utf8') as fIn: @@ -106,7 +125,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(queries_filepath, 'r', encoding='utf8') as fIn: diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_mnrl.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_mnrl.py index 8ac807c..62e51f7 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_mnrl.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_bi-encoder_mnrl.py @@ -221,7 +221,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) if not os.path.exists(queries_filepath): tar_filepath = os.path.join(data_folder, 'queries.tar.gz') @@ -230,7 +249,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) # Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid) # to the CrossEncoder score computed by the cross-encoder/ms-marco-MiniLM-L-6-v2 model @@ -423,7 +461,26 @@ def __len__(self): tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) if not os.path.exists(qrels_filepath): util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv', qrels_filepath) diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_kd.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_kd.py index e78c17d..a5778a9 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_kd.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_kd.py @@ -64,7 +64,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(collection_filepath, 'r', encoding='utf8') as fIn: for line in fIn: @@ -82,7 +101,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(queries_filepath, 'r', encoding='utf8') as fIn: diff --git a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_scratch.py b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_scratch.py index 67e4374..46bff78 100644 --- a/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_scratch.py +++ b/biencoder/nli_msmarco/sentence-transformers/examples/training/ms_marco/train_cross-encoder_scratch.py @@ -69,7 +69,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(collection_filepath, 'r', encoding='utf8') as fIn: for line in fIn: @@ -87,7 +106,26 @@ util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) with tarfile.open(tar_filepath, "r:gz") as tar: - tar.extractall(path=data_folder) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, path=data_folder) with open(queries_filepath, 'r', encoding='utf8') as fIn: