Skip to content

Commit 688f6a2

Browse files
authored
Merge pull request #3 from yeti-platform/errorhandling
Add some error handling
2 parents 9209a1f + b62cbef commit 688f6a2

File tree

3 files changed

+45
-11
lines changed

3 files changed

+45
-11
lines changed

tests/api.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import unittest
22
from unittest.mock import patch, MagicMock
33
from yeti.api import YetiApi
4+
from yeti import errors
5+
6+
import requests
47

58

69
class TestYetiApi(unittest.TestCase):
@@ -259,6 +262,23 @@ def test_search_graph(self, mock_post):
259262
},
260263
)
261264

265+
@patch("yeti.api.requests.Session.post")
266+
def test_error_message(self, mock_post):
267+
# create mock requests response that raises an requests.exceptions.HTTPError for status
268+
mock_response = MagicMock()
269+
mock_exception_with_status_code = requests.exceptions.HTTPError()
270+
mock_exception_with_status_code.response = MagicMock()
271+
mock_exception_with_status_code.response.status_code = 400
272+
mock_exception_with_status_code.response.text = "error_message"
273+
mock_response.raise_for_status.side_effect = mock_exception_with_status_code
274+
mock_post.return_value = mock_response
275+
276+
with self.assertRaises(errors.YetiApiError) as raised:
277+
self.api.new_indicator({"name": "test_indicator"})
278+
279+
self.assertEqual(str(raised.exception), "error_message")
280+
self.assertEqual(raised.exception.status_code, 400)
281+
262282

263283
if __name__ == "__main__":
264284
unittest.main()

yeti/api.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Python client for the Yeti API."""
22

3-
import requests
4-
import requests_toolbelt.multipart.encoder as encoder
5-
63
import json
74
from typing import Any, Sequence
85

6+
import yeti.errors as errors
7+
import requests
8+
import requests_toolbelt.multipart.encoder as encoder
99

1010
TYPE_TO_ENDPOINT = {
1111
"indicator": "/api/v2/indicators",
@@ -73,14 +73,19 @@ def do_request(
7373
if body:
7474
request_kwargs["body"] = body
7575

76-
if method == "POST":
77-
response = self.client.post(url, **request_kwargs)
78-
elif method == "PATCH":
79-
response = self.client.patch(url, **request_kwargs)
80-
elif method == "GET":
81-
response = self.client.get(url, **request_kwargs)
82-
else:
83-
raise ValueError(f"Unsupported method: {method}")
76+
try:
77+
if method == "POST":
78+
response = self.client.post(url, **request_kwargs)
79+
elif method == "PATCH":
80+
response = self.client.patch(url, **request_kwargs)
81+
elif method == "GET":
82+
response = self.client.get(url, **request_kwargs)
83+
else:
84+
raise ValueError(f"Unsupported method: {method}")
85+
response.raise_for_status()
86+
except requests.exceptions.HTTPError as e:
87+
raise errors.YetiApiError(e.response.status_code, e.response.text)
88+
8489
return response.bytes
8590

8691
def auth_api_key(self, apikey: str) -> None:

yeti/errors.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class YetiApiError(RuntimeError):
2+
"""Base class for errors in the Yeti API."""
3+
4+
status_code: int
5+
message: str
6+
7+
def __init__(self, status_code: int, message: str):
8+
super().__init__(message)
9+
self.status_code = status_code

0 commit comments

Comments
 (0)