From 7e1134f0a94450a02d020a21f100c1e40f0dd4e9 Mon Sep 17 00:00:00 2001 From: Matthew Grange Date: Mon, 17 Nov 2025 12:59:13 -0800 Subject: [PATCH] Update PACKAGE transformers library version to support gpt-oss (#81) Summary: To unblock the tutorial PrivacyGuard Tutorial: Memorization of Online Content in LLMs, updating the transformers version to enable gpt-oss to be used. Differential Revision: D86675884 --- .../extraction/predictors/huggingface_predictor.py | 2 +- .../extraction/tests/test_generation_attack.py | 13 +++++++++++++ pyproject.toml | 3 ++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py b/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py index 433cd6d..045da99 100644 --- a/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py +++ b/privacy_guard/attacks/extraction/predictors/huggingface_predictor.py @@ -284,7 +284,7 @@ def _get_logits_batch( with torch.no_grad(): # Get model outputs for the entire batch - outputs = ( + outputs = ( # pyre-fixme[29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function. self.model.module # pyre-ignore if hasattr(self.model, "module") else self.model diff --git a/privacy_guard/attacks/extraction/tests/test_generation_attack.py b/privacy_guard/attacks/extraction/tests/test_generation_attack.py index 9a9322d..c644d6f 100644 --- a/privacy_guard/attacks/extraction/tests/test_generation_attack.py +++ b/privacy_guard/attacks/extraction/tests/test_generation_attack.py @@ -19,6 +19,8 @@ from unittest.mock import MagicMock +import transformers +from packaging.version import Version from privacy_guard.attacks.extraction.generation_attack import GenerationAttack @@ -145,6 +147,17 @@ def test_generation_attack_missing_column(self) -> None: self.assertIn("Missing required columns", str(context.exception)) bad_input_file.close() + def test_transformers_version_in_generation_attack(self) -> None: + """Verify that transformers version is greater than or equal to 4.55.0""" + current_version = transformers.__version__ + required_version = "4.55.0" + + self.assertGreaterEqual( + Version(current_version), + Version(required_version), + f"Transformers version {current_version} must be greater than or equal to 4.55.0", + ) + def tearDown(self) -> None: """Clean up temporary files.""" self.input_file.close() diff --git a/pyproject.toml b/pyproject.toml index fdef049..527532d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "torch", 'tqdm', 'textdistance', - 'transformers', + 'transformers>=4.55.0', + 'accelerate', 'later', ]