Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 7 additions & 32 deletions adserver/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .constants import VIEWS
from .utils import COUNTRY_DICT
from .utils import anonymize_ip_address
from .utils import cached_method
from .utils import calculate_ctr
from .utils import generate_absolute_url
from .utils import get_ad_day
Expand Down Expand Up @@ -1324,49 +1325,23 @@ def days_remaining(self):
remaining_seconds = (end_datetime - timezone.now()).total_seconds()
return max(0, int(remaining_seconds / self.pacing_interval))

def views_today(self, bypass_cache=False):
# Check for a cached value that would come from an annotated queryset
if hasattr(self, "flight_views_today"):
return self.flight_views_today or 0

# Fetch this value from the local cache if present
# Otherwise, populate the local cache
cache_key = f"flight_views_today_{self.pk}"
cached_value = caches[settings.CACHE_LOCAL_ALIAS].get(cache_key)
if cached_value is not None and not bypass_cache:
return cached_value

@cached_method("flight_views_today")
def views_today(self):
aggregation = AdImpression.objects.filter(
advertisement__in=self.advertisements.all(), date=timezone.now().date()
).aggregate(total_views=models.Sum("views"))["total_views"]

# The aggregation can be `None` if there are no impressions
result = aggregation or 0
caches[settings.CACHE_LOCAL_ALIAS].set(cache_key, result, timeout=60 * 15)

return result

def clicks_today(self, bypass_cache=False):
# Check for a cached value that would come from an annotated queryset
if hasattr(self, "flight_clicks_today"):
return self.flight_clicks_today or 0

# Fetch this value from the local cache if present
# Otherwise, populate the local cache
cache_key = f"flight_clicks_today_{self.pk}"
cached_value = caches[settings.CACHE_LOCAL_ALIAS].get(cache_key)
if cached_value is not None and not bypass_cache:
return cached_value
return aggregation or 0

@cached_method("flight_clicks_today")
def clicks_today(self):
aggregation = AdImpression.objects.filter(
advertisement__in=self.advertisements.all(), date=timezone.now().date()
).aggregate(total_clicks=models.Sum("clicks"))["total_clicks"]

# The aggregation can be `None` if there are no impressions
result = aggregation or 0
caches[settings.CACHE_LOCAL_ALIAS].set(cache_key, result, timeout=60 * 15)

return result
return aggregation or 0

def spend_today(self):
"""Get the total spend for this flight today."""
Expand Down
64 changes: 64 additions & 0 deletions adserver/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
from unittest import mock

from django.conf import settings
from django.core.cache import caches
from django.db import IntegrityError
from django.test import override_settings
from django.utils import timezone
Expand Down Expand Up @@ -882,3 +884,65 @@ def test_refund(self):
self.assertAlmostEqual(report.total["views"], 1)
self.assertAlmostEqual(report.total["clicks"], 1)
self.assertAlmostEqual(report.total["cost"], 2.0)


class FlightCachedMethodTest(BaseAdModelsTestCase):
"""Tests for Flight.views_today() and Flight.clicks_today() caching behavior."""

def setUp(self):
super().setUp()
# Clear the local cache before each test to start fresh
caches[settings.CACHE_LOCAL_ALIAS].clear()

def test_views_today_uses_cache_on_second_call(self):
"""Repeated calls to views_today() should not hit the DB after the first call."""
# Prime the cache with a first call
self.flight.views_today()

# Second call should use the cache - zero DB queries
with self.assertNumQueries(0):
result = self.flight.views_today()

self.assertEqual(result, 0)

def test_clicks_today_uses_cache_on_second_call(self):
"""Repeated calls to clicks_today() should not hit the DB after the first call."""
self.flight.clicks_today()

with self.assertNumQueries(0):
result = self.flight.clicks_today()

self.assertEqual(result, 0)

def test_views_today_bypass_cache_hits_db(self):
"""bypass_cache=True should force a fresh DB query."""
self.flight.views_today() # prime cache

# bypass_cache forces at least one DB query
with self.assertNumQueries(1):
self.flight.views_today(bypass_cache=True)

def test_clicks_today_bypass_cache_hits_db(self):
"""bypass_cache=True should force a fresh DB query."""
self.flight.clicks_today() # prime cache

with self.assertNumQueries(1):
self.flight.clicks_today(bypass_cache=True)

def test_views_today_annotated_attr_skips_db(self):
"""An annotated flight_views_today attribute avoids DB and cache lookups."""
self.flight.flight_views_today = 55

