Skip to content

Commit 2856c96

Browse files
committed
fix linting and tests
1 parent 7733072 commit 2856c96

File tree

6 files changed

+99
-57
lines changed

6 files changed

+99
-57
lines changed

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ ignore_missing_imports = True
77

88
[mypy-boto3.*]
99
ignore_missing_imports = True
10+
11+
[mypy-pytz.*]
12+
ignore_missing_imports = True

tap_cloudwatch/cloudwatch_api.py

+60-28
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
import os
44
import time
5+
from collections import deque
56
from datetime import datetime, timezone
7+
from math import ceil
8+
9+
import boto3
610
import pytz
711

812
from tap_cloudwatch.exception import InvalidQueryException
9-
from collections import deque
10-
import boto3
11-
from math import ceil
13+
1214

1315
class CloudwatchAPI:
1416
"""Cloudwatch class for interacting with the API."""
@@ -66,7 +68,7 @@ def _create_client(self, config):
6668
def _request_more_records():
6769
return True
6870

69-
def split_batch_into_windows(self, start_time, end_time, batch_increment_s):
71+
def _split_batch_into_windows(self, start_time, end_time, batch_increment_s):
7072
diff_s = end_time - start_time
7173
total_batches = ceil(diff_s / batch_increment_s)
7274
batch_windows = []
@@ -81,62 +83,86 @@ def split_batch_into_windows(self, start_time, end_time, batch_increment_s):
8183
batch_windows.append((query_start, query_end))
8284
return batch_windows
8385

84-
def validate_query(self, query):
86+
def _validate_query(self, query):
8587
if "|sort" in query.replace(" ", ""):
8688
raise InvalidQueryException("sort not allowed")
8789
if "|limit" in query.replace(" ", ""):
8890
raise InvalidQueryException("limit not allowed")
8991
if "stats" in query:
9092
raise InvalidQueryException("stats not allowed")
9193
if "@timestamp" not in query.split("|")[0]:
92-
raise InvalidQueryException("@timestamp field is used as the replication key so it must be selected")
94+
raise InvalidQueryException(
95+
"@timestamp field is used as the replication key so it must be selected"
96+
)
9397

9498
def get_records_iterator(self, bookmark, log_group, query, batch_increment_s):
9599
"""Retrieve records from Cloudwatch."""
96100
end_time = datetime.now(timezone.utc).timestamp()
97101
start_time = bookmark.timestamp()
98-
self.validate_query(query)
99-
batch_windows = self.split_batch_into_windows(start_time, end_time, batch_increment_s)
102+
self._validate_query(query)
103+
batch_windows = self._split_batch_into_windows(
104+
start_time, end_time, batch_increment_s
105+
)
100106

101107
queue = deque()
102108
for window in batch_windows:
103109
if len(queue) < (self.max_concurrent_queries - 1):
104-
queue.append((self.start_query(window[0], window[1], log_group, query), window[0], window[1]))
110+
queue.append(
111+
(
112+
self._start_query(window[0], window[1], log_group, query),
113+
window[0],
114+
window[1],
115+
)
116+
)
105117
else:
106118
query_id, start, end = queue.popleft()
107-
queue.append((self.start_query(window[0], window[1], log_group, query), window[0], window[1]))
108-
results = self.get_results(log_group, start, end, query, query_id)
119+
queue.append(
120+
(
121+
self._start_query(window[0], window[1], log_group, query),
122+
window[0],
123+
window[1],
124+
)
125+
)
126+
results = self._get_results(log_group, start, end, query, query_id)
109127
yield results
110128

111129
while len(queue) > 0:
112130
query_id, start, end = queue.popleft()
113-
results = self.get_results(log_group, start, end, query, query_id)
131+
results = self._get_results(log_group, start, end, query, query_id)
114132
yield results
115133

116-
def handle_limit_exceeded(self, response, log_group, query_start, query_end, query):
134+
def _handle_limit_exceeded(
135+
self, response, log_group, query_start, query_end, query
136+
):
117137
results = response.get("results")
118138
last_record = results[-1]
119139

