Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _get_logits_batch(

with torch.no_grad():
# Get model outputs for the entire batch
outputs = (
outputs = ( # pyre-fixme[29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
self.model.module # pyre-ignore
if hasattr(self.model, "module")
else self.model
Expand Down
13 changes: 13 additions & 0 deletions privacy_guard/attacks/extraction/tests/test_generation_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from unittest.mock import MagicMock

import transformers
from packaging.version import Version
from privacy_guard.attacks.extraction.generation_attack import GenerationAttack


Expand Down Expand Up @@ -145,6 +147,17 @@ def test_generation_attack_missing_column(self) -> None:
self.assertIn("Missing required columns", str(context.exception))
bad_input_file.close()

def test_transformers_version_in_generation_attack(self) -> None:
"""Verify that transformers version is greater than or equal to 4.55.0"""
current_version = transformers.__version__
required_version = "4.55.0"

self.assertGreaterEqual(
Version(current_version),
Version(required_version),
f"Transformers version {current_version} must be greater than or equal to 4.55.0",
)

def tearDown(self) -> None:
"""Clean up temporary files."""
self.input_file.close()
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"torch",
'tqdm',
'textdistance',
'transformers',
'transformers>=4.55.0',
'accelerate',
'later',
]

Expand Down
Loading