Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions nbs/src/nixtla_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1290,13 +1290,16 @@
" warnings.warn(\"Azure endpoint detected, setting `model` to 'azureai'.\")\n",
" model = 'azureai'\n",
" return model\n",
" \n",
" def _make_client(self, **kwargs: Any) -> httpx.Client:\n",
" return httpx.Client(**kwargs)\n",
"\n",
" def _get_model_params(self, model: _Model, freq: str) -> tuple[int, int]:\n",
" key = (model, freq)\n",
" if key not in self._model_params:\n",
" logger.info('Querying model metadata...')\n",
" payload = {'model': model, 'freq': freq}\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" if self._is_azure:\n",
" resp_body = self._make_request_with_retries(\n",
" client, 'model_params', payload\n",
Expand Down Expand Up @@ -1443,7 +1446,7 @@
" 'validate_api_key is not implemented for Azure deployments, '\n",
" 'you can try using the forecasting methods directly.'\n",
" )\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" resp = client.get(\"/validate_api_key\")\n",
" body = resp.json()\n",
" if log:\n",
Expand All @@ -1459,7 +1462,7 @@
" Consumed requests and limits by minute and month.\"\"\"\n",
" if self._is_azure:\n",
" raise NotImplementedError('usage is not implemented for Azure deployments')\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" return self._get_request(client, '/usage')\n",
"\n",
" def finetune(\n",
Expand Down Expand Up @@ -1567,7 +1570,7 @@
" 'output_model_id': output_model_id,\n",
" 'finetuned_model_id': finetuned_model_id,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" resp = self._make_request_with_retries(client, 'v2/finetune', payload)\n",
" return resp['finetuned_model_id']\n",
"\n",
Expand All @@ -1594,7 +1597,7 @@
" -------\n",
" list of FinetunedModel\n",
" List of available fine-tuned models.\"\"\"\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" resp_body = self._get_request(client, '/v2/finetuned_models')\n",
" models = [FinetunedModel(**m) for m in resp_body['finetuned_models']]\n",
" if as_df:\n",
Expand All @@ -1613,7 +1616,7 @@
" -------\n",
" FinetunedModel\n",
" Fine-tuned model metadata.\"\"\"\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" resp_body = self._get_request(\n",
" client, f'/v2/finetuned_models/{finetuned_model_id}'\n",
" )\n",
Expand All @@ -1631,7 +1634,7 @@
" -------\n",
" bool\n",
" Whether delete was successful.\"\"\"\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" resp = client.delete(\n",
" f\"/v2/finetuned_models/{finetuned_model_id}\",\n",
" headers={'accept-encoding': 'identity'},\n",
Expand Down Expand Up @@ -1945,7 +1948,7 @@
" 'finetuned_model_id': finetuned_model_id,\n",
" 'feature_contributions': feature_contributions and X is not None,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" insample_feat_contributions = None\n",
" if num_partitions is None:\n",
" resp = self._make_request_with_retries(client, 'v2/forecast', payload)\n",
Expand Down Expand Up @@ -2204,7 +2207,7 @@
" 'clean_ex_first': clean_ex_first,\n",
" 'level': level,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" if num_partitions is None:\n",
" resp = self._make_request_with_retries(\n",
" client, 'v2/anomaly_detection', payload\n",
Expand Down Expand Up @@ -2474,7 +2477,7 @@
" 'refit': refit,\n",
" 'hist_exog': hist_exog,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" if num_partitions is None:\n",
" resp = self._make_request_with_retries(\n",
" client, 'v2/online_anomaly_detection', payload\n",
Expand Down Expand Up @@ -2775,7 +2778,7 @@
" 'finetuned_model_id': finetuned_model_id,\n",
" 'refit': refit,\n",
" }\n",
" with httpx.Client(**self._client_kwargs) as client:\n",
" with self._make_client(**self._client_kwargs) as client:\n",
" if num_partitions is None:\n",
" resp = self._make_request_with_retries(\n",
" client, 'v2/cross_validation', payload\n",
Expand Down
2 changes: 2 additions & 0 deletions nixtla/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._get_request': ( 'src/nixtla_client.html#nixtlaclient._get_request',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._make_client': ( 'src/nixtla_client.html#nixtlaclient._make_client',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._make_partitioned_requests': ( 'src/nixtla_client.html#nixtlaclient._make_partitioned_requests',
'nixtla/nixtla_client.py'),
'nixtla.nixtla_client.NixtlaClient._make_request': ( 'src/nixtla_client.html#nixtlaclient._make_request',
Expand Down
25 changes: 14 additions & 11 deletions nixtla/nixtla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,12 +978,15 @@ def _maybe_override_model(self, model: _Model) -> _Model:
model = "azureai"
return model

def _make_client(self, **kwargs: Any) -> httpx.Client:
return httpx.Client(**kwargs)

def _get_model_params(self, model: _Model, freq: str) -> tuple[int, int]:
key = (model, freq)
if key not in self._model_params:
logger.info("Querying model metadata...")
payload = {"model": model, "freq": freq}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
if self._is_azure:
resp_body = self._make_request_with_retries(
client, "model_params", payload
Expand Down Expand Up @@ -1130,7 +1133,7 @@ def validate_api_key(self, log: bool = True) -> bool:
"validate_api_key is not implemented for Azure deployments, "
"you can try using the forecasting methods directly."
)
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
resp = client.get("/validate_api_key")
body = resp.json()
if log:
Expand All @@ -1146,7 +1149,7 @@ def usage(self) -> dict[str, dict[str, int]]:
Consumed requests and limits by minute and month."""
if self._is_azure:
raise NotImplementedError("usage is not implemented for Azure deployments")
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
return self._get_request(client, "/usage")

def finetune(
Expand Down Expand Up @@ -1254,7 +1257,7 @@ def finetune(
"output_model_id": output_model_id,
"finetuned_model_id": finetuned_model_id,
}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
resp = self._make_request_with_retries(client, "v2/finetune", payload)
return resp["finetuned_model_id"]

Expand All @@ -1279,7 +1282,7 @@ def finetuned_models(
-------
list of FinetunedModel
List of available fine-tuned models."""
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
resp_body = self._get_request(client, "/v2/finetuned_models")
models = [FinetunedModel(**m) for m in resp_body["finetuned_models"]]
if as_df:
Expand All @@ -1298,7 +1301,7 @@ def finetuned_model(self, finetuned_model_id: str) -> FinetunedModel:
-------
FinetunedModel
Fine-tuned model metadata."""
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
resp_body = self._get_request(
client, f"/v2/finetuned_models/{finetuned_model_id}"
)
Expand All @@ -1316,7 +1319,7 @@ def delete_finetuned_model(self, finetuned_model_id: str) -> bool:
-------
bool
Whether delete was successful."""
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
resp = client.delete(
f"/v2/finetuned_models/{finetuned_model_id}",
headers={"accept-encoding": "identity"},
Expand Down Expand Up @@ -1631,7 +1634,7 @@ def forecast(
"finetuned_model_id": finetuned_model_id,
"feature_contributions": feature_contributions and X is not None,
}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
insample_feat_contributions = None
if num_partitions is None:
resp = self._make_request_with_retries(client, "v2/forecast", payload)
Expand Down Expand Up @@ -1896,7 +1899,7 @@ def detect_anomalies(
"clean_ex_first": clean_ex_first,
"level": level,
}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
if num_partitions is None:
resp = self._make_request_with_retries(
client, "v2/anomaly_detection", payload
Expand Down Expand Up @@ -2174,7 +2177,7 @@ def detect_anomalies_online(
"refit": refit,
"hist_exog": hist_exog,
}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
if num_partitions is None:
resp = self._make_request_with_retries(
client, "v2/online_anomaly_detection", payload
Expand Down Expand Up @@ -2479,7 +2482,7 @@ def cross_validation(
"finetuned_model_id": finetuned_model_id,
"refit": refit,
}
with httpx.Client(**self._client_kwargs) as client:
with self._make_client(**self._client_kwargs) as client:
if num_partitions is None:
resp = self._make_request_with_retries(
client, "v2/cross_validation", payload
Expand Down