120-
latest_ts_str = [i["value"] for i in last_record if i["field"] == "@timestamp"][0]
140+
latest_ts_str = [i["value"] for i in last_record if i["field"] == "@timestamp"][
141+
0
142+
]
121143
# Include latest ts in query, this could cause duplicates but
122144
# without it we might miss ties
123-
new_query_start = int(datetime.fromisoformat(latest_ts_str).replace(tzinfo=pytz.UTC).timestamp())
124-
new_query_id = self.start_query(new_query_start, query_end, log_group, query)
125-
return self.get_results(log_group, new_query_start, query_end, query, new_query_id)
145+
new_query_start = int(
146+
datetime.fromisoformat(latest_ts_str).replace(tzinfo=pytz.UTC).timestamp()
147+
)
148+
new_query_id = self._start_query(new_query_start, query_end, log_group, query)
149+
return self._get_results(
150+
log_group, new_query_start, query_end, query, new_query_id
151+
)
126152

127-
def alter_query(self, query):
153+
def _alter_query(self, query):
128154
query += " | sort @timestamp asc"
129155
return query
130156

131-
def start_query(self, query_start, query_end, log_group, query, prev_start=None):
157+
def _start_query(self, query_start, query_end, log_group, query, prev_start=None):
132158
self.logger.info(
133159
(
134160
"Submitting query for batch from:"
135161
f" `{datetime.utcfromtimestamp(query_start).isoformat()} UTC` -"
136162
f" `{datetime.utcfromtimestamp(query_end).isoformat()} UTC`"
137163
)
138164
)
139-
query = self.alter_query(query)
165+
query = self._alter_query(query)
140166
start_query_response = self.client.start_query(
141167
logGroupName=log_group,
142168
startTime=query_start,
@@ -146,7 +172,9 @@ def start_query(self, query_start, query_end, log_group, query, prev_start=None)
146172
)
147173
return start_query_response["queryId"]
148174

