Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airbyte_cdk/sources/declarative/request_local/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .request_local import RequestLocal

__all__ = ["RequestLocal"]
39 changes: 39 additions & 0 deletions airbyte_cdk/sources/declarative/request_local/request_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from threading import local, Lock

class RequestLocal(local):
_instance = None
_lock = Lock() # Thread-safe singleton creation

def __new__(cls, *args, **kwargs):
# Use double-checked locking for thread safety
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(RequestLocal, cls).__new__(cls)
return cls._instance
Comment on lines +7 to +13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add type annotations to fix mypy errors

The __new__ method is missing type annotations. Would you consider adding them to satisfy the mypy checks? wdyt?

-    def __new__(cls, *args, **kwargs):
+    def __new__(cls, *args, **kwargs) -> "RequestLocal":
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __new__(cls, *args, **kwargs):
# Use double-checked locking for thread safety
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(RequestLocal, cls).__new__(cls)
return cls._instance
def __new__(cls, *args, **kwargs) -> "RequestLocal":
# Use double-checked locking for thread safety
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super(RequestLocal, cls).__new__(cls)
return cls._instance
🧰 Tools
🪛 GitHub Actions: Linters

[error] 7-7: mypy: Function is missing a type annotation. (no-untyped-def)

🤖 Prompt for AI Agents
In airbyte_cdk/sources/declarative/request_local/request_local.py around lines 7
to 13, the __new__ method lacks type annotations causing mypy errors. Add
appropriate type annotations to the __new__ method signature, specifying the
class type for cls and the return type as an instance of the class, to satisfy
static type checking.


def __init__(self):
# __init__ will be called every time the class is instantiated,
# but the object itself is only created once by __new__.
# Use a flag to prevent re-initialization
if not hasattr(self, '_initialized'):
self._stream_slice = None # Initialize _stream_slice
self._initialized = True
Comment on lines +15 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add return type annotation and consider thread-safe initialization

The __init__ method needs a return type annotation. Also, the initialization check using hasattr might have edge cases in concurrent scenarios. Would you consider using a more robust approach? wdyt?

-    def __init__(self):
+    def __init__(self) -> None:
         # __init__ will be called every time the class is instantiated,
         # but the object itself is only created once by __new__.
         # Use a flag to prevent re-initialization
-        if not hasattr(self, '_initialized'):
+        with self._lock:
+            if not hasattr(self, '_initialized'):
+                self._stream_slice = None  # Initialize _stream_slice
+                self._initialized = True
-            self._stream_slice = None  # Initialize _stream_slice
-            self._initialized = True

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 GitHub Actions: Linters

[error] 15-15: mypy: Function is missing a return type annotation. Use '-> None' if function does not return a value. (no-untyped-def)

🤖 Prompt for AI Agents
In airbyte_cdk/sources/declarative/request_local/request_local.py around lines
15 to 21, add a return type annotation of None to the __init__ method. To
improve thread safety during initialization, replace the hasattr check with a
thread-safe mechanism such as using a threading.Lock or a class-level flag
protected by synchronization to ensure the initialization code runs only once
even in concurrent scenarios.


@property
def stream_slice(self):
return self._stream_slice
Comment on lines +23 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add return type annotation to property getter

The property getter needs a return type annotation.

     @property
-    def stream_slice(self):
+    def stream_slice(self) -> Optional[Dict[str, Any]]:
         return self._stream_slice

Don't forget to import the necessary types at the top of the file:

from typing import Optional, Dict, Any
🧰 Tools
🪛 GitHub Actions: Linters

[error] 24-24: mypy: Function is missing a return type annotation. (no-untyped-def)

🤖 Prompt for AI Agents
In airbyte_cdk/sources/declarative/request_local/request_local.py around lines
23 to 25, the stream_slice property getter lacks a return type annotation. Add a
return type annotation to the method signature, such as Optional[Dict[str,
Any]], and ensure the necessary types Optional, Dict, and Any are imported from
typing at the top of the file.


@stream_slice.setter
def stream_slice(self, stream_slice):
self._stream_slice = stream_slice
Comment on lines +27 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add type annotation to property setter

The setter needs a type annotation for the parameter.

     @stream_slice.setter
