Skip to content

Commit dac2f3c

Browse files
committed
test(recorded): consume shared harness in library rail tests
Fold library/helpers.py onto the shared build_rails()/async_chunks() instead of its own LLMRails(load_config(...)) + local _chunks (D11/F), and assert the content-safety output block via assert_blocked_generation (refusal + rail stop) rather than the weak assert_generation_response non-empty check (D6).
1 parent e84017c commit dac2f3c

2 files changed

Lines changed: 8 additions & 13 deletions

File tree

tests/recorded/rails/library/helpers.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,19 @@
1717

1818
from typing import Any
1919

20-
from nemoguardrails import LLMRails
2120
from nemoguardrails.rails.llm.options import RailType
22-
from tests.recorded.rails_config import RailsConfigSource, enable_streaming, load_config
21+
from tests.recorded.rails.helpers import async_chunks, build_rails
22+
from tests.recorded.rails_config import RailsConfigSource
2323
from tests.utils import FakeLLMModel
2424

2525

26-
async def _chunks(values: list[str]):
27-
for value in values:
28-
yield value
29-
30-
3126
async def check_rails(
3227
config: RailsConfigSource,
3328
messages: list[dict[str, Any]],
3429
*,
3530
rail_types: tuple[RailType, ...] | None = None,
3631
):
37-
rails = LLMRails(load_config(config), verbose=False)
32+
rails = build_rails(config)
3833
return await rails.check_async(messages, rail_types=list(rail_types) if rail_types is not None else None)
3934

4035

@@ -43,7 +38,7 @@ async def generate_with_fake_main(
4338
main_output: str,
4439
messages: list[dict[str, Any]],
4540
):
46-
rails = LLMRails(load_config(config), llm=FakeLLMModel(responses=[main_output]), verbose=False)
41+
rails = build_rails(config, llm=FakeLLMModel(responses=[main_output]))
4742
return await rails.generate_async(
4843
messages=messages,
4944
options={"log": {"activated_rails": True, "llm_calls": True}},
@@ -55,11 +50,11 @@ async def stream_with_fake_main(
5550
main_output: str,
5651
messages: list[dict[str, Any]],
5752
):
58-
rails = LLMRails(enable_streaming(load_config(config)), verbose=False)
53+
rails = build_rails(config, streaming=True)
5954
chunks = []
6055
async for chunk in rails.stream_async(
6156
messages=messages,
62-
generator=_chunks([main_output]),
57+
generator=async_chunks([main_output]),
6358
options={"rails": ["output"]},
6459
):
6560
chunks.append(chunk)

tests/recorded/rails/library/test_content_safety.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from nemoguardrails.exceptions import LLMCallException
2121
from nemoguardrails.rails.llm.options import RailStatus, RailType
2222
from tests.recorded.assertions import (
23+
assert_blocked_generation,
2324
assert_blocked_stream_error,
24-
assert_generation_response,
2525
assert_rails_result,
2626
)
2727
from tests.recorded.normalization import normalize_generation_response, normalize_rails_result, normalize_stream_chunks
@@ -89,7 +89,7 @@ async def test_content_safety_output_blocks_fake_main_generation(nvidia_api_key)
8989
[{"role": "user", "content": "hello"}],
9090
)
9191

92-
result = assert_generation_response(result)
92+
result = assert_blocked_generation(result, refusal="I'm sorry, I can't respond to that.")
9393

9494
assert normalize_generation_response(result) == snapshot(
9595
{

0 commit comments

Comments
 (0)