Skip to content

Commit 5852042

Browse files
authored
feat: response_parser both sync and async (#432)
1 parent 6b6171d commit 5852042

File tree

8 files changed

+61
-33
lines changed

8 files changed

+61
-33
lines changed

packages/ragbits-core/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CHANGELOG
22

33
## Unreleased
4+
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.
45

56
## 0.11.0 (2025-03-25)
67
- Add HybridSearchVectorStore which can aggregate results from multiple VectorStores (#412)

packages/ragbits-core/src/ragbits/core/llms/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def generate(
163163
with trace(model_name=self.model_name, prompt=prompt, options=repr(options)) as outputs:
164164
raw_response = await self.generate_raw(prompt, options=options)
165165
if isinstance(prompt, BasePromptWithParser):
166-
response = prompt.parse_response(raw_response["response"])
166+
response = await prompt.parse_response(raw_response["response"])
167167
else:
168168
response = cast(OutputT, raw_response["response"])
169169
raw_response["response"] = response
@@ -225,7 +225,7 @@ async def generate_with_metadata(
225225
response = await self.generate_raw(prompt, options=options)
226226
content = response.pop("response")
227227
if isinstance(prompt, BasePromptWithParser):
228-
content = prompt.parse_response(content)
228+
content = await prompt.parse_response(content)
229229
outputs.response = LLMResponseWithMetadata[type(content)]( # type: ignore
230230
content=content,
231231
metadata=response,

packages/ragbits-core/src/ragbits/core/prompt/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class BasePromptWithParser(Generic[OutputT], BasePrompt, metaclass=ABCMeta):
5454
"""
5555

5656
@abstractmethod
57-
def parse_response(self, response: str) -> OutputT:
57+
async def parse_response(self, response: str) -> OutputT:
5858
"""
5959
Parse the response from the LLM to the desired output type.
6060

packages/ragbits-core/src/ragbits/core/prompt/prompt.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import asyncio
12
import base64
23
import imghdr
34
import textwrap
45
from abc import ABCMeta
5-
from collections.abc import Callable
6+
from collections.abc import Awaitable, Callable
67
from typing import Any, Generic, cast, get_args, get_origin, overload
78

89
from jinja2 import Environment, Template, meta
@@ -34,7 +35,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=
3435

3536
# function that parses the response from the LLM to specific output type
3637
# if not provided, the class tries to set it automatically based on the output type
37-
response_parser: Callable[[str], OutputT]
38+
response_parser: Callable[[str], OutputT | Awaitable[OutputT]]
3839

3940
# Automatically set in __init_subclass__
4041
input_type: type[InputT] | None
@@ -98,7 +99,7 @@ def _format_message(cls, message: str) -> str:
9899
return textwrap.dedent(message).strip()
99100

100101
@classmethod
101-
def _detect_response_parser(cls) -> Callable[[str], OutputT]:
102+
def _detect_response_parser(cls) -> Callable[[str], OutputT | Awaitable[OutputT]]:
102103
if hasattr(cls, "response_parser") and cls.response_parser is not None:
103104
return cls.response_parser
104105
if issubclass(cls.output_type, BaseModel):
@@ -265,7 +266,7 @@ def json_mode(self) -> bool:
265266
"""
266267
return issubclass(self.output_type, BaseModel)
267268

268-
def parse_response(self, response: str) -> OutputT:
269+
async def parse_response(self, response: str) -> OutputT:
269270
"""
270271
Parse the response from the LLM to the desired output type.
271272
@@ -278,7 +279,11 @@ def parse_response(self, response: str) -> OutputT:
278279
Raises:
279280
ResponseParsingError: If the response cannot be parsed.
280281
"""
281-
return self.response_parser(response)
282+
if asyncio.iscoroutinefunction(self.response_parser):
283+
result = await self.response_parser(response)
284+
else:
285+
result = self.response_parser(response)
286+
return result
282287

283288
@classmethod
284289
def to_promptfoo(cls, config: dict[str, Any]) -> ChatFormat:

packages/ragbits-core/tests/unit/llms/test_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def chat(self) -> ChatFormat:
2828
return [{"role": "user", "content": self._content}]
2929

3030
@staticmethod
31-
def parse_response(response: str) -> CustomOutputType:
31+
async def parse_response(response: str) -> CustomOutputType:
3232
return CustomOutputType(message=response)
3333

3434

packages/ragbits-core/tests/unit/llms/test_litellm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def chat(self) -> ChatFormat:
5757
return [{"content": self.message, "role": "user"}]
5858

5959
@staticmethod
60-
def parse_response(response: str) -> int:
60+
async def parse_response(response: str) -> int:
6161
"""
6262
Parser for the prompt.
6363

packages/ragbits-core/tests/unit/prompts/test_parsers.py

+21-21
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .test_prompt import _PromptOutput
99

1010

11-
def test_prompt_with_str_output():
11+
async def test_prompt_with_str_output():
1212
"""Test a prompt with a string output."""
1313

1414
class TestPrompt(Prompt[None, str]): # pylint: disable=unused-variable
@@ -17,10 +17,10 @@ class TestPrompt(Prompt[None, str]): # pylint: disable=unused-variable
1717
user_prompt = "Hello"
1818

1919
prompt = TestPrompt()
20-
assert prompt.parse_response("Hi") == "Hi"
20+
assert await prompt.parse_response("Hi") == "Hi"
2121

2222

23-
def test_prompt_with_int_output():
23+
async def test_prompt_with_int_output():
2424
"""Test a prompt with an int output."""
2525

2626
class TestPrompt(Prompt[None, int]): # pylint: disable=unused-variable
@@ -29,13 +29,13 @@ class TestPrompt(Prompt[None, int]): # pylint: disable=unused-variable
2929
user_prompt = "Hello"
3030

3131
prompt = TestPrompt()
32-
assert prompt.parse_response("1") == 1
32+
assert await prompt.parse_response("1") == 1
3333

3434
with pytest.raises(ResponseParsingError):
35-
prompt.parse_response("a")
35+
await prompt.parse_response("a")
3636

3737

38-
def test_prompt_with_model_output():
38+
async def test_prompt_with_model_output():
3939
"""Test a prompt with a model output."""
4040

4141
class TestPrompt(Prompt[None, _PromptOutput]): # pylint: disable=unused-variable
@@ -44,15 +44,15 @@ class TestPrompt(Prompt[None, _PromptOutput]): # pylint: disable=unused-variabl
4444
user_prompt = "Hello"
4545

4646
prompt = TestPrompt()
47-
assert prompt.parse_response('{"song_title": "Hello", "song_lyrics": "World"}') == _PromptOutput(
47+
assert await prompt.parse_response('{"song_title": "Hello", "song_lyrics": "World"}') == _PromptOutput(
4848
song_title="Hello", song_lyrics="World"
4949
)
5050

5151
with pytest.raises(ResponseParsingError):
52-
prompt.parse_response('{"song_title": "Hello"}')
52+
await prompt.parse_response('{"song_title": "Hello"}')
5353

5454

55-
def test_prompt_with_float_output():
55+
async def test_prompt_with_float_output():
5656
"""Test a prompt with a float output."""
5757

5858
class TestPrompt(Prompt[None, float]): # pylint: disable=unused-variable
@@ -61,13 +61,13 @@ class TestPrompt(Prompt[None, float]): # pylint: disable=unused-variable
6161
user_prompt = "Hello"
6262

6363
prompt = TestPrompt()
64-
assert prompt.parse_response("1.0") == 1.0
64+
assert await prompt.parse_response("1.0") == 1.0
6565

6666
with pytest.raises(ResponseParsingError):
67-
prompt.parse_response("a")
67+
await prompt.parse_response("a")
6868

6969

70-
def test_prompt_with_bool_output():
70+
async def test_prompt_with_bool_output():
7171
"""Test a prompt with a bool output."""
7272

7373
class TestPrompt(Prompt[None, bool]): # pylint: disable=unused-variable
@@ -76,14 +76,14 @@ class TestPrompt(Prompt[None, bool]): # pylint: disable=unused-variable
7676
user_prompt = "Hello"
7777

7878
prompt = TestPrompt()
79-
assert prompt.parse_response("true") is True
80-
assert prompt.parse_response("false") is False
79+
assert await prompt.parse_response("true") is True
80+
assert await prompt.parse_response("false") is False
8181

8282
with pytest.raises(ResponseParsingError):
83-
prompt.parse_response("a")
83+
await prompt.parse_response("a")
8484

8585

86-
def test_prompt_with_int_and_custom_parser():
86+
async def test_prompt_with_int_and_custom_parser():
8787
"""Test a prompt with an int output and a custom parser."""
8888

8989
class TestPrompt(Prompt[None, int]): # pylint: disable=unused-variable
@@ -111,10 +111,10 @@ def response_parser(response: str) -> int:
111111
raise ResponseParsingError("Could not parse response")
112112

113113
prompt = TestPrompt()
114-
assert prompt.parse_response("abcd k2") == 2
114+
assert await prompt.parse_response("abcd k2") == 2
115115

116116
with pytest.raises(ResponseParsingError):
117-
prompt.parse_response("a")
117+
await prompt.parse_response("a")
118118

119119

120120
def test_prompt_with_unknown_output_and_no_parser():
@@ -127,7 +127,7 @@ class TestPrompt(Prompt[None, list]): # pylint: disable=unused-variable
127127
user_prompt = "Hello"
128128

129129

130-
def test_prompt_with_unknown_output_and_custom_parser():
130+
async def test_prompt_with_unknown_output_and_custom_parser():
131131
"""Test a prompt with an output type that doesn't have a default parser but has a custom parser."""
132132

133133
class TestPrompt(Prompt[None, list]): # pylint: disable=unused-variable
@@ -152,5 +152,5 @@ def response_parser(response: str) -> list:
152152
return response.split()
153153

154154
prompt = TestPrompt()
155-
assert prompt.parse_response("Hello World") == ["Hello", "World"]
156-
assert prompt.parse_response("Hello") == ["Hello"]
155+
assert await prompt.parse_response("Hello World") == ["Hello", "World"]
156+
assert await prompt.parse_response("Hello") == ["Hello"]

packages/ragbits-core/tests/unit/prompts/test_prompt.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class TestPromptSystem(Prompt): # pylint: disable=unused-variable
7777
system_prompt = "Hello, {{ name }}"
7878

7979

80-
def test_raises_when_unknow_user_template_variable():
80+
def test_raises_when_unknown_user_template_variable():
8181
"""Test that a ValueError is raised when an unknown template variable is provided."""
8282
with pytest.raises(ValueError):
8383

@@ -87,7 +87,7 @@ class TestPromptUser(Prompt[_PromptInput, str]): # pylint: disable=unused-varia
8787
user_prompt = "Hello, {{ foo }}"
8888

8989

90-
def test_raises_when_unknow_system_template_variable():
90+
def test_raises_when_unknown_system_template_variable():
9191
"""Test that a ValueError is raised when an unknown template variable is provided."""
9292
with pytest.raises(ValueError):
9393

@@ -515,3 +515,25 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
515515
{"role": "assistant", "content": "Why do I know all the words?"},
516516
{"role": "user", "content": "Theme for the song is rock."},
517517
]
518+
519+
520+
async def test_response_parser():
521+
class TestPrompt(Prompt):
522+
user_prompt = "Hello AI"
523+
524+
async def async_parser(response: str) -> str:
525+
return response.upper()
526+
527+
def sync_parser(response: str) -> str:
528+
return response.lower()
529+
530+
test_prompt = TestPrompt()
531+
532+
resp = "Hello Human"
533+
test_prompt.response_parser = async_parser
534+
resp_async = await test_prompt.parse_response(resp)
535+
assert resp_async == "HELLO HUMAN"
536+
537+
test_prompt.response_parser = sync_parser
538+
resp_sync = await test_prompt.parse_response(resp)
539+
assert resp_sync == "hello human"

0 commit comments

Comments
 (0)