-    def stream_slice(self, stream_slice):
+    def stream_slice(self, stream_slice: Optional[Dict[str, Any]]) -> None:
         self._stream_slice = stream_slice
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@stream_slice.setter
def stream_slice(self, stream_slice):
self._stream_slice = stream_slice
@stream_slice.setter
def stream_slice(self, stream_slice: Optional[Dict[str, Any]]) -> None:
self._stream_slice = stream_slice
🧰 Tools
🪛 GitHub Actions: Linters

[error] 28-28: mypy: Function is missing a type annotation. (no-untyped-def)

🤖 Prompt for AI Agents
In airbyte_cdk/sources/declarative/request_local/request_local.py around lines
27 to 29, the stream_slice setter method lacks a type annotation for its
parameter. Add the appropriate type annotation to the stream_slice parameter in
the setter definition to improve code clarity and type checking.


@classmethod
def get_instance(cls):
"""
Get the singleton instance of RequestLocal.
This is the recommended way to get the instance.
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
Comment on lines +32 to +39
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this

Comment on lines +31 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Add return type annotation and simplify implementation

The class method needs a return type annotation. Also, since __new__ already handles singleton creation with proper locking, would it make sense to simplify this method? wdyt?

     @classmethod
-    def get_instance(cls):
+    def get_instance(cls) -> "RequestLocal":
         """
         Get the singleton instance of RequestLocal.
         This is the recommended way to get the instance.
         """
-        if cls._instance is None:
-            cls._instance = cls()
-        return cls._instance
+        return cls()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@classmethod
def get_instance(cls):
"""
Get the singleton instance of RequestLocal.
This is the recommended way to get the instance.
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def get_instance(cls) -> "RequestLocal":
"""
Get the singleton instance of RequestLocal.
This is the recommended way to get the instance.
"""
return cls()
🧰 Tools
🪛 GitHub Actions: Linters

[error] 32-32: mypy: Function is missing a return type annotation. (no-untyped-def)

🤖 Prompt for AI Agents
In airbyte_cdk/sources/declarative/request_local/request_local.py around lines
31 to 39, add a return type annotation to the get_instance class method to
specify it returns an instance of the class. Since the singleton pattern is
already handled in the __new__ method with locking, simplify get_instance by
removing the explicit instance check and creation, and just return cls()
directly.

5 changes: 5 additions & 0 deletions airbyte_cdk/sources/declarative/requesters/http_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
combine_mappings,
get_interpolation_context,
)
from airbyte_cdk.sources.declarative.request_local import RequestLocal


@dataclass
Expand Down Expand Up @@ -449,6 +450,9 @@ def send_request(
request_body_json: Optional[Mapping[str, Any]] = None,
log_formatter: Optional[Callable[[requests.Response], Any]] = None,
) -> Optional[requests.Response]:
request_local = RequestLocal()
request_local.stream_slice = stream_slice

