Skip to content

Commit dd74a80

Browse files
[core] feat: add support for caching fetching results
1 parent e075165 commit dd74a80

File tree

8 files changed

+251
-26
lines changed

8 files changed

+251
-26
lines changed

libs/core/garf_core/cache.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Stores and loads reports from a cache instead of calling API."""
16+
17+
from __future__ import annotations
18+
19+
import hashlib
20+
import json
21+
import logging
22+
import os
23+
import pathlib
24+
from typing import Final
25+
26+
from garf_core import exceptions, query_editor, report
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class GarfCacheFileNotFoundError(exceptions.GarfError):
32+
"""Exception for not found cached report."""
33+
34+
35+
DEFAULT_CACHE_LOCATION: Final[str] = os.getenv(
36+
'GARF_CACHE_LOCATION', str(pathlib.Path.home() / '.garf/cache/')
37+
)
38+
39+
40+
class GarfCache:
41+
"""Stores and loads reports from a cache instead of calling API.
42+
43+
Attribute:
44+
location: Folder where cached results are stored.
45+
"""
46+
47+
def __init__(
48+
self,
49+
location: str | None = None,
50+
) -> None:
51+
"""Stores and loads reports from a cache instead of calling API.
52+
53+
Args:
54+
location: Folder where cached results are stored.
55+
"""
56+
self.location = pathlib.Path(location or DEFAULT_CACHE_LOCATION)
57+
58+
def load(
59+
self, query: query_editor.BaseQueryElements, args=None, kwargs=None
60+
) -> report.GarfReport:
61+
"""Loads report from cache based on query definition.
62+
63+
Args:
64+
query: Query elements.
65+
args: Query parameters.
66+
kwargs: Optional keyword arguments.
67+
68+
Returns:
69+
Cached report.
70+
71+
Raises:
72+
GarfCacheFileNotFoundError: If cached report not found
73+
"""
74+
args_hash = args.hash if args else ''
75+
kwargs_hash = (
76+
hashlib.md5(json.dumps(kwargs).encode('utf-8')).hexdigest()
77+
if kwargs
78+
else ''
79+
)
80+
hash_identifier = f'{query.hash}:{args_hash}:{kwargs_hash}'
81+
cached_path = self.location / f'{hash_identifier}.json'
82+
if cached_path.exists():
83+
with open(cached_path, 'r', encoding='utf-8') as f:
84+
data = json.load(f)
85+
logger.debug('Report is loaded from cache: %s', str(cached_path))
86+
return report.GarfReport.from_json(data)
87+
raise GarfCacheFileNotFoundError
88+
89+
def save(
90+
self,
91+
fetched_report: report.GarfReport,
92+
query: query_editor.BaseQueryElements,
93+
args=None,
94+
kwargs=None,
95+
) -> None:
96+
"""Saves report to cache based on query definition.
97+
98+
Args:
99+
fetched_report: Report to save.
100+
query: Query elements.
101+
args: Query parameters.
102+
kwargs: Optional keyword arguments.
103+
"""
104+
self.location.mkdir(parents=True, exist_ok=True)
105+
args_hash = args.hash if args else ''
106+
kwargs_hash = (
107+
hashlib.md5(json.dumps(kwargs).encode('utf-8')).hexdigest()
108+
if kwargs
109+
else ''
110+
)
111+
hash_identifier = f'{query.hash}:{args_hash}:{kwargs_hash}'
112+
cached_path = self.location / f'{hash_identifier}.json'
113+
logger.debug('Report is saved to cache: %s', str(cached_path))
114+
with open(cached_path, 'w', encoding='utf-8') as f:
115+
json.dump(fetched_report.to_json(), f)

libs/core/garf_core/query_editor.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from __future__ import annotations
1717

18-
import dataclasses
1918
import datetime
19+
import hashlib
20+
import json
2021
import logging
2122
import re
2223
from typing import Generator, Union
@@ -39,6 +40,11 @@ class GarfQueryParameters(pydantic.BaseModel):
3940
macro: QueryParameters = pydantic.Field(default_factory=dict)
4041
template: QueryParameters = pydantic.Field(default_factory=dict)
4142

43+
@property
44+
def hash(self) -> str:
45+
hash_fields = self.model_dump(exclude_none=True)
46+
return hashlib.md5(json.dumps(hash_fields).encode('utf-8')).hexdigest()
47+
4248

4349
class GarfMacroError(query_parser.GarfQueryError):
4450
"""Specifies incorrect macro in Garf query."""
@@ -52,33 +58,32 @@ class GarfBuiltInQueryError(query_parser.GarfQueryError):
5258
"""Specifies non-existing builtin query."""
5359

5460

55-
@dataclasses.dataclass
56-
class BaseQueryElements:
61+
class BaseQueryElements(pydantic.BaseModel):
5762
"""Contains raw query and parsed elements.
5863
5964
Attributes:
60-
title: Title of the query that needs to be parsed.
61-
text: Text of the query that needs to be parsed.
62-
resource_name: Name of Google Ads API reporting resource.
63-
fields: Ads API fields that need to be fetched.
64-
column_names: Friendly names for fields which are used when saving data
65-
column_names: Friendly names for fields which are used when saving data
66-
customizers: Attributes of fields that need to be be extracted.
67-
virtual_columns: Attributes of fields that need to be be calculated.
68-
is_builtin_query: Whether query is built-in.
65+
title: Title of the query that needs to be parsed.
66+
text: Text of the query that needs to be parsed.
67+
resource_name: Name of Google Ads API reporting resource.
68+
fields: Ads API fields that need to be fetched.
69+
column_names: Friendly names for fields which are used when saving data
70+
column_names: Friendly names for fields which are used when saving data
71+
customizers: Attributes of fields that need to be be extracted.
72+
virtual_columns: Attributes of fields that need to be be calculated.
73+
is_builtin_query: Whether query is built-in.
6974
"""
7075

71-
title: str
76+
title: str | None
7277
text: str
7378
resource_name: str | None = None
74-
fields: list[str] = dataclasses.field(default_factory=list)
75-
filters: list[str] = dataclasses.field(default_factory=list)
76-
sorts: list[str] = dataclasses.field(default_factory=list)
77-
column_names: list[str] = dataclasses.field(default_factory=list)
78-
customizers: dict[str, dict[str, str]] = dataclasses.field(
79+
fields: list[str] = pydantic.Field(default_factory=list)
80+
filters: list[str] = pydantic.Field(default_factory=list)
81+
sorts: list[str] = pydantic.Field(default_factory=list)
82+
column_names: list[str] = pydantic.Field(default_factory=list)
83+
customizers: dict[str, query_parser.Customizer] = pydantic.Field(
7984
default_factory=dict
8085
)
81-
virtual_columns: dict[str, query_parser.VirtualColumn] = dataclasses.field(
86+
virtual_columns: dict[str, query_parser.VirtualColumn] = pydantic.Field(
8287
default_factory=dict
8388
)
8489
is_builtin_query: bool = False
@@ -87,12 +92,16 @@ def __eq__(self, other: BaseQueryElements) -> bool: # noqa: D105
8792
return (
8893
self.column_names,
8994
self.fields,
95+
self.filters,
96+
self.sorts,
9097
self.resource_name,
9198
self.customizers,
9299
self.virtual_columns,
93100
) == (
94101
other.column_names,
95102
other.fields,
103+
other.filters,
104+
other.sorts,
96105
other.resource_name,
97106
other.customizers,
98107
other.virtual_columns,
@@ -103,6 +112,11 @@ def request(self) -> str:
103112
"""API request."""
104113
return ','.join(self.fields)
105114

115+
@property
116+
def hash(self) -> str:
117+
hash_fields = self.model_dump(exclude_none=True, exclude={'title', 'text'})
118+
return hashlib.md5(json.dumps(hash_fields).encode('utf-8')).hexdigest()
119+
106120

107121
class CommonParametersMixin:
108122
"""Helper mixin to inject set of common parameters to all queries."""

libs/core/garf_core/report_fetcher.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323

2424
import asyncio
2525
import logging
26+
import pathlib
2627
from typing import Callable
2728

2829
from opentelemetry import trace
2930

3031
from garf_core import (
3132
api_clients,
33+
cache,
3234
exceptions,
3335
parsers,
3436
query_editor,
@@ -63,6 +65,7 @@ class ApiReportFetcher:
6365
query_specification_builder: Class to perform query parsing.
6466
builtin_queries:
6567
Mapping between query name and function for generating GarfReport.
68+
enable_cache: Whether to load / save report from / to cache.
6669
"""
6770

