1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ import io
16+ import tarfile
17+ from unittest import mock
18+
19+ import pytest
1520
1621import recirq
1722from recirq .readout_scan .tasks import ReadoutScanTask
@@ -35,6 +40,23 @@ def test_display_markdown_docstring():
3540"""
3641
3742
38- def test_fetch_guide_data_collection_data (tmpdir ):
39- recirq .fetch_guide_data_collection_data (base_dir = tmpdir )
40- assert os .path .exists (f'{ tmpdir } /2020-02-tutorial' )
43+ @mock .patch ('urllib.request.urlopen' )
44+ def test_fetch_guide_data_collection_data_traversal (mock_urlopen , tmpdir ):
45+ # Create a malicious tarball in memory.
46+ malicious_tar_stream = io .BytesIO ()
47+ with tarfile .open (fileobj = malicious_tar_stream , mode = 'w:xz' ) as tf :
48+ # Add a file that tries to write outside the target directory
49+ malicious_info = tarfile .TarInfo (name = "../../tmp/pwned" )
50+ tf .addfile (malicious_info , io .BytesIO (b"pwned" ))
51+ malicious_tar_stream .seek (0 )
52+
53+ # Read the stream into a BytesIO object so that the mock should return a
54+ # response object whose read() method returns the tarball content.
55+ mock_response = mock .Mock ()
56+ mock_response .read .return_value = malicious_tar_stream .getvalue ()
57+ mock_urlopen .return_value = mock_response
58+
59+ with pytest .raises (ValueError , match = "Encountered untrusted path" ):
60+ recirq .fetch_guide_data_collection_data (base_dir = tmpdir )
61+
62+ assert not os .path .exists ('/tmp/pwned' )
0 commit comments