request, response = self._http_client.send_request(
http_method=self.get_method().value,
url=self._get_url(
Expand All @@ -473,6 +477,7 @@ def send_request(
dedupe_query_params=True,
log_formatter=log_formatter,
exit_on_rate_limit=self._exit_on_rate_limit,
stream_slice=stream_slice,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this.

)

return response
1 change: 1 addition & 0 deletions airbyte_cdk/sources/streams/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def _fetch_next_page(
dedupe_query_params=True,
log_formatter=self.get_log_formatter(),
exit_on_rate_limit=self.exit_on_rate_limit,
stream_slice=stream_slice,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

)

return request, response
Expand Down
9 changes: 6 additions & 3 deletions airbyte_cdk/sources/streams/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def _send_with_retry(
request_kwargs: Mapping[str, Any],
log_formatter: Optional[Callable[[requests.Response], Any]] = None,
exit_on_rate_limit: Optional[bool] = False,
stream_slice: Optional[Mapping[str, Any]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

) -> requests.Response:
"""
Sends a request with retry logic.
Expand All @@ -259,9 +260,9 @@ def _send_with_retry(
max_tries = max(0, max_retries) + 1
max_time = self._max_time

user_backoff_handler = user_defined_backoff_handler(max_tries=max_tries, max_time=max_time)(
self._send
)
user_backoff_handler = user_defined_backoff_handler(
max_tries=max_tries, max_time=max_time, stream_slice=stream_slice
)(self._send)
rate_limit_backoff_handler = rate_limit_default_backoff_handler(max_tries=max_tries)
backoff_handler = http_client_default_backoff_handler(
max_tries=max_tries, max_time=max_time
Expand Down Expand Up @@ -506,6 +507,7 @@ def send_request(
dedupe_query_params: bool = False,
log_formatter: Optional[Callable[[requests.Response], Any]] = None,
exit_on_rate_limit: Optional[bool] = False,
stream_slice: Optional[Mapping[str, Any]] = None,
Copy link
Contributor

@maxi297 maxi297 Jul 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the connector you are trying to update low-code?

Wild idea: instead of having to modify all the interfaces that would require us to low more information, could we create a thread-local in a threading.py file that users that want to log more information like the slice could have access to this? PartitionReader.process_partition would basically do thread_local.stream_slice = partition.to_slice and at that point, http_client can just get the information from thread_local.stream_slice. If anyone else than the HttpClient need access to this information, they just need to do the same.

We can probably assume it would not only work for the concurrent CDK as we could register the information in the thread local even in a non-concurrent world. I just haven't checked where this would happen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the connector you are trying to update low-code?

@maxi297 Yes, it's Slack, I received the feedback from a user who was unable to say after a few hours if the attempt was "stuck". Honestly, I couldn't say as the rate limit logs won't tell me too much.

I was recently working on bumping Cdk 6 for this connector and was trying to get better logs on the fly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxi297 I need to do some cleanup, but I added the singleton class that inherits from local and added some tests around.

) -> Tuple[requests.PreparedRequest, requests.Response]:
"""
Prepares and sends request and return request and response objects.
Expand All @@ -526,6 +528,7 @@ def send_request(
request_kwargs=request_kwargs,
log_formatter=log_formatter,
exit_on_rate_limit=exit_on_rate_limit,
stream_slice=stream_slice,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

)

return request, response
26 changes: 22 additions & 4 deletions airbyte_cdk/sources/streams/http/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import backoff
from requests import PreparedRequest, RequestException, Response, codes, exceptions

from airbyte_cdk.utils.datetime_helpers import ab_datetime_now

from .exceptions import (
DefaultBackoffException,
RateLimitBackoffException,
UserDefinedBackoffException,
)
from airbyte_cdk.sources.declarative.request_local import RequestLocal

TRANSIENT_EXCEPTIONS = (
DefaultBackoffException,
Expand Down Expand Up @@ -101,7 +104,10 @@ def should_give_up(exc: Exception) -> bool:


def user_defined_backoff_handler(
max_tries: Optional[int], max_time: Optional[int] = None, **kwargs: Any
max_tries: Optional[int],
max_time: Optional[int] = None,
stream_slice: Optional[Mapping[str, Any]] = None,
**kwargs: Any,
) -> Callable[[SendRequestCallableType], SendRequestCallableType]:
def sleep_on_ratelimit(details: Mapping[str, Any]) -> None:
_, exc, _ = sys.exc_info()
Expand All @@ -111,7 +117,14 @@ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None:
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
)
retry_after = exc.backoff
logger.info(f"Retrying. Sleeping for {retry_after} seconds")
# server logs are misleading as several sleeping messages are logged at the same timestamp
logging_message = (
f"Retrying. Sleeping for {retry_after} seconds at {ab_datetime_now()} UTC"
)
request_local = RequestLocal()
if request_local.stream_slice:
logging_message += f" for slice: {request_local.stream_slice}"
logger.info(logging_message)
time.sleep(retry_after + 1) # extra second to cover any fractions of second

def log_give_up(details: Mapping[str, Any]) -> None:
Expand Down Expand Up @@ -145,9 +158,14 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None:
logger.info(
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
)
logger.info(
f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
logger_slice_info = ""
request_local = RequestLocal()
if request_local.stream_slice:
logger_slice_info = f" for slice: {request_local.stream_slice}"
logger_info_message = (
f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying{logger_slice_info}..."
)
logger.info(logger_info_message)

return backoff.on_exception( # type: ignore # Decorator function returns a function with a different signature than the input function, so mypy can't infer the type of the returned function
backoff.expo,
Expand Down
167 changes: 167 additions & 0 deletions unit_tests/sources/declarative/request_local/test_request_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import sys
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor


from airbyte_cdk.sources.declarative.request_local.request_local import RequestLocal

STREAM_SLICE_KEY = "stream_slice"
INSTANCE_ID_KEY = "instance_id"

def test_basic_singleton():
"""Test basic singleton behavior"""
# Multiple instantiations return same instance
instance1 = RequestLocal()
instance2 = RequestLocal()
instance3 = RequestLocal()

assert instance1 is instance2
assert instance1 is instance3, "All instances should be the same singleton instance"
assert instance2 is instance3, "All instances should be the same singleton instance"


# get_instance class method
instance4 = RequestLocal.get_instance()
instance1.stream_slice = {"test": "data"}

# stream_slice property
instance1.stream_slice = {"test": "data"}
assert instance1.stream_slice is instance4.stream_slice
assert instance2.stream_slice is instance4.stream_slice

return instance1


def create_instance_in_thread(thread_id, results):
"""Function to create instance in a separate thread"""
instance = RequestLocal()

results[thread_id] = {
'instance_id': id(instance),
'thread_id': threading.get_ident()
}
time.sleep(0.1) # Small delay to ensure threads overlap


def test_thread_safety():
"""Ensure that RequestLocal is thread-safe and behaves as a singleton across threads"""
print("\n=== Testing Thread Safety ===")

results = {}
threads = []
total_treads = 5
# Create multiple threads that instantiate RequestLocal
for i in range(total_treads):
thread = threading.Thread(target=create_instance_in_thread, args=(i, results))
threads.append(thread)
thread.start()

# Wait for all threads to complete
for thread in threads:
thread.join()

# Analyze results
instance_ids = [result[INSTANCE_ID_KEY] for result in results.values()]
unique_ids = set(instance_ids)

assert len(results) == total_treads, "All threads should have created an instance"
assert len(unique_ids) == 1, "All threads should see the same singleton instance"



def test_threading_local_behavior():
"""Test how threading.local affects the singleton"""
def thread_func(thread_name, shared_results, time_sleep):
instance = RequestLocal()
assert instance.stream_slice == None, "Initial stream_slice should be empty"
instance.stream_slice = {f"data_from_{thread_name}": True}

shared_results[thread_name] = {
'instance_id': id(instance),
'stream_slice': instance.stream_slice.copy(),
'thread_id': threading.get_ident()
}

# Check if we can see data from other threads
# this should not happen as RequestLocal is a singleton
time.sleep(time_sleep)
shared_results[f"{thread_name}_after_sleep"] = {
'instance_id': id(instance),
'stream_slice': instance.stream_slice.copy(),
'end_time': time.time(),
}

results = {}
threads = {}
threads_amount = 3
time_sleep = 0.9
thread_names = []
for i in range(threads_amount):
tread_name = f"thread_{i}"
thread_names.append(tread_name)
thread = threading.Thread(target=thread_func, args=(tread_name, results, time_sleep))
time_sleep /=3 # Decrease sleep time for each thread to ensure they overlap
threads[tread_name]= thread
thread.start()

for _, thread in threads.items():
thread.join()

end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names]
last_end_time = end_times.pop()
while end_times:
current_end_time = end_times.pop()
# Just checking the last thread created ended before the previous ones
# so we could ensure the first thread created that sleep for a longer time
# was not affected by the other threads
assert last_end_time < current_end_time, "End times should be in increasing order"
last_end_time = current_end_time

assert len(thread_names) > 1
assert len(set(thread_names)) == len(thread_names), "Thread names should be unique"
for curren_thread_name in thread_names:
current_thread_name_after_sleep = f"{curren_thread_name}_after_sleep"
assert results[curren_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
f"Stream slice should remain consistent across thread {curren_thread_name} before and after sleep"
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
f"Instance ID should remain consistent across thread {curren_thread_name} before and after sleep"

# Check if stream slices are different across threads
# but same instance ID
for other_tread_name in [thread_name for thread_name in thread_names if thread_name != curren_thread_name]:
assert results[curren_thread_name][STREAM_SLICE_KEY] != results[other_tread_name][STREAM_SLICE_KEY], \
f"Stream slices from different threads should not be the same: {curren_thread_name} vs {other_tread_name}"
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[other_tread_name][INSTANCE_ID_KEY]

Comment on lines +74 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix multiple typos and formatting issues

There are several typos and formatting issues that need to be addressed.

         }
-    
+
     results = {}
     threads = {}
     threads_amount = 3
     time_sleep = 0.9
     thread_names = []
     for i in range(threads_amount):
-        tread_name = f"thread_{i}"
-        thread_names.append(tread_name)
-        thread = threading.Thread(target=thread_func, args=(tread_name, results, time_sleep))
-        time_sleep /=3  # Decrease sleep time for each thread to ensure they overlap
-        threads[tread_name]= thread
+        thread_name = f"thread_{i}"
+        thread_names.append(thread_name)
+        thread = threading.Thread(target=thread_func, args=(thread_name, results, time_sleep))
+        time_sleep /= 3  # Decrease sleep time for each thread to ensure they overlap
+        threads[thread_name] = thread
         thread.start()

     for _, thread in threads.items():
         thread.join()

     end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names]
     last_end_time = end_times.pop()
     while end_times:
         current_end_time = end_times.pop()
         # Just checking the last thread created ended before the previous ones
         # so we could ensure the first thread created that sleep for a longer time
         # was not affected by the other threads
         assert last_end_time < current_end_time, "End times should be in increasing order"
         last_end_time = current_end_time

     assert len(thread_names) > 1
     assert len(set(thread_names)) == len(thread_names), "Thread names should be unique"
-    for curren_thread_name in thread_names:
-        current_thread_name_after_sleep = f"{curren_thread_name}_after_sleep"
-        assert results[curren_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
-            f"Stream slice should remain consistent across thread {curren_thread_name} before and after sleep"
-        assert results[curren_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
-            f"Instance ID should remain consistent across thread {curren_thread_name} before and after sleep"
+    for current_thread_name in thread_names:
+        current_thread_name_after_sleep = f"{current_thread_name}_after_sleep"
+        assert results[current_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
+            f"Stream slice should remain consistent across thread {current_thread_name} before and after sleep"
+        assert results[current_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
+            f"Instance ID should remain consistent across thread {current_thread_name} before and after sleep"

         # Check if stream slices are different across threads
         # but same instance ID
-        for other_tread_name in [thread_name for thread_name in thread_names if thread_name != curren_thread_name]:
-            assert results[curren_thread_name][STREAM_SLICE_KEY] != results[other_tread_name][STREAM_SLICE_KEY], \
-                f"Stream slices from different threads should not be the same: {curren_thread_name} vs {other_tread_name}"
-            assert results[curren_thread_name][INSTANCE_ID_KEY] == results[other_tread_name][INSTANCE_ID_KEY]
+        for other_thread_name in [thread_name for thread_name in thread_names if thread_name != current_thread_name]:
+            assert results[current_thread_name][STREAM_SLICE_KEY] != results[other_thread_name][STREAM_SLICE_KEY], \
+                f"Stream slices from different threads should not be the same: {current_thread_name} vs {other_thread_name}"
+            assert results[current_thread_name][INSTANCE_ID_KEY] == results[other_thread_name][INSTANCE_ID_KEY]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_threading_local_behavior():
"""Test how threading.local affects the singleton"""
def thread_func(thread_name, shared_results, time_sleep):
instance = RequestLocal()
assert instance.stream_slice == None, "Initial stream_slice should be empty"
instance.stream_slice = {f"data_from_{thread_name}": True}
shared_results[thread_name] = {
'instance_id': id(instance),
'stream_slice': instance.stream_slice.copy(),
'thread_id': threading.get_ident()
}
# Check if we can see data from other threads
# this should not happen as RequestLocal is a singleton
time.sleep(time_sleep)
shared_results[f"{thread_name}_after_sleep"] = {
'instance_id': id(instance),
'stream_slice': instance.stream_slice.copy(),
'end_time': time.time(),
}
results = {}
threads = {}
threads_amount = 3
time_sleep = 0.9
thread_names = []
for i in range(threads_amount):
tread_name = f"thread_{i}"
thread_names.append(tread_name)
thread = threading.Thread(target=thread_func, args=(tread_name, results, time_sleep))
time_sleep /=3 # Decrease sleep time for each thread to ensure they overlap
threads[tread_name]= thread
thread.start()
for _, thread in threads.items():
thread.join()
end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names]
last_end_time = end_times.pop()
while end_times:
current_end_time = end_times.pop()
# Just checking the last thread created ended before the previous ones
# so we could ensure the first thread created that sleep for a longer time
# was not affected by the other threads
assert last_end_time < current_end_time, "End times should be in increasing order"
last_end_time = current_end_time
assert len(thread_names) > 1
assert len(set(thread_names)) == len(thread_names), "Thread names should be unique"
for curren_thread_name in thread_names:
current_thread_name_after_sleep = f"{curren_thread_name}_after_sleep"
assert results[curren_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
f"Stream slice should remain consistent across thread {curren_thread_name} before and after sleep"
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
f"Instance ID should remain consistent across thread {curren_thread_name} before and after sleep"
# Check if stream slices are different across threads
# but same instance ID
for other_tread_name in [thread_name for thread_name in thread_names if thread_name != curren_thread_name]:
assert results[curren_thread_name][STREAM_SLICE_KEY] != results[other_tread_name][STREAM_SLICE_KEY], \
f"Stream slices from different threads should not be the same: {curren_thread_name} vs {other_tread_name}"
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[other_tread_name][INSTANCE_ID_KEY]
results = {}
threads = {}
threads_amount = 3
time_sleep = 0.9
thread_names = []
for i in range(threads_amount):
thread_name = f"thread_{i}"
thread_names.append(thread_name)
thread = threading.Thread(target=thread_func, args=(thread_name, results, time_sleep))
time_sleep /= 3 # Decrease sleep time for each thread to ensure they overlap
threads[thread_name] = thread
thread.start()
for _, thread in threads.items():
thread.join()
end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names]
last_end_time = end_times.pop()
while end_times:
current_end_time = end_times.pop()
# Just checking the last thread created ended before the previous ones
# so we could ensure the first thread created that sleep for a longer time
# was not affected by the other threads
assert last_end_time < current_end_time, "End times should be in increasing order"
last_end_time = current_end_time
assert len(thread_names) > 1
assert len(set(thread_names)) == len(thread_names), "Thread names should be unique"
for current_thread_name in thread_names:
current_thread_name_after_sleep = f"{current_thread_name}_after_sleep"
assert results[current_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
f"Stream slice should remain consistent across thread {current_thread_name} before and after sleep"
assert results[current_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
f"Instance ID should remain consistent across thread {current_thread_name} before and after sleep"
# Check if stream slices are different across threads
# but same instance ID
for other_thread_name in [thread_name for thread_name in thread_names if thread_name != current_thread_name]:
assert results[current_thread_name][STREAM_SLICE_KEY] != results[other_thread_name][STREAM_SLICE_KEY], \
f"Stream slices from different threads should not be the same: {current_thread_name} vs {other_thread_name}"
assert results[current_thread_name][INSTANCE_ID_KEY] == results[other_thread_name][INSTANCE_ID_KEY]
🤖 Prompt for AI Agents
In unit_tests/sources/declarative/request_local/test_request_local.py between
lines 74 and 137, fix typos such as "tread_name" to "thread_name" and
"curren_thread_name" to "current_thread_name". Also, correct variable names like
"other_tread_name" to "other_thread_name". Ensure consistent formatting and
spacing throughout the function for readability and correctness.

# Fixme: Uncomment this test put asserts and remove prints to test concurrent access
# def test_concurrent_access():
# """Test concurrent access using ThreadPoolExecutor"""
# print("\n=== Testing Concurrent Access ===")
#
# def worker(worker_id):
# instance = RequestLocal()
# return {
# 'worker_id': worker_id,
# 'instance_id': id(instance),
# 'thread_id': threading.get_ident()
# }
#
# with ThreadPoolExecutor(max_workers=10) as executor:
# futures = [executor.submit(worker, i) for i in range(20)]
# results = [future.result() for future in futures]
#
# # Analyze results
# instance_ids = [result[INSTANCE_ID_KEY] for result in results]
# unique_ids = set(instance_ids)
#
# print(f"Total workers: {len(results)}")
# print(f"Unique instance IDs: {len(unique_ids)}")
# print(f"Singleton behavior maintained: {len(unique_ids) == 1}")
#
# # Show first few results
# print("First 5 results:")
# for result in results[:5]:
# print(f" Worker {result['worker_id']}: ID={result[INSTANCE_ID_KEY]}")

Loading
Loading