Skip to content

Commit 58a15b9

Browse files
author
Stephen Hoover
authored
BUG Ensure that polling thread shuts down when no longer needed (#45)
* BUG Ensure that polling thread shuts down when no longer needed When a Python session shuts down, all threads are joined. That means that we need to wait for the polling thread to complete before the interpreter can shut down. It hadn't been noticable when polling was the primary way of getting results, but with the use of the notifications endpoint, it matters. I observed a very short job which used notifications wait 9.5 minutes to complete, because it had to wait for the `sleep` inside our polling thread to finish. There's no way to shut down a thread in a ``ThreadPoolExecutor`` early, so replace it with our own thread. The ``cancel`` method lets us stop it early, and overloading ``join`` ensures that we never need to wait for a polling cycle to complete when exiting the interpreter. This also solves a problem I'd observed in testing, that the test session wouldn't exit until the poller shut down. * MAINT Remove overly-paranoid exception catching Stop wrapping `_check_result` in a try/except inside the thread. It's only necessary if we think there could be coding bugs inside `_check_result`. If there are, we should fix them rather than swallowing them. * TST Conditionally select polling interval in tests When we record tests with ``vcr``, we need to have normal polling intervals, so as not to break the Platform API. But when running tests later, we don't want to wait through the entire sequence of polling at the normal intervals. We had been solving this by patching some internals of `PollableResult`. That solution breaks when we swap in the new threading poller. Replace with conditional poll intervals. We need to remember to set `polling_interval` for all new tests, but in return we don't need to patch private methods of `PollableResult` (and it works with the new threading).
1 parent 3e3135b commit 58a15b9

File tree

5 files changed

+99
-104
lines changed

5 files changed

+99
-104
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
66
### API Changes
77
- Deprecate ``api_key`` input to higher-level functions and classes in favor of an ``APIClient`` input. The ``api_key`` will be removed in v2.0.0. (#46)
88

9+
### Fixed
10+
- Improved threading implementation in ``PollableResult`` so that it no longer blocks interpreter shutdown.
11+
912
### Added
1013
- Decorator function for deprecating parameters (#46)
1114

civis/polling.py

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from concurrent import futures
21
import time
2+
import threading
33

44
from civis.base import CivisJobFailure, CivisAsyncResultBase, FAILED, DONE
55
from civis.response import Response
@@ -8,6 +8,37 @@
88
_DEFAULT_POLLING_INTERVAL = 15
99

1010

11+
class _ResultPollingThread(threading.Thread):
12+
"""Poll a function until it returns a Response with a DONE state
13+
"""
14+
# Inspired by `threading.Timer`
15+
16+
def __init__(self, poller, poller_args, polling_interval):
17+
super().__init__()
18+
self.polling_interval = polling_interval
19+
self.poller = poller
20+
self.poller_args = poller_args
21+
self.finished = threading.Event()
22+
23+
def cancel(self):
24+
"""Stop the poller if it hasn't finished yet.
25+
"""
26+
self.finished.set()
27+
28+
def join(self, timeout=None):
29+
"""Shut down the polling when the thread is terminated.
30+
"""
31+
self.cancel()
32+
super().join(timeout=timeout)
33+
34+
def run(self):
35+
"""Poll until done.
36+
"""
37+
while not self.finished.wait(self.polling_interval):
38+
if self.poller(*self.poller_args).state in DONE:
39+
self.finished.set()
40+
41+
1142
class PollableResult(CivisAsyncResultBase):
1243
"""A class for tracking pollable results.
1344
@@ -55,7 +86,7 @@ class PollableResult(CivisAsyncResultBase):
5586
# Implementation notes: The `PollableResult` depends on some private
5687
# features of the `concurrent.futures.Future` class, so it's possible
5788
# that future versions of Python could break something here.
58-
# (It works under at least 3.4 and 3.5.)
89+
# (It works under at least 3.4, 3.5, and 3.6)
5990
# We use the following `Future` implementation details
6091
# - The `Future` checks its state against predefined strings. We use
6192
# `STATE_TRANS` to translate from the Civis platform states to `Future`
@@ -76,6 +107,8 @@ def __init__(self, poller, poller_args,
76107
api_key=api_key,
77108
client=client,
78109
poll_on_creation=poll_on_creation)
110+
if self.polling_interval <= 0:
111+
raise ValueError("The polling interval must be positive.")
79112

80113
# Polling arguments. Never poll more often than the requested interval.
81114
if poll_on_creation:
@@ -84,37 +117,16 @@ def __init__(self, poller, poller_args,
84117
self._last_polled = time.time()
85118
self._last_result = None
86119

87-
self._self_polling_executor = None
88-
89-
def _wait_for_completion(self):
90-
"""Poll the job every `polling_interval` seconds. Blocks until the
91-
job completes.
92-
"""
93-
try:
94-
while self._civis_state not in DONE:
95-
time.sleep(self.polling_interval)
96-
except Exception as e:
97-
# Exceptions are caught in `_check_result`, so
98-
# we should never get here. If there were to be a
99-
# bug in `_check_result`, however, we would get stuck
100-
# in an infinite loop without setting the `_result`.
101-
with self._condition:
102-
self._set_api_exception(exc=e)
103-
104-
def _poll_wait_elapsed(self, now):
105-
# thie exists because it's easier to monkeypatch in testing
106-
return (now - self._last_polled) >= self.polling_interval
120+
self._polling_thread = _ResultPollingThread(self._check_result, (),
121+
polling_interval)
107122

108123
def _check_result(self):
109124
"""Return the job result from Civis. Once the job completes, store the
110125
result and never poll again."""
111-
112-
# If we haven't started the polling thread, do it now.
113-
if self._self_polling_executor is None and self._result is None:
114-
# Start a single thread continuously polling. It will stop once the
115-
# job completes.
116-
self._self_polling_executor = futures.ThreadPoolExecutor(1)
117-
self._self_polling_executor.submit(self._wait_for_completion)
126+
# Start a single thread continuously polling.
127+
# It will stop once the job completes.
128+
if not self._polling_thread.is_alive() and self._result is None:
129+
self._polling_thread.start()
118130

119131
with self._condition:
120132
if self._result is not None:
@@ -125,7 +137,8 @@ def _check_result(self):
125137
# Check to see if the job has finished, but don't poll more
126138
# frequently than the requested polling frequency.
127139
now = time.time()
128-
if not self._last_polled or self._poll_wait_elapsed(now):
140+
if (not self._last_polled or
141+
(now - self._last_polled) >= self.polling_interval):
129142
# Poll for a new result
130143
self._last_polled = now
131144
try:
@@ -165,5 +178,7 @@ def _set_api_exception(self, exc, result=None):
165178
self.cleanup()
166179

167180
def cleanup(self):
168-
# This gets called after the result is set
169-
pass
181+
# This gets called after the result is set.
182+
# Ensure that the polling thread shuts down when it's no longer needed.
183+
if self._polling_thread.is_alive():
184+
self._polling_thread.cancel()

civis/tests/test_io.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
from civis.resources._resources import get_swagger_spec, generate_classes
1919
from civis.tests.testcase import (CivisVCRTestCase,
2020
cassette_dir,
21-
conditionally_patch)
21+
POLL_INTERVAL)
2222

2323
swagger_import_str = 'civis.resources._resources.get_swagger_spec'
2424
THIS_DIR = os.path.dirname(os.path.realpath(__file__))
2525
with open(os.path.join(THIS_DIR, "civis_api_spec.json")) as f:
2626
civis_api_spec = json.load(f, object_pairs_hook=OrderedDict)
2727

2828

29-
@conditionally_patch('civis.polling.time.sleep', return_value=None)
30-
@conditionally_patch('civis.polling.PollableResult._poll_wait_elapsed',
31-
return_value=True)
3229
@patch(swagger_import_str, return_value=civis_api_spec)
3330
class ImportTests(CivisVCRTestCase):
31+
# Note that all functions tested here should use a
32+
# `polling_interval=POLL_INTERVAL` input. This lets us use
33+
# sensible polling intervals when recording, but speed through
34+
# the calls in the VCR cassette when testing later.
3435

3536
@classmethod
3637
def setUpClass(cls):
@@ -43,9 +44,6 @@ def tearDownClass(cls):
4344
generate_classes.cache_clear()
4445

4546
@classmethod
46-
@conditionally_patch('civis.polling.time.sleep', return_value=None)
47-
@conditionally_patch('civis.polling.PollableResult._poll_wait_elapsed',
48-
return_value=True)
4947
@patch(swagger_import_str, return_value=civis_api_spec)
5048
def setup_class(cls, *mocks):
5149
setup_vcr = vcr.VCR(filter_headers=['Authorization'])
@@ -71,14 +69,16 @@ def setup_class(cls, *mocks):
7169
INSERT INTO scratch.api_client_test_fixture
7270
VALUES (1,2,3);
7371
"""
74-
res = civis.io.query_civis(sql, 'redshift-general')
72+
res = civis.io.query_civis(sql, 'redshift-general',
73+
polling_interval=POLL_INTERVAL)
7574
res.result() # block
7675

7776
# create an export to check get_url. also tests export_csv
7877
with tempfile.NamedTemporaryFile() as tmp:
7978
sql = "SELECT * FROM scratch.api_client_test_fixture"
8079
database = 'redshift-general'
81-
result = civis.io.civis_to_csv(tmp.name, sql, database)
80+
result = civis.io.civis_to_csv(tmp.name, sql, database,
81+
polling_interval=POLL_INTERVAL)
8282
result = result.result()
8383
assert result.state == 'succeeded'
8484

@@ -106,7 +106,8 @@ def test_csv_to_civis(self, *mocks):
106106
table = "scratch.api_client_test_fixture"
107107
database = 'redshift-general'
108108
result = civis.io.csv_to_civis(tmp.name, database, table,
109-
existing_table_rows='truncate')
109+
existing_table_rows='truncate',
110+
polling_interval=POLL_INTERVAL)
110111
result = result.result() # block until done
111112

112113
assert isinstance(result.id, int)
@@ -117,22 +118,25 @@ def test_csv_to_civis(self, *mocks):
117118
def test_read_civis_pandas(self, *mocks):
118119
expected = pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
119120
df = civis.io.read_civis('scratch.api_client_test_fixture',
120-
'redshift-general', use_pandas=True)
121+
'redshift-general', use_pandas=True,
122+
polling_interval=POLL_INTERVAL)
121123
assert df.equals(expected)
122124

123125
@patch(swagger_import_str, return_value=civis_api_spec)
124126
def test_read_civis_no_pandas(self, *mocks):
125127
expected = [['a', 'b', 'c'], ['1', '2', '3']]
126128
data = civis.io.read_civis('scratch.api_client_test_fixture',
127-
'redshift-general', use_pandas=False)
129+
'redshift-general', use_pandas=False,
130+
polling_interval=POLL_INTERVAL)
128131
assert data == expected
129132

