Skip to content

Commit d1d59b5

Browse files
authored
fix: (CDK) (AsyncRetriever) - fix bug when TIMEOUT Jobs are retried, ignoring the polling_job_timeout setting (#429)
1 parent 923acc2 commit d1d59b5

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

airbyte_cdk/sources/declarative/async_job/job_orchestrator.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
self._non_breaking_exceptions: List[Exception] = []
180180

181181
def _replace_failed_jobs(self, partition: AsyncPartition) -> None:
182-
failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT)
182+
failed_status_jobs = (AsyncJobStatus.FAILED,)
183183
jobs_to_replace = [job for job in partition.jobs if job.status() in failed_status_jobs]
184184
for job in jobs_to_replace:
185185
new_job = self._start_job(job.job_parameters(), job.api_job_id())
@@ -359,14 +359,11 @@ def _process_running_partitions_and_yield_completed_ones(
359359
self._process_partitions_with_errors(partition)
360360
case _:
361361
self._stop_timed_out_jobs(partition)
362+
# re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated
363+
self._reallocate_partition(current_running_partitions, partition)
362364

363-
# job will be restarted in `_start_job`
364-
current_running_partitions.insert(0, partition)
365-
366-
for job in partition.jobs:
367-
# We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority
368-
if job.status() == AsyncJobStatus.COMPLETED:
369-
self._job_tracker.remove_job(job.api_job_id())
365+
# We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority
366+
self._remove_completed_or_timed_out_jobs(partition)
370367

371368
# update the referenced list with running partitions
372369
self._running_partitions = current_running_partitions
@@ -381,8 +378,11 @@ def _stop_partition(self, partition: AsyncPartition) -> None:
381378
def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None:
382379
for job in partition.jobs:
383380
if job.status() == AsyncJobStatus.TIMED_OUT:
384-
# we don't free allocation here because it is expected to retry the job
385-
self._abort_job(job, free_job_allocation=False)
381+
self._abort_job(job, free_job_allocation=True)
382+
raise AirbyteTracedException(
383+
internal_message=f"Job {job.api_job_id()} has timed out. Try increasing the `polling job timeout`.",
384+
failure_type=FailureType.config_error,
385+
)
386386

387387
def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
388388
try:
@@ -392,6 +392,34 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
392392
except Exception as exception:
393393
LOGGER.warning(f"Could not free budget for job {job.api_job_id()}: {exception}")
394394

395+
def _remove_completed_or_timed_out_jobs(self, partition: AsyncPartition) -> None:
396+
"""
397+
Remove completed or timed out jobs from the partition.
398+
399+
Args:
400+
partition (AsyncPartition): The partition to process.
401+
"""
402+
for job in partition.jobs:
403+
if job.status() in [AsyncJobStatus.COMPLETED, AsyncJobStatus.TIMED_OUT]:
404+
self._job_tracker.remove_job(job.api_job_id())
405+
406+
def _reallocate_partition(
407+
self,
408+
current_running_partitions: List[AsyncPartition],
409+
partition: AsyncPartition,
410+
) -> None:
411+
"""
412+
Reallocate the partition by starting a new job for each job in the
413+
partition.
414+
Args:
415+
current_running_partitions (list): The list of currently running partitions.
416+
partition (AsyncPartition): The partition to reallocate.
417+
"""
418+
for job in partition.jobs:
419+
if job.status() != AsyncJobStatus.TIMED_OUT:
420+
# allow the FAILED jobs to be re-allocated for partition
421+
current_running_partitions.insert(0, partition)
422+
395423
def _process_partitions_with_errors(self, partition: AsyncPartition) -> None:
396424
"""
397425
Process a partition with status errors (FAILED and TIMEOUT).

airbyte_cdk/sources/declarative/requesters/http_job_repository.py

+52-4
Original file line numberDiff line numberDiff line change
@@ -273,24 +273,72 @@ def _clean_up_job(self, job_id: str) -> None:
273273
del self._create_job_response_by_id[job_id]
274274
del self._polling_job_response_by_id[job_id]
275275

276+
def _get_creation_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]:
277+
"""
278+
Returns the interpolation context for the creation response.
279+
280+
Args:
281+
job (AsyncJob): The job for which to get the creation response interpolation context.
282+
283+
Returns:
284+
Dict[str, Any]: The interpolation context as a dictionary.
285+
"""
286+
# TODO: currently we support only JsonDecoder to decode the response to track the ids or the status
287+
# of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future
288+
creation_response_context = dict(self._create_job_response_by_id[job.api_job_id()].json())
289+
if not "headers" in creation_response_context:
290+
creation_response_context["headers"] = self._create_job_response_by_id[
291+
job.api_job_id()
292+
].headers
293+
if not "request" in creation_response_context:
294+
creation_response_context["request"] = self._create_job_response_by_id[
295+
job.api_job_id()
296+
].request
297+
return creation_response_context
298+
299+
def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]:
300+
"""
301+
Returns the interpolation context for the polling response.
302+
303+
Args:
304+
job (AsyncJob): The job for which to get the polling response interpolation context.
305+
306+
Returns:
307+
Dict[str, Any]: The interpolation context as a dictionary.
308+
"""
309+
# TODO: currently we support only JsonDecoder to decode the response to track the ids or the status
310+
# of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future
311+
polling_response_context = dict(self._polling_job_response_by_id[job.api_job_id()].json())
312+
if not "headers" in polling_response_context:
313+
polling_response_context["headers"] = self._polling_job_response_by_id[
314+
job.api_job_id()
315+
].headers
316+
if not "request" in polling_response_context:
317+
polling_response_context["request"] = self._polling_job_response_by_id[
318+
job.api_job_id()
319+
].request
320+
return polling_response_context
321+
276322
def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice:
277-
creation_response = self._create_job_response_by_id[job.api_job_id()].json()
278323
stream_slice = StreamSlice(
279324
partition={},
280325
cursor_slice={},
281-
extra_fields={"creation_response": creation_response},
326+
extra_fields={
327+
"creation_response": self._get_creation_response_interpolation_context(job),
328+
},
282329
)
283330
return stream_slice
284331

285332
def _get_download_targets(self, job: AsyncJob) -> Iterable[str]:
286333
if not self.download_target_requester:
287334
url_response = self._polling_job_response_by_id[job.api_job_id()]
288335
else:
289-
polling_response = self._polling_job_response_by_id[job.api_job_id()].json()
290336
stream_slice: StreamSlice = StreamSlice(
291337
partition={},
292338
cursor_slice={},
293-
extra_fields={"polling_response": polling_response},
339+
extra_fields={
340+
"polling_response": self._get_polling_response_interpolation_context(job),
341+
},
294342
)
295343
url_response = self.download_target_requester.send_request(stream_slice=stream_slice) # type: ignore # we expect download_target_requester to always be presented, otherwise raise an exception as we cannot proceed with the report
296344
if not url_response:

unit_tests/sources/declarative/async_job/test_job_orchestrator.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,10 @@ def test_given_timeout_when_create_and_get_completed_partitions_then_free_budget
144144
)
145145
orchestrator = self._orchestrator([_A_STREAM_SLICE], job_tracker)
146146

147-
with pytest.raises(AirbyteTracedException):
147+
with pytest.raises(AirbyteTracedException) as error:
148148
list(orchestrator.create_and_get_completed_partitions())
149-
assert job_tracker.try_to_get_intent()
150-
assert (
151-
self._job_repository.start.call_args_list
152-
== [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS
153-
)
149+
150+
assert "Job an api job id has timed out" in str(error.value)
154151

155152
@mock.patch(sleep_mock_target)
156153
def test_given_failure_when_create_and_get_completed_partitions_then_raise_exception(

0 commit comments

Comments
 (0)