Skip to content

Features #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions tests/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def test_new_dfiq_from_yaml(self, mock_post):
json={
"dfiq_type": "type",
"dfiq_yaml": "yaml_content",
"update_indicators": True,
},
)

Expand All @@ -145,7 +144,24 @@ def test_patch_dfiq_from_yaml(self, mock_patch):
json={
"dfiq_type": "type",
"dfiq_yaml": "yaml_content",
"update_indicators": True,
},
)

@patch("yeti.api.requests.Session.patch")
def test_patch_dfiq(self, mock_patch):
mock_response = MagicMock()
mock_response.content = b'{"id": "patched_dfiq"}'
mock_patch.return_value = mock_response

result = self.api.patch_dfiq(
{"name": "patched_dfiq", "id": 1, "type": "question"}
)
self.assertEqual(result, {"id": "patched_dfiq"})
mock_patch.assert_called_with(
"http://fake-url/api/v2/dfiq/1",
json={
"dfiq_object": {"name": "patched_dfiq", "type": "question", "id": 1},
"dfiq_type": "question",
},
)

Expand Down
57 changes: 53 additions & 4 deletions yeti/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Python client for the Yeti API."""

import json
import logging
from typing import Any, Sequence

import requests
Expand All @@ -24,6 +25,13 @@
YetiLinkObject = dict[str, Any]


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


class YetiApi:
"""API object to interact with the Yeti API.

Expand All @@ -40,13 +48,21 @@ def __init__(self, url_root: str):
}
self._url_root = url_root

self._auth_function = ""
self._auth_function_map = {
"auth_api_key": self.auth_api_key,
}

self._apikey = None

def do_request(
self,
method: str,
url: str,
json_data: dict[str, Any] | None = None,
body: bytes | None = None,
headers: dict[str, Any] | None = None,
retries: int = 3,
) -> bytes:
"""Issues a request to the given URL.

Expand All @@ -56,6 +72,7 @@ def do_request(
json: The JSON payload to include in the request.
body: The body to include in the request.
headers: Extra headers to include in the request.
retries: The number of times to retry the request.

Returns:
The response from the API; a bytes object.
Expand Down Expand Up @@ -85,17 +102,30 @@ def do_request(
raise ValueError(f"Unsupported method: {method}")
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
if retries == 0:
raise errors.YetiAuthError(str(e)) from e
self.refresh_auth()
return self.do_request(
method, url, json_data, body, headers, retries - 1
)

raise errors.YetiApiError(e.response.status_code, e.response.text)

return response.content

def auth_api_key(self, apikey: str) -> None:
def auth_api_key(self, apikey: str | None = None) -> None:
"""Authenticates a session using an API key."""
# Use long-term refresh API token to get an access token
if apikey is not None:
self._apikey = apikey
if not self._apikey:
raise ValueError("No API key provided.")

response = self.do_request(
"POST",
f"{self._url_root}{API_TOKEN_ENDPOINT}",
headers={"x-yeti-apikey": apikey},
headers={"x-yeti-apikey": self._apikey},
)

access_token = json.loads(response).get("access_token")
Expand All @@ -107,6 +137,14 @@ def auth_api_key(self, apikey: str) -> None:
authd_session.headers.update({"authorization": f"Bearer {access_token}"})
self.client = authd_session

self._auth_function = "auth_api_key"

def refresh_auth(self):
if self._auth_function:
self._auth_function_map[self._auth_function]()
else:
logger.warning("No auth function set, cannot refresh auth.")

def search_indicators(
self,
name: str | None = None,
Expand Down Expand Up @@ -261,7 +299,6 @@ def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject:
params = {
"dfiq_type": dfiq_type,
"dfiq_yaml": dfiq_yaml,
"update_indicators": True,
}
response = self.do_request(
"POST", f"{self._url_root}/api/v2/dfiq/from_yaml", json_data=params
Expand All @@ -278,13 +315,25 @@ def patch_dfiq_from_yaml(
params = {
"dfiq_type": dfiq_type,
"dfiq_yaml": dfiq_yaml,
"update_indicators": True,
}
response = self.do_request(
"PATCH", f"{self._url_root}/api/v2/dfiq/{yeti_id}", json_data=params
)
return json.loads(response)

def patch_dfiq(self, dfiq_object: dict[str, Any]) -> YetiObject:
"""Patches a DFIQ object in Yeti."""
params = {
"dfiq_type": dfiq_object["type"],
"dfiq_object": dfiq_object,
}
response = self.do_request(
"PATCH",
f"{self._url_root}/api/v2/dfiq/{dfiq_object['id']}",
json_data=params,
)
return json.loads(response)

def download_dfiq_archive(self, dfiq_type: str | None = None) -> bytes:
"""Downloads an archive containing all DFIQ data from Yeti.

Expand Down
5 changes: 4 additions & 1 deletion yeti/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ def __init__(self):
@click.group()
@click.option("--api-key", envvar="YETI_API_KEY", required=True, help="Your API key.")
@click.option(
"--endpoint", envvar="YETI_WEB_ROOT", required=True, help="The Yeti endpoint."
"--endpoint",
envvar="YETI_WEB_ROOT",
required=True,
help="The Yeti endpoint, e.g. http://localhost:3000/",
)
@pass_context # Add this to pass the context to subcommands
def cli(ctx, api_key, endpoint):
Expand Down
13 changes: 12 additions & 1 deletion yeti/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
class YetiApiError(RuntimeError):
class YetiError(RuntimeError):
"""Base class for errors in the Yeti package."""


class YetiApiError(YetiError):
"""Base class for errors in the Yeti API."""

status_code: int
Expand All @@ -7,3 +11,10 @@ class YetiApiError(RuntimeError):
def __init__(self, status_code: int, message: str):
super().__init__(message)
self.status_code = status_code


class YetiAuthError(YetiError):
"""Error authenticating with the Yeti API."""

def __init__(self, message: str):
super().__init__(message)