130133
@patch(swagger_import_str, return_value=civis_api_spec)
131134
def test_read_civis_sql(self, *mocks):
132135
sql = "SELECT * FROM scratch.api_client_test_fixture"
133136
expected = [['a', 'b', 'c'], ['1', '2', '3']]
134137
data = civis.io.read_civis_sql(sql, 'redshift-general',
135-
use_pandas=False)
138+
use_pandas=False,
139+
polling_interval=POLL_INTERVAL)
136140
assert data == expected
137141

138142
@pytest.mark.skipif(not has_pandas, reason="pandas not installed")
@@ -141,15 +145,16 @@ def test_dataframe_to_civis(self, *mocks):
141145
df = pd.DataFrame([['1', '2', '3']], columns=['a', 'b', 'c'])
142146
result = civis.io.dataframe_to_civis(df, 'redshift-general',
143147
'scratch.api_client_test_fixture',
144-
existing_table_rows='truncate')
148+
existing_table_rows='truncate',
149+
polling_interval=POLL_INTERVAL)
145150
result = result.result()
146151
assert result.state == 'succeeded'
147152

148153
@patch(swagger_import_str, return_value=civis_api_spec)
149154
def test_civis_to_multifile_csv(self, *mocks):
150155
sql = "SELECT * FROM scratch.api_client_test_fixture"
151-
result = civis.io.civis_to_multifile_csv(sql,
152-
database='redshift-general')
156+
result = civis.io.civis_to_multifile_csv(
157+
sql, database='redshift-general', polling_interval=POLL_INTERVAL)
153158
assert set(result.keys()) == {'entries', 'query', 'header'}
154159
assert result['query'] == sql
155160
assert result['header'] == ['a', 'b', 'c']
@@ -164,13 +169,15 @@ def test_civis_to_multifile_csv(self, *mocks):
164169
def test_transfer_table(self, *mocks):
165170
result = civis.io.transfer_table('redshift-general', 'redshift-test',
166171
'scratch.api_client_test_fixture',
167-
'scratch.api_client_test_fixture')
172+
'scratch.api_client_test_fixture',
173+
polling_interval=POLL_INTERVAL)
168174
result = result.result()
169175
assert result.state == 'succeeded'
170176

