11# -*- coding: utf-8 -*-
22"""Get the formatter and model based on the model provider."""
3+ import re
4+ import agentscope
35from agentscope .formatter import (
46 DashScopeChatFormatter ,
57 OpenAIChatFormatter ,
1820)
1921
2022
23+ def is_agentscope_version_ge (target_version : tuple ) -> bool :
24+ """
25+ Check if the current agentscope version is greater than or equal to the target version.
26+
27+ Args:
28+ target_version: A tuple of (major, minor, patch) version numbers.
29+
30+ Returns:
31+ True if current version >= target version, False otherwise.
32+
33+ Example:
34+ >>> is_agentscope_version_ge((1, 0, 9)) # Works with "1.0.9" or "1.0.9dev"
35+ True
36+ """
37+ version_str = agentscope .__version__
38+ version_match = re .match (r'^(\d+)\.(\d+)\.(\d+)' , version_str )
39+ if version_match :
40+ major , minor , patch = map (int , version_match .groups ())
41+ current_version = (major , minor , patch )
42+ return current_version >= target_version
43+ return False
44+
45+
2146def get_formatter (llmProvider : str ) -> FormatterBase :
2247 """Get the formatter based on the model provider."""
2348 match llmProvider .lower ():
@@ -36,7 +61,13 @@ def get_formatter(llmProvider: str) -> FormatterBase:
3661 f"Unsupported model provider: { llmProvider } . "
3762 )
3863
39- def get_model (llmProvider :str , modelName : str , apiKey : str , baseUrl : str = None ) -> ChatModelBase :
64+ def get_model (
65+ llmProvider : str ,
66+ modelName : str ,
67+ apiKey : str ,
68+ client_kwargs : dict = {},
69+ generate_kwargs : dict = {},
70+ ) -> ChatModelBase :
4071 """Get the model instance based on the input arguments."""
4172
4273 match llmProvider .lower ():
@@ -45,34 +76,47 @@ def get_model(llmProvider:str, modelName: str, apiKey: str, baseUrl: str = None)
4576 model_name = modelName ,
4677 api_key = apiKey ,
4778 stream = True ,
79+ generate_kwargs = generate_kwargs ,
4880 )
4981 case "openai" :
50- client_args = {}
51- if baseUrl :
52- client_args ["base_url" ] = baseUrl
5382 return OpenAIChatModel (
5483 model_name = modelName ,
5584 api_key = apiKey ,
5685 stream = True ,
57- client_args = client_args ,
86+ client_kwargs = client_kwargs ,
87+ generate_kwargs = generate_kwargs ,
5888 )
5989 case "ollama" :
60- return OllamaChatModel (
61- model_name = modelName ,
62- stream = True ,
63- host = baseUrl ,
64- )
90+ if is_agentscope_version_ge ((1 , 0 , 9 )):
91+ # For agentscope >= 1.0.9
92+ return OllamaChatModel (
93+ model_name = modelName ,
94+ stream = True ,
95+ client_kwargs = client_kwargs ,
96+ generate_kwargs = generate_kwargs ,
97+ )
98+ else :
99+ # For agentscope < 1.0.9
100+ return OllamaChatModel (
101+ model_name = modelName ,
102+ stream = True ,
103+ ** client_kwargs ,
104+ )
65105 case "gemini" :
66106 return GeminiChatModel (
67107 model_name = modelName ,
68108 api_key = apiKey ,
69109 stream = True ,
110+ client_kwargs = client_kwargs ,
111+ generate_kwargs = generate_kwargs ,
70112 )
71113 case "anthropic" :
72114 return AnthropicChatModel (
73115 model_name = modelName ,
74116 api_key = apiKey ,
75117 stream = True ,
118+ client_kwargs = client_kwargs ,
119+ generate_kwargs = generate_kwargs ,
76120 )
77121 case _:
78122 raise ValueError (
0 commit comments