diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index ecd07d37..7ad8adbf 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -86,16 +86,28 @@ async def process_func( Process a single request. """ - # This is actually stupid + # Extract params from kwargs - must pass to constructor, not mutate after, + # because SamplingParams.__post_init__ likely does some extra setup, + # and mutation after construction skips this. + logprobs = None + prompt_logprobs = None + max_tokens = self.sampling_params.max_tokens + temperature = 1.0 for kwarg in kwargs: if "logprobs" in kwarg: - self.sampling_params.logprobs = kwarg["top_logprobs"] + logprobs = kwarg["top_logprobs"] if "prompt_logprobs" in kwarg: - self.sampling_params.prompt_logprobs = kwarg["prompt_logprobs"] + prompt_logprobs = kwarg["prompt_logprobs"] if "max_tokens" in kwarg: - self.sampling_params.max_tokens = kwarg["max_tokens"] + max_tokens = kwarg["max_tokens"] if "temperature" in kwarg: - self.sampling_params.temperature = kwarg["temperature"] + temperature = kwarg["temperature"] + sampling_params = SamplingParams( + max_tokens=max_tokens, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + temperature=temperature, + ) loop = asyncio.get_running_loop() prompts = [] statistics = [] @@ -124,7 +136,7 @@ async def process_func( partial( self.client.generate, # type: ignore prompts, - sampling_params=self.sampling_params, + sampling_params=sampling_params, # Use fresh sampling_params use_tqdm=False, ), )