171177
# check for side effect
172178
sql = 'select * from scratch.api_client_test_fixture'
173-
result = civis.io.query_civis(sql, 'redshift-test').result()
179+
result = civis.io.query_civis(sql, 'redshift-test',
180+
polling_interval=POLL_INTERVAL).result()
174181
assert result.state == 'succeeded'
175182

176183
def test_get_sql_select(self, *mocks):

civis/tests/test_polling.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@ def __init__(self, state):
1515
self.state = state
1616

1717

18-
def func():
19-
pass
20-
21-
2218
def create_pollable_result(state, exception=None, result=None):
23-
f = PollableResult(State, (state, ), polling_interval=0)
19+
f = PollableResult(State, (state, ), polling_interval=0.001)
2420
f._exception = exception
2521
f._result = result
2622
return f
@@ -29,12 +25,9 @@ def create_pollable_result(state, exception=None, result=None):
2925
CANCELLED_RESULT = create_pollable_result(state='cancelled')
3026
FINISHED_RESULT = create_pollable_result(state='success')
3127
QUEUED_RESULT = create_pollable_result(state='queued')
32-
# avoid the polling thread hanging
33-
QUEUED_RESULT._wait_for_completion = func
3428

3529

3630
class TestPolling(unittest.TestCase):
37-
3831
def test_as_completed(self):
3932
my_futures = [QUEUED_RESULT, CANCELLED_RESULT, FINISHED_RESULT]
4033
fs = futures.as_completed(my_futures)
@@ -64,51 +57,33 @@ def test_error_setting(self):
6457
assert isinstance(pollable.exception(), ZeroDivisionError)
6558

