@@ -16,12 +16,13 @@ def __repr__(self):
1616
1717
1818class FewShotAgent (BaseAgent ):
19- def __init__ (self , client_factory , prompt_builder ):
19+ def __init__ (self , client_factory , prompt_builder , max_icl_history ):
2020 """Initialize the FewShotAgent with a client and prompt builder."""
2121 super ().__init__ (client_factory , prompt_builder )
2222 self .client = client_factory ()
2323 self .icl_episodes = []
2424 self .icl_events = []
25+ self .max_icl_history = max_icl_history
2526 self .cached_icl = False
2627
2728 def update_icl_observation (self , obs : dict ):
@@ -76,8 +77,19 @@ def get_icl_prompt(self) -> List[Message]:
7677
7778 # unroll the wrapped icl episodes messages
7879 icl_messages = [icl_instruction ]
80+ i = 0
7981 for icl_episode in self .icl_episodes :
80- icl_messages .extend (icl_episode )
82+ episode_steps = len (icl_episode ) - 2 # not count start and end messages
83+ if i + episode_steps <= self .max_icl_history :
84+ icl_messages .extend (icl_episode )
85+ i += episode_steps
86+ else :
87+ icl_episode = icl_episode [: self .max_icl_history - i + 1 ] + [
88+ icl_episode [- 1 ]
89+ ] # +1 for start message -1 for end message
90+ icl_messages .extend (icl_episode )
91+ i += len (icl_episode ) - 2 # not count start and end messages
92+ break
8193
8294 end_demo_message = Message (
8395 role = "user" ,
0 commit comments