Skip to content

Commit 13b3146

Browse files
- moved hardcoded constants into options;
1 parent 5f56183 commit 13b3146

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

src/coverup/coverup.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def Path_dir(value):
6464
ap.add_argument('--checkpoint', type=Path,
6565
help=f'path to save progress to (and to resume it from)')
6666
ap.add_argument('--no-checkpoint', action='store_const', const=None, dest='checkpoint', default=argparse.SUPPRESS,
67-
help=f'disables checkpoint')
67+
help='disable checkpoint')
6868

6969
def default_model():
7070
if 'OPENAI_API_KEY' in os.environ:
@@ -82,6 +82,12 @@ def default_model():
8282
default='gpt-v2',
8383
help='Prompt style to use')
8484

85+
ap.add_argument('--ollama-api-base', type=str, default="http://localhost:11434",
86+
help='"api_base" setting for Ollama models')
87+
88+
ap.add_argument('--bedrock-anthropic-version', type=str, default="bedrock-2023-05-31",
89+
help='"anthropic_version" setting for bedrock Anthropic models')
90+
8591
ap.add_argument('--model-temperature', type=float, default=0,
8692
help='Model "temperature" to use')
8793

@@ -657,6 +663,13 @@ def main():
657663
if args.rate_limit:
658664
chatter.set_token_rate_limit((args.rate_limit, 60))
659665

666+
extra_request_pars = {}
667+
if "ollama" in args.model:
668+
extra_request_pars['api_base'] = args.ollama_api_base
669+
if args.model.startswith("bedrock/anthropic"):
670+
extra_request_pars['anthropic_version'] = args.bedrock_anthropic_version
671+
chatter.set_extra_request_pars(extra_request_pars)
672+
660673
prompter = prompter_registry[args.prompt](cmd_args=args)
661674
for f in prompter.get_functions():
662675
chatter.add_function(f)

src/coverup/llm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(self, model: str) -> None:
112112
self._signal_retry = lambda: None
113113
self._functions: dict[str, dict[str, T.Any]] = dict()
114114
self._max_func_calls_per_chat = 50
115+
self._extra_request_pars = None
115116

116117
@staticmethod
117118
def _validate_model(model) -> None:
@@ -174,6 +175,10 @@ def set_signal_retry(self, signal_retry: T.Callable) -> None:
174175
"""Sets up a callback to indicate a retry."""
175176
self._signal_retry = signal_retry
176177

178+
def set_extra_request_pars(self, request_pars: dict[str, T.Any] | None) -> None:
179+
"""Sets additional parameters to pass into the LLM request."""
180+
self._extra_request_pars = request_pars
181+
177182
def add_function(self, function: T.Callable) -> None:
178183
"""Makes a function availabe to the LLM."""
179184
if not litellm.supports_function_calling(self._model):
@@ -190,19 +195,14 @@ def add_function(self, function: T.Callable) -> None:
190195
self._functions[schema['name']] = {"function": function, "schema": schema}
191196

192197
def _request(self, messages: T.List[dict]) -> dict:
193-
request = {
198+
return {
194199
'model': self._model,
195200
**({'temperature': self._model_temperature} if self._model_temperature is not None else {}),
196201
'messages': messages,
197-
**({'api_base': "http://localhost:11434"} if "ollama" in self._model else {}),
198-
**({'tools': [{'type': 'function', 'function': f['schema']} for f in self._functions.values()]} if self._functions else {})
202+
**({'tools': [{'type': 'function', 'function': f['schema']} for f in self._functions.values()]} if self._functions else {}),
203+
**(self._extra_request_pars if self._extra_request_pars else {})
199204
}
200205

201-
if self._model.startswith("bedrock/anthropic"):
202-
request['anthropic_version'] = "bedrock-2023-05-31"
203-
204-
return request
205-
206206
async def _send_request(self, request: dict, ctx: object) -> litellm.ModelResponse | None:
207207
"""Sends the LLM chat request, handling common failures and returning the response."""
208208
sleep = 1

0 commit comments

Comments
 (0)