Skip to content

Commit af3ea18

Browse files
committed
Add extract calls
API to match the zipfile module.
1 parent 36b113d commit af3ea18

2 files changed

Lines changed: 84 additions & 0 deletions

File tree

arpy.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"""
6161

6262
import io
63+
import os
6364
import struct
6465
from typing import Optional, List, Dict, BinaryIO, cast, Union
6566

@@ -84,6 +85,9 @@ class ArchiveFormatError(Exception):
8485
class ArchiveAccessError(IOError):
8586
""" Raised on problems with accessing the archived files """
8687
pass
88+
class ExtractBreakoutAttempt(IOError):
89+
""" Raised on files which would be extracted above specified path """
90+
pass
8791

8892
class ArchiveFileHeader(object):
8993
""" File header of an archived file, or a special data segment """
@@ -417,6 +421,47 @@ def open(self, name: Union[bytes,ArchiveFileHeader]) -> ArchiveFileData:
417421

418422
raise ValueError("Can't look up file using type %s, expected bytes or ArchiveFileHeader" % (type(name),))
419423

424+
def extract(self, member: bytes, path=None):
425+
filename = os.path.basename(member)
426+
if isinstance(path, bytes):
427+
filepath = os.path.join(path, filename)
428+
else:
429+
filepath = os.path.join(path.encode('utf-8'), filename)
430+
buf_size = 8*1024
431+
432+
ar_member = self.open(member)
433+
with open(filepath, 'wb') as f:
434+
while True:
435+
buffer = ar_member.read(buf_size)
436+
if not len(buffer):
437+
break
438+
f.write(buffer)
439+
440+
def extractall(self, path: Union[str, bytes], members: Optional[List[bytes]]=None):
441+
self.read_all_headers()
442+
443+
normpath = os.path.normpath(path)
444+
if isinstance(normpath, str):
445+
normpath = normpath.encode('utf-8')
446+
447+
if members is None:
448+
sources = self.archived_files.keys()
449+
else:
450+
sources = members
451+
452+
for member in sources:
453+
member_dir = os.path.dirname(member)
454+
member_name = os.path.basename(member)
455+
filepath = os.path.join(normpath, member_dir, member_name)
456+
norm_filepath = os.path.normpath(filepath)
457+
458+
if os.path.commonpath([normpath, norm_filepath]) != normpath:
459+
raise ExtractBreakoutAttempt("file %s would be extracted below specified path" % (member,))
460+
461+
norm_dirpath = os.path.normpath(os.path.join(normpath, member_dir))
462+
os.makedirs(norm_dirpath, exist_ok=True)
463+
self.extract(member, norm_filepath)
464+
420465
def __enter__(self):
421466
return self
422467

test/test_contents.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33
import os
44
import io
5+
from unittest.mock import patch, mock_open, call
56

67
class ArContents(unittest.TestCase):
78
def test_archive_contents(self):
@@ -13,6 +14,44 @@ def test_archive_contents(self):
1314
self.assertEqual(b'test_in_file_2\n', f2_contents)
1415
ar.close()
1516

17+
def test_extract(self):
18+
m = mock_open()
19+
with arpy.Archive(os.path.join(os.path.dirname(__file__), 'contents.ar')) as ar:
20+
with patch('arpy.open', m):
21+
ar.extract(b'file1', '/foobar')
22+
23+
m().write.assert_called_once_with(b'test_in_file_1\n')
24+
m().__exit__.assert_called_once_with(None, None, None)
25+
26+
def test_extract_byte_path(self):
27+
m = mock_open()
28+
with arpy.Archive(os.path.join(os.path.dirname(__file__), 'contents.ar')) as ar:
29+
with patch('arpy.open', m):
30+
ar.extract(b'file1', b'/foobar')
31+
32+
m().write.assert_called_once_with(b'test_in_file_1\n')
33+
m().__exit__.assert_called_once_with(None, None, None)
34+
35+
def test_extractall(self):
36+
with arpy.Archive(os.path.join(os.path.dirname(__file__), 'contents.ar')) as ar:
37+
with patch.object(ar, 'extract') as m_extract:
38+
with patch('os.makedirs') as m_makedirs:
39+
ar.extractall('/foobar')
40+
41+
m_extract.assert_has_calls([
42+
call(b'file1', b'/foobar/file1'),
43+
call(b'file2', b'/foobar/file2'),
44+
])
45+
m_makedirs.assert_called_with(b'/foobar', exist_ok=True)
46+
47+
def test_extractall2(self):
48+
with arpy.Archive(os.path.join(os.path.dirname(__file__), 'contents.ar')) as ar:
49+
with patch.object(ar, 'extract') as m_extract:
50+
with patch('os.makedirs') as m_makedirs:
51+
ar.extractall('/foobar', [b'file2'])
52+
53+
m_extract.assert_called_once_with(b'file2', b'/foobar/file2')
54+
m_makedirs.assert_called_once_with(b'/foobar', exist_ok=True)
1655

1756
class ArZipLike(unittest.TestCase):
1857
def setUp(self):

0 commit comments

Comments
 (0)