Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion acquire/outputs/tar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import copy
import io
import shutil
import tarfile
from typing import TYPE_CHECKING, BinaryIO

Expand Down Expand Up @@ -100,7 +102,52 @@ def write(
if stat:
info.mtime = stat.st_mtime

self.tar.addfile(info, fh)
# Inline version of Python stdlib's tarfile.addfile & tarfile.copyfileobj,
# to allow for padding and more control over the tar file writing.
self.tar._check("awx")

if fh is None and info.isreg() and info.size != 0:
raise ValueError("fileobj not provided for non zero-size regular file")

info = copy.copy(info)

buf = info.tobuf(self.tar.format, self.tar.encoding, self.tar.errors)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could make this even safer by truncating to the previous offset/tar member end if any exception occurs while writing.

self.tar.fileobj.write(buf)
self.tar.offset += len(buf)
bufsize = self.tar.copybufsize
if fh is not None:
bufsize = bufsize or 16 * 1024

if info.size == 0:
return
if info.size is None:
shutil.copyfileobj(fh, self.tar.fileobj, bufsize)
return
Comment on lines +123 to +125
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this block since it would be an illegal action in this context.


blocks, remainder = divmod(info.size, bufsize)
for _ in range(blocks):
# Prevents "long reads" because it reads at max bufsize bytes at a time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long or short?

buf = fh.read(bufsize)
if len(buf) < bufsize:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can generalize this case instead of doing it twice. Keep track of how many bytes you actually wrote (i.e. using .tell() and only pad once.

# PATCH; instead of raising an exception, pad the data to the desired length
buf += tarfile.NUL * (bufsize - len(buf))
self.tar.fileobj.write(buf)

if remainder != 0:
# Prevents "long reads" because it reads at max bufsize bytes at a time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long or short?

buf = fh.read(remainder)
if len(buf) < remainder:
# PATCH; instead of raising an exception, pad the data to the desired length
buf += tarfile.NUL * (remainder - len(buf))
self.tar.fileobj.write(buf)

blocks, remainder = divmod(info.size, tarfile.BLOCKSIZE)
if remainder > 0:
self.tar.fileobj.write(tarfile.NUL * (tarfile.BLOCKSIZE - remainder))
blocks += 1
self.tar.offset += blocks * tarfile.BLOCKSIZE

self.tar.members.append(info)

def close(self) -> None:
"""Closes the tar file."""
Expand Down
75 changes: 75 additions & 0 deletions tests/test_outputs_tar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import io
import tarfile
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -63,3 +64,77 @@ def test_tar_output_encrypt(mock_fs: VirtualFilesystem, public_key: bytes, tmp_p

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
assert entry.open().read() == tar_file.extractfile(entry_name).read()


def test_tar_output_race_condition_with_shrinking_file(tmp_path: Path, public_key: bytes) -> None:
class ShrinkingFile(io.BytesIO):
"""
A file-like object that returns 5 bytes less than required.
Simulates a file on disk that has shrunk in between the time of
determining the size and actually reading the data.
"""

def __init__(self, data: bytes):
super().__init__(data)

def read(self, size: int) -> bytes:
return super().read(size - 5)

content = b"some text"

content_padded = content[:-5] + tarfile.NUL * 5
file = ShrinkingFile(content)

tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key)
tar_output.write("file.log", file)
tar_output.close()
file.close()

encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem"))
decrypted_path = tmp_path / "decrypted.tar"

# Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly
Path(decrypted_path).write_bytes(encrypted_stream.read())

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
member = tar_file.getmember("file.log")
extracted = tar_file.extractfile(member).read()
# The content should be padded with zeros to match the original size, despite the fact that the file shrunk
assert extracted == content_padded


def test_tar_output_race_condition_with_growing_file(tmp_path: Path, public_key: bytes) -> None:
class GrowingFile(io.BytesIO):
"""
A file-like object that returns 3 extra bytes.
Simulates a file on disk that has grown in between the time of
determining the size and actually reading the data.
"""

def __init__(self, data: bytes):
super().__init__(data)

def read(self, size: int) -> bytes:
return super().read(size) + b"FOX"

content = b"some text"

file = GrowingFile(content)

tar_output = TarOutput(tmp_path / "race.tar", encrypt=True, public_key=public_key)
tar_output.write("file.log", file)
tar_output.close()
file.close()

encrypted_stream = EncryptedFile(tar_output.path.open("rb"), Path("tests/_data/private_key.pem"))
decrypted_path = tmp_path / "decrypted.tar"

# Direct streaming is not an option because tarfile needs seek when reading from encrypted files directly
Path(decrypted_path).write_bytes(encrypted_stream.read())

with tarfile.open(name=decrypted_path, mode="r") as tar_file:
member = tar_file.getmember("file.log")
extracted = tar_file.extractfile(member).read()
# The content should match the original content, without the extra bytes
# because the file was read with the original size
assert extracted == content
Loading