diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index d5836f72..fa46d910 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -492,7 +492,7 @@ def __init__( self.kwargs = { "model": model, "temperature": 0.0, - "max_tokens": 150, + # "max_tokens": 150, "top_p": 1, "frequency_penalty": 0, "presence_penalty": 0, @@ -500,6 +500,11 @@ def __init__( **kwargs, } + if model in ["o3-mini", "o1", "o1-mini"]: + self.kwargs["max_completion_tokens"] = kwargs.get("max_completion_tokens", 150) + else: + self.kwargs["max_tokens"] = kwargs.get("max_tokens", 150) + @backoff.on_exception( backoff.expo, ERRORS, @@ -511,6 +516,10 @@ def basic_request(self, prompt: str, **kwargs) -> Any: kwargs = {**self.kwargs, **kwargs} try: + if self.model in ["o3-mini", "o1", "o1-mini"]: + if "max_tokens" in kwargs: + kwargs["max_completion_tokens"] = kwargs.pop("max_tokens") + if self.model_type == "chat": messages = [{"role": "user", "content": prompt}]