6659
def test_timeout(self):
67-
# Note: Something about the test framework seems to prevent the
68-
# Pollable result from being destroyed while the polling
69-
# thread is running. The test will hang if the PollableResult
70-
# never completes. I haven't seen the same problem in
71-
# the interpreter.
7260
pollable = PollableResult(
73-
mock.Mock(side_effect=[Response({"state": "running"}),
74-
ValueError()]), (),
61+
mock.Mock(return_value=Response({"state": "running"})),
62+
poller_args=(),
7563
polling_interval=0.1)
7664
pytest.raises(futures.TimeoutError, pollable.result, timeout=0.05)
7765

78-
def test_no_hanging(self):
79-
# Make sure that an error in the `_check_result` doesn't
80-
# cause an infinite loop.
81-
class PollableResultTester(PollableResult):
82-
def __init__(self, *args, **kwargs):
83-
self._poll_ct = 0
84-
super().__init__(*args, **kwargs)
85-
86-
def _check_result(self):
87-
if self._poll_ct is not None:
88-
self._poll_ct += 1
89-
if self._poll_ct > 10:
90-
self._poll_ct = None # Disable the counter.
91-
# Make the _wait_for_completion loop fail.
92-
raise ZeroDivisionError()
93-
return super()._check_result()
94-
95-
# The following should raise a CivisJobFailure before a TimeoutError.
96-
pollable = PollableResultTester(
97-
lambda: Response({"state": "running"}), (),
98-
polling_interval=0.1)
99-
pytest.raises(ZeroDivisionError, pollable.result, timeout=5)
100-
10166
def test_poll_on_creation(self):
102-
poller = mock.Mock(side_effect=Response({"state": "running"}))
67+
poller = mock.Mock(return_value=Response({"state": "running"}))
10368
pollable = PollableResult(poller,
10469
(),
10570
polling_interval=0.01,
10671
poll_on_creation=False)
107-
repr(pollable)
72+
pollable.done() # Check status once to start the polling thread
10873
assert poller.call_count == 0
109-
time.sleep(0.02)
74+
time.sleep(0.015)
11075
assert poller.call_count == 1
11176

11277

78+
def test_repeated_polling():
79+
# Verify that we poll the expected number of times.
80+
poller = mock.Mock(return_value=Response({"state": "running"}))
81+
pollable = PollableResult(poller, (), polling_interval=0.1)
82+
pollable.done() # Check status once to start the polling thread
83+
assert poller.call_count == 1, "Poll once on the first status check"
84+
time.sleep(0.25)
85+
assert poller.call_count == 3, "After waiting 2.5x the polling interval"
86+
87+
11388
if __name__ == '__main__':
11489
unittest.main()

civis/tests/testcase.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
1-
from functools import wraps
21
import os
32

43
from vcr import persist
54
from vcr.serialize import CASSETTE_FORMAT_VERSION
65
from vcr.serializers import compat
76
from vcr_unittest import VCRTestCase
87

9-
10-
def conditionally_patch(target, *args, **kwargs):
11-
from unittest.mock import patch
12-
if os.getenv('GENERATE_TESTS'):
13-
def pass_func(func):
14-
wraps(func)
15-
return func
16-
return pass_func
17-
else:
18-
def decorated_func(func):
19-
return patch(target, *args, **kwargs)(func)
20-
return decorated_func
8+
# The "GENERATE_TEST" environment variable indicates that
9+
# we're recording new cassettes.
10+
if os.environ.get('GENERATE_TESTS'):
11+
# Use default polling intervals if generating new tests.
12+
POLL_INTERVAL = None
13+
else:
14+
# Speed through calls in pre-recorded VCR cassettes.
15+
POLL_INTERVAL = 0.00001
2116

2217

2318
def cassette_dir():

0 commit comments

Comments
 (0)