with self.assertNumQueries(0):
result = self.flight.views_today()

self.assertEqual(result, 55)

def test_clicks_today_annotated_attr_skips_db(self):
"""An annotated flight_clicks_today attribute avoids DB and cache lookups."""
self.flight.flight_clicks_today = 33

with self.assertNumQueries(0):
result = self.flight.clicks_today()

self.assertEqual(result, 33)
62 changes: 62 additions & 0 deletions adserver/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..utils import anonymize_ip_address
from ..utils import anonymize_user_agent
from ..utils import build_blocked_ip_set
from ..utils import cached_method
from ..utils import calculate_ctr
from ..utils import calculate_ecpm
from ..utils import calculate_percent_diff
Expand Down Expand Up @@ -318,3 +319,64 @@ def test_is_proxy_ip_true(self):

def test_offers_dump_exists_placeholder(self):
self.assertFalse(offers_dump_exists(datetime.date.today()))


class CachedMethodTest(TestCase):
"""Tests for the cached_method decorator."""

_next_pk = 1000

def _make_obj(self, **attrs):
"""Return a (fake_model_instance, call_count_list) pair for testing."""
CachedMethodTest._next_pk += 1
call_count_ref = [0]

class FakeModel:
@cached_method("my_attr")
def my_method(self):
call_count_ref[0] += 1
return 42

obj = FakeModel()
obj.pk = CachedMethodTest._next_pk
for key, value in attrs.items():
setattr(obj, key, value)
return obj, call_count_ref

def test_cached_method_caches_result(self):
"""Second call should return cached value without calling the function."""
obj, call_count_ref = self._make_obj()

result1 = obj.my_method()
result2 = obj.my_method()

self.assertEqual(result1, 42)
self.assertEqual(result2, 42)
self.assertEqual(call_count_ref[0], 1)

def test_cached_method_bypass_cache(self):
"""bypass_cache=True should force re-evaluation."""
obj, call_count_ref = self._make_obj()

obj.my_method()
obj.my_method(bypass_cache=True)

self.assertEqual(call_count_ref[0], 2)

def test_cached_method_uses_annotated_attr(self):
"""An annotated queryset attribute is returned without calling the function."""
obj, call_count_ref = self._make_obj(my_attr=99)

result = obj.my_method()

self.assertEqual(result, 99)
self.assertEqual(call_count_ref[0], 0)

def test_cached_method_annotated_attr_none_returns_zero(self):
"""An annotated None value (e.g. no DB rows) is treated as 0."""
obj, call_count_ref = self._make_obj(my_attr=None)

result = obj.my_method()

self.assertEqual(result, 0)
self.assertEqual(call_count_ref[0], 0)
37 changes: 37 additions & 0 deletions adserver/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Ad server utilities."""

import functools
import hashlib
import ipaddress
import logging
Expand All @@ -17,6 +18,7 @@
from django.contrib.gis.geoip2 import GeoIP2
from django.contrib.gis.geoip2 import GeoIP2Exception
from django.contrib.sites.shortcuts import get_current_site
from django.core.cache import caches
from django.urls import reverse
from django.utils import timezone
from django.utils.crypto import get_random_string
Expand All @@ -40,6 +42,41 @@
COUNTRY_DICT = dict(countries)


def cached_method(attr_name, cache_alias=None, timeout=60 * 15):
"""
Decorator for caching model method results using Django's local cache.

First checks for an ``attr_name`` attribute on the instance (which may come
from an annotated queryset), then checks the local cache, and finally calls
the wrapped method and caches the result.

The wrapped method gains a ``bypass_cache`` keyword argument (default ``False``)
that, when ``True``, forces a fresh DB query and overwrites the cached value.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(self, bypass_cache=False):
# Check for a cached value that would come from an annotated queryset
if hasattr(self, attr_name):
value = getattr(self, attr_name)
return value if value is not None else 0

alias = cache_alias or settings.CACHE_LOCAL_ALIAS
cache_key = f"{self.__class__.__name__}_{attr_name}_{self.pk}"
cached_value = caches[alias].get(cache_key)
if cached_value is not None and not bypass_cache:
return cached_value

result = func(self)
caches[alias].set(cache_key, result, timeout=timeout)
return result

return wrapper

return decorator


@dataclass
class GeolocationData:
"""Dataclass for (temporarily) storing geolocation information for ad viewers."""
Expand Down