Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/community/google/ads/garf_google_ads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
'GoogleAdsApiReportFetcher',
]

__version__ = '0.0.3'
__version__ = '0.0.4'
27 changes: 22 additions & 5 deletions libs/community/google/ads/garf_google_ads/report_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
import functools
import operator
import warnings

import garf_core

Expand Down Expand Up @@ -51,7 +52,14 @@ def __init__(
if not api_client:
api_client = GoogleAdsApiClient(**kwargs)
self.parallel_threshold = parallel_threshold
super().__init__(api_client, parser, query_spec, builtin_queries, **kwargs)
super().__init__(
api_client=api_client,
parser=parser,
query_specification_builder=query_spec,
builtin_queries=builtin_queries,
preprocessors={'account': self.expand_mcc},
**kwargs,
)

def fetch(
self,
Expand Down Expand Up @@ -90,7 +98,7 @@ def fetch(
args = {}
if expand_mcc or customer_ids_query:
account = self.expand_mcc(
customer_ids=account, customer_ids_query=customer_ids_query
account=account, customer_ids_query=customer_ids_query
)
if not account:
raise GoogleAdsApiReportFetcherError(
Expand Down Expand Up @@ -132,20 +140,29 @@ async def run_with_semaphore(fn):

def expand_mcc(
self,
customer_ids: str | list[str],
account: str | list[str],
customer_ids_query: str | None = None,
customer_ids: str | list | None = None,
) -> list[str]:
"""Performs Manager account(s) expansion to child accounts.

Args:
customer_ids: Manager account(s) to be expanded.
account: Manager account(s) to be expanded.
customer_ids_query: GAQL query used to reduce the number of customer_ids.
customer_ids: Manager account(s) to be expanded.

Returns:
All child accounts under provided customer_ids.
"""
if customer_ids and not account:
warnings.warn(
'`customer_ids` is deprecated, used `account` instead',
category=DeprecationWarning,
stacklevel=2,
)
account = customer_ids
return self._get_customer_ids(
seed_customer_ids=customer_ids, customer_ids_query=customer_ids_query
seed_customer_ids=account, customer_ids_query=customer_ids_query
)

def _get_customer_ids(
Expand Down
2 changes: 1 addition & 1 deletion libs/core/garf_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
'ApiReportFetcher',
]

__version__ = '0.6.3'
__version__ = '0.7.0'
11 changes: 10 additions & 1 deletion libs/core/garf_core/report_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import asyncio
import logging
import pathlib
from typing import Callable
from typing import Any, Callable

from opentelemetry import trace
from typing_extensions import TypeAlias

from garf_core import (
api_clients,
Expand All @@ -40,6 +41,8 @@

logger = logging.getLogger(__name__)

Processor: TypeAlias = Callable[..., Any]


class ApiReportFetcherError(exceptions.GarfError):
"""Base exception for all ApiReportFetchers."""
Expand Down Expand Up @@ -81,6 +84,8 @@ def __init__(
enable_cache: bool = False,
cache_path: str | pathlib.Path | None = None,
cache_ttl_seconds: int = 3600,
preprocessors: dict[str, Processor] | None = None,
postprocessors: dict[str, Processor] | None = None,
**kwargs: str,
) -> None:
"""Instantiates ApiReportFetcher based on provided api client.
Expand All @@ -94,6 +99,8 @@ def __init__(
enable_cache: Whether to load / save report from / to cache.
cache_path: Optional path to cache folder.
cache_ttl_seconds: Maximum lifespan of cached reports.
preprocessors: Functions to execute before fetching the query.
postprocessors: Functions to execute after fetching the query.
"""
self.api_client = api_client
self.parser = parser
Expand All @@ -102,6 +109,8 @@ def __init__(
self.enable_cache = enable_cache
self.cache = cache.GarfCache(cache_path, cache_ttl_seconds)
self.builtin_queries = builtin_queries or {}
self.preprocessors = preprocessors or {}
self.postprocessors = postprocessors or {}

def add_builtin_queries(
self,
Expand Down
4 changes: 2 additions & 2 deletions libs/executors/garf_executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setup_executor(
else:
concrete_api_fetcher = fetchers.get_report_fetcher(source)
query_executor = ApiQueryExecutor(
concrete_api_fetcher(
fetcher=concrete_api_fetcher(
**fetcher_parameters,
enable_cache=enable_cache,
cache_ttl_seconds=cache_ttl_seconds,
Expand All @@ -57,4 +57,4 @@ def setup_executor(
'ApiExecutionContext',
]

__version__ = '0.1.7'
__version__ = '0.2.0'
6 changes: 5 additions & 1 deletion libs/executors/garf_executors/api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, fetcher: report_fetcher.ApiReportFetcher) -> None:
fetcher: Instantiated report fetcher.
"""
self.fetcher = fetcher
super().__init__(
preprocessors=self.fetcher.preprocessors,
postprocessors=self.fetcher.postprocessors,
)

@classmethod
def from_fetcher_alias(
Expand All @@ -59,7 +63,7 @@ def from_fetcher_alias(
if not fetcher_parameters:
fetcher_parameters = {}
concrete_api_fetcher = fetchers.get_report_fetcher(source)
return ApiQueryExecutor(concrete_api_fetcher(**fetcher_parameters))
return ApiQueryExecutor(fetcher=concrete_api_fetcher(**fetcher_parameters))

@tracer.start_as_current_span('api.execute')
def execute(
Expand Down
39 changes: 38 additions & 1 deletion libs/executors/garf_executors/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Defines common functionality between executors."""

import asyncio
import inspect
from typing import Optional

from garf_core import report_fetcher
from opentelemetry import trace

from garf_executors import execution_context
Expand All @@ -25,6 +28,14 @@
class Executor:
"""Defines common functionality between executors."""

def __init__(
self,
preprocessors: Optional[dict[str, report_fetcher.Processor]] = None,
postprocessors: Optional[dict[str, report_fetcher.Processor]] = None,
) -> None:
self.preprocessors = preprocessors or {}
self.postprocessors = postprocessors or {}

@tracer.start_as_current_span('api.execute_batch')
def execute_batch(
self,
Expand All @@ -34,6 +45,9 @@ def execute_batch(
) -> list[str]:
"""Executes batch of queries for a common context.

If an executor has any pre/post processors, executes them first while
modifying the context.

Args:
batch: Mapping between query_title and its text.
context: Execution context.
Expand All @@ -44,11 +58,19 @@ def execute_batch(
"""
span = trace.get_current_span()
span.set_attribute('api.parallel_threshold', parallel_threshold)
return asyncio.run(
_handle_processors(processors=self.preprocessors, context=context)
results = asyncio.run(
self._run(
batch=batch, context=context, parallel_threshold=parallel_threshold
)
)
_handle_processors(processors=self.postprocessors, context=context)
return results

def add_preprocessor(
self, preprocessors: dict[str, report_fetcher.Processor]
) -> None:
self.preprocessors.update(preprocessors)

async def aexecute(
self,
Expand Down Expand Up @@ -85,3 +107,18 @@ async def run_with_semaphore(fn):
for title, query in batch.items()
]
return await asyncio.gather(*(run_with_semaphore(task) for task in tasks))


def _handle_processors(
processors: dict[str, report_fetcher.Processor],
context: execution_context.ExecutionContext,
) -> None:
for k, processor in processors.items():
processor_signature = list(inspect.signature(processor).parameters.keys())
if k in context.fetcher_parameters:
processor_parameters = {
k: v
for k, v in context.fetcher_parameters.items()
if k in processor_signature
}
context.fetcher_parameters[k] = processor(**processor_parameters)