Skip to content

Commit ae3e69c

Browse files
authored
Merge pull request #5 from yeti-platform/features
Features
2 parents 177fa07 + 511c4c9 commit ae3e69c

File tree

4 files changed

+87
-8
lines changed

4 files changed

+87
-8
lines changed

tests/api.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def test_new_dfiq_from_yaml(self, mock_post):
128128
json={
129129
"dfiq_type": "type",
130130
"dfiq_yaml": "yaml_content",
131-
"update_indicators": True,
132131
},
133132
)
134133

@@ -145,7 +144,24 @@ def test_patch_dfiq_from_yaml(self, mock_patch):
145144
json={
146145
"dfiq_type": "type",
147146
"dfiq_yaml": "yaml_content",
148-
"update_indicators": True,
147+
},
148+
)
149+
150+
@patch("yeti.api.requests.Session.patch")
151+
def test_patch_dfiq(self, mock_patch):
152+
mock_response = MagicMock()
153+
mock_response.content = b'{"id": "patched_dfiq"}'
154+
mock_patch.return_value = mock_response
155+
156+
result = self.api.patch_dfiq(
157+
{"name": "patched_dfiq", "id": 1, "type": "question"}
158+
)
159+
self.assertEqual(result, {"id": "patched_dfiq"})
160+
mock_patch.assert_called_with(
161+
"http://fake-url/api/v2/dfiq/1",
162+
json={
163+
"dfiq_object": {"name": "patched_dfiq", "type": "question", "id": 1},
164+
"dfiq_type": "question",
149165
},
150166
)
151167

yeti/api.py

+53-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Python client for the Yeti API."""
22

33
import json
4+
import logging
45
from typing import Any, Sequence
56

67
import requests
@@ -24,6 +25,13 @@
2425
YetiLinkObject = dict[str, Any]
2526

2627

28+
logger = logging.getLogger(__name__)
29+
handler = logging.StreamHandler()
30+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
31+
handler.setFormatter(formatter)
32+
logger.addHandler(handler)
33+
34+
2735
class YetiApi:
2836
"""API object to interact with the Yeti API.
2937
@@ -40,13 +48,21 @@ def __init__(self, url_root: str):
4048
}
4149
self._url_root = url_root
4250

51+
self._auth_function = ""
52+
self._auth_function_map = {
53+
"auth_api_key": self.auth_api_key,
54+
}
55+
56+
self._apikey = None
57+
4358
def do_request(
4459
self,
4560
method: str,
4661
url: str,
4762
json_data: dict[str, Any] | None = None,
4863
body: bytes | None = None,
4964
headers: dict[str, Any] | None = None,
65+
retries: int = 3,
5066
) -> bytes:
5167
"""Issues a request to the given URL.
5268
@@ -56,6 +72,7 @@ def do_request(
5672
json: The JSON payload to include in the request.
5773
body: The body to include in the request.
5874
headers: Extra headers to include in the request.
75+
retries: The number of times to retry the request.
5976
6077
Returns:
6178
The response from the API; a bytes object.
@@ -85,17 +102,30 @@ def do_request(
85102
raise ValueError(f"Unsupported method: {method}")
86103
response.raise_for_status()
87104
except requests.exceptions.HTTPError as e:
105+
if e.response.status_code == 401:
106+
if retries == 0:
107+
raise errors.YetiAuthError(str(e)) from e
108+
self.refresh_auth()
109+
return self.do_request(
110+
method, url, json_data, body, headers, retries - 1
111+
)
112+
88113
raise errors.YetiApiError(e.response.status_code, e.response.text)
89114

90115
return response.content
91116

92-
def auth_api_key(self, apikey: str) -> None:
117+
def auth_api_key(self, apikey: str | None = None) -> None:
93118
"""Authenticates a session using an API key."""
94119
# Use long-term refresh API token to get an access token
120+
if apikey is not None:
121+
self._apikey = apikey
122+
if not self._apikey:
123+
raise ValueError("No API key provided.")
124+
95125
response = self.do_request(
96126
"POST",
97127
f"{self._url_root}{API_TOKEN_ENDPOINT}",
98-
headers={"x-yeti-apikey": apikey},
128+
headers={"x-yeti-apikey": self._apikey},
99129
)
100130

101131
access_token = json.loads(response).get("access_token")
@@ -107,6 +137,14 @@ def auth_api_key(self, apikey: str) -> None:
107137
authd_session.headers.update({"authorization": f"Bearer {access_token}"})
108138
self.client = authd_session
109139

140+
self._auth_function = "auth_api_key"
141+
142+
def refresh_auth(self):
143+
if self._auth_function:
144+
self._auth_function_map[self._auth_function]()
145+
else:
146+
logger.warning("No auth function set, cannot refresh auth.")
147+
110148
def search_indicators(
111149
self,
112150
name: str | None = None,
@@ -261,7 +299,6 @@ def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject:
261299
params = {
262300
"dfiq_type": dfiq_type,
263301
"dfiq_yaml": dfiq_yaml,
264-
"update_indicators": True,
265302
}
266303
response = self.do_request(
267304
"POST", f"{self._url_root}/api/v2/dfiq/from_yaml", json_data=params
@@ -278,13 +315,25 @@ def patch_dfiq_from_yaml(
278315
params = {
279316
"dfiq_type": dfiq_type,
280317
"dfiq_yaml": dfiq_yaml,
281-
"update_indicators": True,
282318
}
283319
response = self.do_request(
284320
"PATCH", f"{self._url_root}/api/v2/dfiq/{yeti_id}", json_data=params
285321
)
286322
return json.loads(response)
287323

324+
def patch_dfiq(self, dfiq_object: dict[str, Any]) -> YetiObject:
325+
"""Patches a DFIQ object in Yeti."""
326+
params = {
327+
"dfiq_type": dfiq_object["type"],
328+
"dfiq_object": dfiq_object,
329+
}
330+
response = self.do_request(
331+
"PATCH",
332+
f"{self._url_root}/api/v2/dfiq/{dfiq_object['id']}",
333+
json_data=params,
334+
)
335+
return json.loads(response)
336+
288337
def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes:
289338
"""Downloads an archive containing all DFIQ data from Yeti.
290339

yeti/client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ def __init__(self):
1313
@click.group()
1414
@click.option("--api-key", envvar="YETI_API_KEY", required=True, help="Your API key.")
1515
@click.option(
16-
"--endpoint", envvar="YETI_WEB_ROOT", required=True, help="The Yeti endpoint."
16+
"--endpoint",
17+
envvar="YETI_WEB_ROOT",
18+
required=True,
19+
help="The Yeti endpoint, e.g. http://localhost:3000/",
1720
)
1821
@pass_context # Add this to pass the context to subcommands
1922
def cli(ctx, api_key, endpoint):

yeti/errors.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
class YetiApiError(RuntimeError):
1+
class YetiError(RuntimeError):
2+
"""Base class for errors in the Yeti package."""
3+
4+
5+
class YetiApiError(YetiError):
26
"""Base class for errors in the Yeti API."""
37

48
status_code: int
@@ -7,3 +11,10 @@ class YetiApiError(RuntimeError):
711
def __init__(self, status_code: int, message: str):
812
super().__init__(message)
913
self.status_code = status_code
14+
15+
16+
class YetiAuthError(YetiError):
17+
"""Error authenticating with the Yeti API."""
18+
19+
def __init__(self, message: str):
20+
super().__init__(message)

0 commit comments

Comments
 (0)