Skip to content

Commit e809e16

Browse files
authored
fix: Resolve bugs introduced by PR #68 (#139)
<!-- SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. --> <!-- SPDX-License-Identifier: Apache-2.0 --> <!-- Thank you for contributing to Safe Synthesizer! --> # Summary <!-- Brief description of changes --> #68 missed updating vllm_backend.py for the change to `_build_json_based_regex()` arguments, fix them now. ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make lint` or via prek validation. - [x] `make test` passes locally - [ ] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [ ] New or updated tests for any fix or new behavior - [ ] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #<issue> --------- Signed-off-by: memadi <memadi@nvidia.com>
1 parent b869c24 commit e809e16

2 files changed

Lines changed: 6 additions & 7 deletions

File tree

src/nemo_safe_synthesizer/generation/vllm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def _build_structured_output_params(self) -> StructuredOutputsParams | None:
118118
logger.info("Structured generation is enabled, using a regex to enforce the schema")
119119
regex = build_json_based_regex(
120120
self.schema,
121+
self.config,
121122
self.model_metadata.prompt_config.bos_token,
122123
self.model_metadata.prompt_config.eos_token,
123-
group_by=self.config.data.group_training_examples_by is not None,
124124
)
125125
params["regex"] = regex
126126
elif self.config.generation.structured_generation_schema_method == "json_schema":

tests/generation/test_vllm_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,11 @@ def test_returns_params_with_regex_when_regex_method(
158158
return_value="test_regex_pattern",
159159
) as mock_build_regex:
160160
result = backend._build_structured_output_params()
161-
162161
mock_build_regex.assert_called_once_with(
163162
mock_schema,
163+
params_with_structured_generation_regex,
164164
mock_model_metadata.prompt_config.bos_token,
165165
mock_model_metadata.prompt_config.eos_token,
166-
group_by=False,
167166
)
168167
assert result is not None
169168
assert result.regex == "test_regex_pattern"
@@ -179,10 +178,10 @@ def test_returns_params_with_json_when_json_schema_method(
179178
assert result is not None
180179
assert result.json == mock_schema
181180

182-
def test_group_by_passed_when_grouping_enabled(
181+
def test_config_with_grouping_passed_to_build_regex(
183182
self, params_with_structured_generation_regex, mock_model_metadata, mock_schema, mock_workdir
184183
):
185-
"""Test that group_by=True is passed when group_training_examples_by is set."""
184+
"""Test that config with group_training_examples_by set is passed to build_json_based_regex."""
186185
params_with_structured_generation_regex.data.group_training_examples_by = "category"
187186
backend = create_backend(
188187
params_with_structured_generation_regex, mock_model_metadata, mock_schema, mock_workdir
@@ -195,8 +194,8 @@ def test_group_by_passed_when_grouping_enabled(
195194
backend._build_structured_output_params()
196195

197196
mock_build_regex.assert_called_once()
198-
_, kwargs = mock_build_regex.call_args
199-
assert kwargs.get("group_by") is True
197+
call_args, _ = mock_build_regex.call_args
198+
assert call_args[1].data.group_training_examples_by == "category"
200199

201200

202201
class TestResolveTemperature:

0 commit comments

Comments
 (0)