Skip to content

Commit f3a24de

Browse files
authored
fix: Skip refresh if already in progress or if lock is already held (feast-dev#5068)
1 parent 69d462c commit f3a24de

File tree

3 files changed

+219
-3
lines changed

3 files changed

+219
-3
lines changed

sdk/python/feast/infra/registry/caching_registry.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,24 @@ def list_projects(
425425
return self._list_projects(tags)
426426

427427
def refresh(self, project: Optional[str] = None):
428-
self.cached_registry_proto = self.proto()
429-
self.cached_registry_proto_created = _utc_now()
428+
if self._refresh_lock.locked():
429+
logger.info("Skipping refresh if already in progress")
430+
return
431+
try:
432+
self.cached_registry_proto = self.proto()
433+
self.cached_registry_proto_created = _utc_now()
434+
except Exception as e:
435+
logger.error(f"Error while refreshing registry: {e}", exc_info=True)
430436

431437
def _refresh_cached_registry_if_necessary(self):
432438
if self.cache_mode == "sync":
433-
with self._refresh_lock:
439+
# Try acquiring the lock without blocking
440+
if not self._refresh_lock.acquire(blocking=False):
441+
logger.info(
442+
"Skipping refresh if lock is already held by another thread"
443+
)
444+
return
445+
try:
434446
if self.cached_registry_proto == RegistryProto():
435447
# Avoids the need to refresh the registry when cache is not populated yet
436448
# Specially during the __init__ phase
@@ -454,6 +466,13 @@ def _refresh_cached_registry_if_necessary(self):
454466
if expired:
455467
logger.info("Registry cache expired, so refreshing")
456468
self.refresh()
469+
except Exception as e:
470+
logger.error(
471+
f"Error in _refresh_cached_registry_if_necessary: {e}",
472+
exc_info=True,
473+
)
474+
finally:
475+
self._refresh_lock.release() # Always release the lock safely
457476

458477
def _start_thread_async_refresh(self, cache_ttl_seconds):
459478
self.refresh()

sdk/python/tests/unit/infra/registry/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from datetime import datetime, timedelta, timezone
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from feast.infra.registry.caching_registry import CachingRegistry
7+
8+
9+
class TestCachingRegistry(CachingRegistry):
10+
"""Test subclass that implements abstract methods as no-ops"""
11+
12+
def _get_any_feature_view(self, *args, **kwargs):
13+
pass
14+
15+
def _get_data_source(self, *args, **kwargs):
16+
pass
17+
18+
def _get_entity(self, *args, **kwargs):
19+
pass
20+
21+
def _get_feature_service(self, *args, **kwargs):
22+
pass
23+
24+
def _get_feature_view(self, *args, **kwargs):
25+
pass
26+
27+
def _get_infra(self, *args, **kwargs):
28+
pass
29+
30+
def _get_on_demand_feature_view(self, *args, **kwargs):
31+
pass
32+
33+
def _get_permission(self, *args, **kwargs):
34+
pass
35+
36+
def _get_project(self, *args, **kwargs):
37+
pass
38+
39+
def _get_saved_dataset(self, *args, **kwargs):
40+
pass
41+
42+
def _get_stream_feature_view(self, *args, **kwargs):
43+
pass
44+
45+
def _get_validation_reference(self, *args, **kwargs):
46+
pass
47+
48+
def _list_all_feature_views(self, *args, **kwargs):
49+
pass
50+
51+
def _list_data_sources(self, *args, **kwargs):
52+
pass
53+
54+
def _list_entities(self, *args, **kwargs):
55+
pass
56+
57+
def _list_feature_services(self, *args, **kwargs):
58+
pass
59+
60+
def _list_feature_views(self, *args, **kwargs):
61+
pass
62+
63+
def _list_on_demand_feature_views(self, *args, **kwargs):
64+
pass
65+
66+
def _list_permissions(self, *args, **kwargs):
67+
pass
68+
69+
def _list_project_metadata(self, *args, **kwargs):
70+
pass
71+
72+
def _list_projects(self, *args, **kwargs):
73+
pass
74+
75+
def _list_saved_datasets(self, *args, **kwargs):
76+
pass
77+
78+
def _list_stream_feature_views(self, *args, **kwargs):
79+
pass
80+
81+
def _list_validation_references(self, *args, **kwargs):
82+
pass
83+
84+
def apply_data_source(self, *args, **kwargs):
85+
pass
86+
87+
def apply_entity(self, *args, **kwargs):
88+
pass
89+
90+
def apply_feature_service(self, *args, **kwargs):
91+
pass
92+
93+
def apply_feature_view(self, *args, **kwargs):
94+
pass
95+
96+
def apply_materialization(self, *args, **kwargs):
97+
pass
98+
99+
def apply_permission(self, *args, **kwargs):
100+
pass
101+
102+
def apply_project(self, *args, **kwargs):
103+
pass
104+
105+
def apply_saved_dataset(self, *args, **kwargs):
106+
pass
107+
108+
def apply_user_metadata(self, *args, **kwargs):
109+
pass
110+
111+
def apply_validation_reference(self, *args, **kwargs):
112+
pass
113+
114+
def commit(self, *args, **kwargs):
115+
pass
116+
117+
def delete_data_source(self, *args, **kwargs):
118+
pass
119+
120+
def delete_entity(self, *args, **kwargs):
121+
pass
122+
123+
def delete_feature_service(self, *args, **kwargs):
124+
pass
125+
126+
def delete_feature_view(self, *args, **kwargs):
127+
pass
128+
129+
def delete_permission(self, *args, **kwargs):
130+
pass
131+
132+
def delete_project(self, *args, **kwargs):
133+
pass
134+
135+
def delete_validation_reference(self, *args, **kwargs):
136+
pass
137+
138+
def get_user_metadata(self, *args, **kwargs):
139+
pass
140+
141+
def proto(self, *args, **kwargs):
142+
pass
143+
144+
def update_infra(self, *args, **kwargs):
145+
pass
146+
147+
148+
@pytest.fixture
149+
def registry():
150+
"""Fixture to create a real instance of CachingRegistry"""
151+
return TestCachingRegistry(
152+
project="test_example", cache_ttl_seconds=2, cache_mode="sync"
153+
)
154+
155+
156+
def test_cache_expiry_triggers_refresh(registry):
157+
"""Test that an expired cache triggers a refresh"""
158+
# Set cache creation time to a value that is expired
159+
registry.cached_registry_proto = "some_cached_data"
160+
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
161+
seconds=5
162+
)
163+
164+
# Mock _refresh_cached_registry_if_necessary to check if it is called
165+
with patch.object(
166+
CachingRegistry,
167+
"_refresh_cached_registry_if_necessary",
168+
wraps=registry._refresh_cached_registry_if_necessary,
169+
) as mock_refresh_check:
170+
registry._refresh_cached_registry_if_necessary()
171+
mock_refresh_check.assert_called_once()
172+
173+
# Now check if the refresh was actually triggered
174+
with patch.object(
175+
CachingRegistry, "refresh", wraps=registry.refresh
176+
) as mock_refresh:
177+
registry._refresh_cached_registry_if_necessary()
178+
mock_refresh.assert_called_once()
179+
180+
181+
def test_skip_refresh_if_lock_held(registry):
182+
"""Test that refresh is skipped if the lock is already held by another thread"""
183+
registry.cached_registry_proto = "some_cached_data"
184+
registry.cached_registry_proto_created = datetime.now(timezone.utc) - timedelta(
185+
seconds=5
186+
)
187+
188+
# Acquire the lock manually to simulate another thread holding it
189+
registry._refresh_lock.acquire()
190+
with patch.object(
191+
CachingRegistry, "refresh", wraps=registry.refresh
192+
) as mock_refresh:
193+
registry._refresh_cached_registry_if_necessary()
194+
195+
# Since the lock was already held, refresh should NOT be called
196+
mock_refresh.assert_not_called()
197+
registry._refresh_lock.release()

0 commit comments

Comments
 (0)