Skip to content

Commit 4a21589

Browse files
[youtube-data-api] feat: add support for filters in WHERE statement
1 parent 7cc19bd commit 4a21589

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

libs/community/google/youtube/youtube-data-api/garf_youtube_data_api/api_clients.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,15 @@
1313
# limitations under the License.
1414
"""Creates API client for YouTube Data API."""
1515

16+
import datetime
17+
import functools
1618
import logging
19+
import operator
1720
import os
1821
import warnings
1922

23+
import dateutil
24+
import pydantic
2025
from garf_core import api_clients, query_editor
2126
from googleapiclient.discovery import build
2227
from googleapiclient.errors import HttpError
@@ -92,6 +97,39 @@ def get_response(
9297
if data := result.get('items'):
9398
results.extend(data)
9499

100+
if filters := request.filters:
101+
span.set_attribute('youtube_data_api.filters', filters)
102+
filtered_results = []
103+
comparators = []
104+
for filter in filters:
105+
field, op, value = filter.split(' ')
106+
comparators.append(Comparator(field=field, operator=op, value=value))
107+
with telemetry.tracer.start_as_current_span(
108+
'youtube_data_api.apply_filters'
109+
):
110+
for row in results:
111+
include_row = True
112+
for comparator in comparators:
113+
key = comparator.field.split('.')
114+
res = functools.reduce(operator.getitem, key, row)
115+
if isinstance(comparator.value, datetime.date):
116+
expr = f'res {comparator.operator} comp'
117+
include_row = eval(
118+
expr,
119+
{
120+
'res': dateutil.parser.parse(res).date(),
121+
'comp': comparator.value,
122+
},
123+
)
124+
else:
125+
include_row = eval(
126+
f'{res} {comparator.operator} {comparator.value}', globals()
127+
)
128+
if not include_row:
129+
break
130+
if include_row:
131+
filtered_results.append(row)
132+
return api_clients.GarfApiResponse(results=filtered_results)
95133
return api_clients.GarfApiResponse(results=results)
96134

97135
def _list(
@@ -105,3 +143,15 @@ def _list(
105143
return service.list(part=part, **kwargs).execute()
106144
except HttpError:
107145
return {'items': None}
146+
147+
148+
class Comparator(pydantic.BaseModel):
149+
field: str
150+
operator: str
151+
value: str | datetime.date
152+
153+
def model_post_init(self, __context) -> None:
154+
if self.operator == '=':
155+
self.operator = '=='
156+
if self.field in ('snippet.publishedAt'):
157+
self.value = dateutil.parser.parse(self.value).date()
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import garf_core
16+
import pytest
17+
from garf_youtube_data_api.report_fetcher import YouTubeDataApiReportFetcher
18+
19+
20+
class TestYouTubeDataApiReportFetcher:
21+
@pytest.fixture
22+
def fetcher(self):
23+
return YouTubeDataApiReportFetcher()
24+
25+
def test_fetch(self, mocker, fetcher):
26+
query = """
27+
SELECT
28+
id,
29+
statistics.viewCount AS views,
30+
statistics.likeCount AS likes,
31+
snippet.publishedAt AS published_at,
32+
FROM videos
33+
WHERE
34+
snippet.publishedAt > 2025-01-01
35+
AND statistics.viewCount = 11
36+
AND statistics.likeCount > 1
37+
"""
38+
39+
mocker.patch(
40+
'garf_youtube_data_api.api_clients.YouTubeDataApiClient._list',
41+
return_value={
42+
'items': [
43+
{
44+
'id': 1,
45+
'statistics': {'viewCount': 10, 'likeCount': 1},
46+
'snippet': {'publishedAt': '2024-07-10T22:15:44Z'},
47+
},
48+
{
49+
'id': 2,
50+
'statistics': {'viewCount': 11, 'likeCount': 2},
51+
'snippet': {'publishedAt': '2025-07-10T22:15:44Z'},
52+
},
53+
],
54+
},
55+
)
56+
57+
result = fetcher.fetch(query, id=['1', '2'])
58+
expected_report = garf_core.GarfReport(
59+
results=[[2, 11, 2, '2025-07-10T22:15:44Z']],
60+
column_names=['id', 'views', 'likes', 'published_at'],
61+
)
62+
63+
assert result == expected_report

0 commit comments

Comments
 (0)