6871
def __init__(
@@ -74,6 +77,8 @@ def __init__(
7477
),
7578
builtin_queries: dict[str, Callable[[ApiReportFetcher], report.GarfReport]]
7679
| None = None,
80+
enable_cache: bool = False,
81+
cache_path: str | pathlib.Path | None = None,
7782
**kwargs: str,
7883
) -> None:
7984
"""Instantiates ApiReportFetcher based on provided api client.
@@ -84,11 +89,15 @@ def __init__(
8489
query_specification_builder: Class to perform query parsing.
8590
builtin_queries:
8691
Mapping between query name and function for generating GarfReport.
92+
enable_cache: Whether to load / save report from / to cache.
93+
cache_path: Optional path to cache folder.
8794
"""
8895
self.api_client = api_client
8996
self.parser = parser
9097
self.query_specification_builder = query_specification_builder
9198
self.query_args = kwargs
99+
self.enable_cache = enable_cache
100+
self.cache = cache.GarfCache(cache_path)
92101
self.builtin_queries = builtin_queries or {}
93102

94103
def add_builtin_queries(
@@ -156,13 +165,24 @@ def fetch(
156165
)
157166
return builtin_report(self, **kwargs)
158167

168+
if self.enable_cache:
169+
try:
170+
cached_report = self.cache.load(query, args, kwargs)
171+
logger.warning('Cached version of query is loaded')
172+
span.set_attribute('is_cached_query', True)
173+
return cached_report
174+
except cache.GarfCacheFileNotFoundError:
175+
logger.debug('Cached version not found, generating')
159176
response = self.api_client.call_api(query, **kwargs)
160177
if not response:
161178
return report.GarfReport(query_specification=query)
162179

163180
parsed_response = self.parser(query).parse_response(response)
164-
return report.GarfReport(
181+
fetched_report = report.GarfReport(
165182
results=parsed_response,
166183
column_names=query.column_names,
167184
query_specification=query,
168185
)
186+
if self.enable_cache:
187+
self.cache.save(fetched_report, query, args, kwargs)
188+
return fetched_report

libs/core/tests/unit/test_cache.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from garf_core import cache, query_editor, report
17+
18+
19+
class TestGarfCache:
20+
@pytest.fixture()
21+
def cache(self, tmp_path):
22+
return cache.GarfCache(str(tmp_path))
23+
24+
def test_save(self, cache):
25+
test_report = report.GarfReport(results=[[1]], column_names=['test'])
26+
query = query_editor.QuerySpecification(
27+
text='SELECT test FROM test'
28+
).generate()
29+
30+
cache.save(test_report, query)
31+
loaded_report = cache.load(query)
32+
33+
assert loaded_report == test_report

libs/core/tests/unit/test_report_fetcher.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import datetime
17+
import logging
1718

1819
import pytest
1920
from garf_core import (
@@ -51,6 +52,38 @@ def test_fetch_returns_correct_report_for_dict_parser(
5152

5253
assert test_report == expected_report
5354

55+
def test_fetch_returns_saves_and_loads_cached_version(self, caplog, tmp_path):
56+
test_api_client = api_clients.FakeApiClient(
57+
results=[
58+
{'column': {'name': 1}, 'other_column': 2},
59+
{'column': {'name': 2}, 'other_column': 2},
60+
{'column': {'name': 3}, 'other_column': 2},
61+
]
62+
)
63+
test_fetcher = report_fetcher.ApiReportFetcher(
64+
api_client=test_api_client,
65+
parser=parsers.DictParser,
66+
enable_cache=True,
67+
cache_path=tmp_path,
68+
)
69+
query = 'SELECT column.name, other_column FROM test'
70+
expected_report = report.GarfReport(
71+
results=[[1, 2], [2, 2], [3, 2]],
72+
column_names=['column_name', 'other_column'],
73+
)
74+
75+
with caplog.at_level(logging.DEBUG):
76+
test_report = test_fetcher.fetch(query)
77+
assert 'Report is saved to cache' in caplog.text
78+
assert 'Cached version not found, generating' in caplog.text
79+
assert test_report == expected_report
80+
81+
with caplog.at_level(logging.DEBUG):
82+
test_report = test_fetcher.fetch(query)
83+
assert 'Report is loaded from cache' in caplog.text
84+
assert 'Cached version of query is loaded' in caplog.text
85+
assert test_report == expected_report
86+
5487
def test_fetch_returns_empty_report_for_empty_api_response(self):
5588
test_api_client = api_clients.FakeApiClient(results=[])
5689
fetcher = report_fetcher.ApiReportFetcher(

libs/executors/garf_executors/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525
@tracer.start_as_current_span('setup_executor')
2626
def setup_executor(
27-
source: str, fetcher_parameters: dict[str, str]
27+
source: str,
28+
fetcher_parameters: dict[str, str | int | bool],
29+
enable_cache: bool = False,
2830
) -> type[executor.Executor]:
2931
"""Initializes executors based on a source and parameters."""
3032
if source == 'bq':
@@ -40,7 +42,7 @@ def setup_executor(
4042
else:
4143
concrete_api_fetcher = fetchers.get_report_fetcher(source)
4244
query_executor = ApiQueryExecutor(
43-
concrete_api_fetcher(**fetcher_parameters)
45+
concrete_api_fetcher(**fetcher_parameters, enable_cache=enable_cache)
4446
)
4547
return query_executor
4648

0 commit comments

Comments
 (0)