Skip to content

Commit ee82177

Browse files
committed
add parameter to limit the size of icl context
1 parent c530bd2 commit ee82177

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

balrog/agents/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def create_agent(self):
4949
elif self.config.agent.type == "custom":
5050
return CustomAgent(client_factory, prompt_builder)
5151
elif self.config.agent.type == "few_shot":
52-
return FewShotAgent(client_factory, prompt_builder)
52+
return FewShotAgent(client_factory, prompt_builder, self.config.agent.max_icl_history)
5353

5454
else:
5555
raise ValueError(f"Unknown agent type: {self.config.agent}")

balrog/agents/few_shot.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@ def __repr__(self):
1616

1717

1818
class FewShotAgent(BaseAgent):
19-
def __init__(self, client_factory, prompt_builder):
19+
def __init__(self, client_factory, prompt_builder, max_icl_history):
2020
"""Initialize the FewShotAgent with a client and prompt builder."""
2121
super().__init__(client_factory, prompt_builder)
2222
self.client = client_factory()
2323
self.icl_episodes = []
2424
self.icl_events = []
25+
self.max_icl_history = max_icl_history
2526
self.cached_icl = False
2627

2728
def update_icl_observation(self, obs: dict):
@@ -76,8 +77,19 @@ def get_icl_prompt(self) -> List[Message]:
7677

7778
# unroll the wrapped icl episodes messages
7879
icl_messages = [icl_instruction]
80+
i = 0
7981
for icl_episode in self.icl_episodes:
80-
icl_messages.extend(icl_episode)
82+
episode_steps = len(icl_episode) - 2 # not count start and end messages
83+
if i + episode_steps <= self.max_icl_history:
84+
icl_messages.extend(icl_episode)
85+
i += episode_steps
86+
else:
87+
icl_episode = icl_episode[: self.max_icl_history - i + 1] + [
88+
icl_episode[-1]
89+
] # +1 for start message -1 for end message
90+
icl_messages.extend(icl_episode)
91+
i += len(icl_episode) - 2 # not count start and end messages
92+
break
8193

8294
end_demo_message = Message(
8395
role="user",

balrog/config/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ agent:
44
max_history: 16 # Maximum number of previous turns to keep in the dialogue history
55
max_image_history: 0 # Maximum number of images to keep in the history
66
max_cot_history: 1 # Maximum number of chain-of-thought steps to keep in history (if using 'cot' type of agent)
7+
max_icl_history: 200 # Maximum number of ICL steps to keep in history (if using 'few_shot' type of agent)
78
cache_icl: False
89

910
eval:

0 commit comments

Comments
 (0)