6060"""
6161
6262import io
63+ import os
6364import struct
6465from typing import Optional , List , Dict , BinaryIO , cast , Union
6566
@@ -84,6 +85,9 @@ class ArchiveFormatError(Exception):
8485class ArchiveAccessError (IOError ):
8586 """ Raised on problems with accessing the archived files """
8687 pass
88+ class ExtractBreakoutAttempt (IOError ):
89+ """ Raised on files which would be extracted above specified path """
90+ pass
8791
8892class ArchiveFileHeader (object ):
8993 """ File header of an archived file, or a special data segment """
@@ -417,6 +421,47 @@ def open(self, name: Union[bytes,ArchiveFileHeader]) -> ArchiveFileData:
417421
418422 raise ValueError ("Can't look up file using type %s, expected bytes or ArchiveFileHeader" % (type (name ),))
419423
424+ def extract (self , member : bytes , path = None ):
425+ filename = os .path .basename (member )
426+ if isinstance (path , bytes ):
427+ filepath = os .path .join (path , filename )
428+ else :
429+ filepath = os .path .join (path .encode ('utf-8' ), filename )
430+ buf_size = 8 * 1024
431+
432+ ar_member = self .open (member )
433+ with open (filepath , 'wb' ) as f :
434+ while True :
435+ buffer = ar_member .read (buf_size )
436+ if not len (buffer ):
437+ break
438+ f .write (buffer )
439+
440+ def extractall (self , path : Union [str , bytes ], members : Optional [List [bytes ]]= None ):
441+ self .read_all_headers ()
442+
443+ normpath = os .path .normpath (path )
444+ if isinstance (normpath , str ):
445+ normpath = normpath .encode ('utf-8' )
446+
447+ if members is None :
448+ sources = self .archived_files .keys ()
449+ else :
450+ sources = members
451+
452+ for member in sources :
453+ member_dir = os .path .dirname (member )
454+ member_name = os .path .basename (member )
455+ filepath = os .path .join (normpath , member_dir , member_name )
456+ norm_filepath = os .path .normpath (filepath )
457+
458+ if os .path .commonpath ([normpath , norm_filepath ]) != normpath :
459+ raise ExtractBreakoutAttempt ("file %s would be extracted below specified path" % (member ,))
460+
461+ norm_dirpath = os .path .normpath (os .path .join (normpath , member_dir ))
462+ os .makedirs (norm_dirpath , exist_ok = True )
463+ self .extract (member , norm_filepath )
464+
420465 def __enter__ (self ):
421466 return self
422467
0 commit comments