149-
def get_results(self, log_group, query_start, query_end, query, query_id, prev_start=None):
175+
def _get_results(
176+
self, log_group, query_start, query_end, query, query_id, prev_start=None
177+
):
150178
self.logger.info(
151179
(
152180
"Retrieving results for batch from:"
@@ -161,15 +189,19 @@ def get_results(self, log_group, query_start, query_end, query, query_id, prev_s
161189
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") != 200:
162190
raise Exception(f"Failed: {response}")
163191
result_size = response.get("statistics", {}).get("recordsMatched")
164-
results = response['results']
165-
self.logger.info(
166-
f"Result set size '{int(result_size)}' received."
167-
)
192+
results = response["results"]
193+
self.logger.info(f"Result set size '{int(result_size)}' received.")
168194
if result_size > self.limit:
169195
if prev_start == query_start:
170-
raise Exception("Stuck in a loop, smaller batch still exceeds limit. Reduce batch window.")
196+
raise Exception(
197+
"Stuck in a loop, smaller batch still exceeds limit."
198+
"Reduce batch window."
199+
)
171200
self.logger.info(
172-
f"Result set size '{int(result_size)}' exceeded limit '{self.limit}'. Re-running sub-batch..."
201+
f"Result set size '{int(result_size)}' exceeded limit "
202+
f"'{self.limit}'. Re-running sub-batch..."
203+
)
204+
results += self._handle_limit_exceeded(
205+
response, log_group, query_start, query_end, query
173206
)
174-
results += self.handle_limit_exceeded(response, log_group, query_start, query_end, query)
175207
return results

tap_cloudwatch/exception.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""Custom exceptions."""
2+
3+
14
class InvalidQueryException(Exception):
2-
"Raised when the input value is less than 18"
3-
pass
5+
"""Raised when the input query is invalid."""
6+
7+
pass

tap_cloudwatch/streams.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@ def schema(self):
2323
# | parse @message "[*] *" as loggingType, loggingMessage
2424
properties.append(
2525
th.Property(
26-
"ptr",
27-
th.StringType(),
28-
description="The identifier for the log record."
26+
"ptr", th.StringType(), description="The identifier for the log record."
2927
)
3028
)
3129
properties.append(
3230
th.Property(
33-
"timestamp",
34-
th.DateTimeType(),
35-
description="The timestamp of the log."
31+
"timestamp", th.DateTimeType(), description="The timestamp of the log."
3632
)
3733
)
3834
for prop in self.config.get("query").split("|")[0].split(","):

tap_cloudwatch/tests/test_cloudwatch_api.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,43 @@
11
"""Tests cloudwatch api module."""
22

3-
from tap_cloudwatch.cloudwatch_api import CloudwatchAPI
4-
from tap_cloudwatch.exception import InvalidQueryException
5-
import pytest
6-
from unittest.mock import patch
7-
from datetime import datetime, timezone
3+
import logging
4+
from contextlib import nullcontext as does_not_raise
85

96
import boto3
7+
import pytest
108
from botocore.stub import Stubber
119
from freezegun import freeze_time
12-
import logging
13-
from contextlib import nullcontext as does_not_raise
10+
11+
from tap_cloudwatch.cloudwatch_api import CloudwatchAPI
12+
from tap_cloudwatch.exception import InvalidQueryException
1413

1514

1615
@pytest.mark.parametrize(
17-
'start,end,batch,expected',
16+
"start,end,batch,expected",
1817
[
1918
[1672272000, 1672275600, 3600, [(1672272000, 1672275600)]],
20-
[1672272000, 1672275601, 3600, [(1672272000, 1672275600), (1672275601, 1672279200)]],
19+
[
20+
1672272000,
21+
1672275601,
22+
3600,
23+
[(1672272000, 1672275600), (1672275601, 1672279200)],
24+
],
2125
],
2226
)
2327
def test_split_batch_into_windows(start, end, batch, expected):
2428
"""Run standard tap tests from the SDK."""
2529
api = CloudwatchAPI(None)
26-
batches = api.split_batch_into_windows(start, end, batch)
30+
batches = api._split_batch_into_windows(start, end, batch)
2731
assert batches == expected
2832

2933

30-
3134
@pytest.mark.parametrize(
32-
'query,expectation',
35+
"query,expectation",
3336
[
34-
["fields @timestamp, @message | sort @timestamp desc", pytest.raises(InvalidQueryException)],
37+
[
38+
"fields @timestamp, @message | sort @timestamp desc",
39+
pytest.raises(InvalidQueryException),
40+
],
3541
["fields @timestamp, @message | limit 5", pytest.raises(InvalidQueryException)],
3642
["stats count(*) by duration as time", pytest.raises(InvalidQueryException)],
3743
["fields @message", pytest.raises(InvalidQueryException)],
@@ -42,7 +48,8 @@ def test_validate_query(query, expectation):
4248
"""Run standard tap tests from the SDK."""
4349
api = CloudwatchAPI(None)
4450
with expectation:
45-
api.validate_query(query)
51+
api._validate_query(query)
52+
4653

4754
@freeze_time("2022-12-30")
4855
def test_handle_batch_window():
@@ -66,7 +73,7 @@ def test_handle_batch_window():
6673
]
6774
],
6875
"ResponseMetadata": {"HTTPStatusCode": 200},
69-
"statistics": {"recordsMatched": 10000}
76+
"statistics": {"recordsMatched": 10000},
7077
}
7178
stubber.add_response(
7279
"start_query",
@@ -86,7 +93,7 @@ def test_handle_batch_window():
8693
)
8794
stubber.activate()
8895

89-
query_id = api.start_query(query_start, query_end, log_group, in_query)
90-
output = api.get_results(log_group, query_start, query_end, in_query, query_id)
91-
92-
assert response["results"] == output
96+
query_id = api._start_query(query_start, query_end, log_group, in_query)
97+
output = api._get_results(log_group, query_start, query_end, in_query, query_id)
98+
99+
assert response["results"] == output

tap_cloudwatch/tests/test_core.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"query": "fields @timestamp, @message",
1616
"aws_region_name": "us-east-1",
1717
"start_date": "2022-12-29",
18-
"batch_increment_s": 86400
18+
"batch_increment_s": 86400,
1919
}
2020

2121
client = boto3.client("logs", region_name="us-east-1")
@@ -49,7 +49,7 @@ def test_standard_tap_tests(patch_client):
4949
]
5050
],
5151
"ResponseMetadata": {"HTTPStatusCode": 200},
52-
"statistics": {"recordsMatched": 0}
52+
"statistics": {"recordsMatched": 0},
5353
},
5454
{"queryId": "123"},
5555
)

0 commit comments

Comments
 (0)