Skip to content
Open
97 changes: 6 additions & 91 deletions posthog/api/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import uuid
import hashlib
from collections import defaultdict
from collections.abc import Iterator
from typing import Annotated, Any, Literal, Optional, Union, cast

Expand Down Expand Up @@ -78,6 +77,7 @@
CohortOrEmpty,
CohortType,
)
from products.cohorts.backend.models.dependencies import get_flag_excluded_behavioral_cohort_ids
from products.cohorts.backend.models.util import (
CohortErrorCode,
cohort_filters_have_values,
Expand Down Expand Up @@ -1379,16 +1379,11 @@ def safely_get_queryset(self, queryset) -> QuerySet:
from products.feature_flags.backend.api.feature_flag import _is_realtime_cohort_flag_targeting_enabled

allow_realtime_backfilled = _is_realtime_cohort_flag_targeting_enabled(self.request)
# Lists every column read by _find_behavioral_cohorts (is_static, filters)
# and Cohort.is_flag_compatible (cohort_type, last_backfill_person_properties_at);
# dropping one triggers a per-cohort deferred query (an N+1 that only bites a
# team with thousands of cohorts). Deferring the rest keeps the graph scan cheap.
graph_source = queryset.only(
"id", "is_static", "filters", "cohort_type", "last_backfill_person_properties_at"
)
all_cohorts = {cohort.id: cohort for cohort in graph_source.all()}
behavioral_cohort_ids = self._find_behavioral_cohorts(
all_cohorts, allow_realtime_backfilled=allow_realtime_backfilled
# The flag's cohort typeahead hits this endpoint on every keystroke, so the
# behavioral set is computed once per team and cached (invalidated on cohort
# writes); see get_flag_excluded_behavioral_cohort_ids.
behavioral_cohort_ids = get_flag_excluded_behavioral_cohort_ids(
self.team_id, allow_realtime_backfilled=allow_realtime_backfilled
)
queryset = queryset.exclude(id__in=behavioral_cohort_ids)

Expand Down Expand Up @@ -1420,86 +1415,6 @@ def safely_get_queryset(self, queryset) -> QuerySet:
.order_by("-created_at")
)

def _find_behavioral_cohorts(
self, all_cohorts: dict[int, Cohort], *, allow_realtime_backfilled: bool = False
) -> set[int]:
"""Find cohorts that are behavioral, or reference (transitively) a behavioral cohort.

A cohort is affected if it's a behavioral seed, or references one through the
dependency graph. We walk the *reverse* graph once from the seeds (O(V+E)) —
every node that can reach a seed via forward edges is affected.

When allow_realtime_backfilled is True, realtime cohorts that have been backfilled
are not seeds: they can be evaluated via the cohort_membership table during flag
evaluation. (They can still be pulled in if they reference another seed.)
"""
graph, behavioral_cohorts = self._build_cohort_dependency_graph(all_cohorts)

flag_compatible: set[int] = set()
if allow_realtime_backfilled:
flag_compatible = {
cid for cid in behavioral_cohorts if (cohort := all_cohorts.get(cid)) and cohort.is_flag_compatible
}
seeds = behavioral_cohorts - flag_compatible

# Reverse adjacency: target -> sources that reference it.
reverse: dict[int, set[int]] = defaultdict(set)
for source_id, targets in graph.items():
for target_id in targets:
reverse[target_id].add(source_id)

affected = set(seeds)
stack = list(seeds)
while stack:
node = stack.pop()
for source_id in reverse.get(node, ()):
if source_id not in affected:
affected.add(source_id)
stack.append(source_id)

return affected

def _build_cohort_dependency_graph(self, all_cohorts: dict[int, Cohort]) -> tuple[dict[int, set[int]], set[int]]:
"""
Builds a directed graph of cohort dependencies and identifies behavioral cohorts.
Returns (adjacency_list, behavioral_cohort_ids).
"""
graph = defaultdict(set)
behavioral_cohorts = set()

def check_property_values(values: Any, source_id: int) -> None:
"""Process property values to build graph edges and identify behavioral cohorts."""
if not isinstance(values, list):
return

for value in values:
if not isinstance(value, dict):
continue

if value.get("type") == "behavioral":
behavioral_cohorts.add(source_id)
elif value.get("type") == "cohort":
try:
target_id = int(value.get("value", "0"))
if target_id in all_cohorts:
graph[source_id].add(target_id)
except ValueError:
continue
elif value.get("type") in ("AND", "OR") and value.get("values"):
check_property_values(value["values"], source_id)

for cohort_id, cohort in all_cohorts.items():
# Static cohorts have pre-computed membership and don't re-evaluate
# their filters, so they're always safe to use regardless of filter type.
if cohort.is_static:
continue
if cohort.filters:
properties = cohort.filters.get("properties", {})
if isinstance(properties, dict):
check_property_values(properties.get("values", []), cohort_id)

return graph, behavioral_cohorts

@extend_schema(
parameters=[
OpenApiParameter(
Expand Down
9 changes: 4 additions & 5 deletions posthog/api/test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from posthog.schema import PersonsOnEventsMode, PropertyOperator

from posthog.api.cohort import COHORT_USED_IN_PAGE_SIZE, CohortViewSet
from posthog.api.cohort import COHORT_USED_IN_PAGE_SIZE
from posthog.clickhouse.client.execute import sync_execute
from posthog.models import Person, User
from posthog.models.activity_logging.activity_log import ActivityLog
Expand All @@ -43,6 +43,7 @@

from products.actions.backend.models.action import Action
from products.cohorts.backend.models.cohort import Cohort, CohortType
from products.cohorts.backend.models.dependencies import find_behavioral_cohorts
from products.exports.backend.api.test.test_exports import TestExportMixin
from products.feature_flags.backend.models.feature_flag import FeatureFlag
from products.product_analytics.backend.models.insight import Insight
Expand Down Expand Up @@ -1806,13 +1807,11 @@ def make(cid: int, *, behavioral: bool = False, refs: tuple[int, ...] = (), real
make(7, refs=(1, 5)),
]
}
viewset = CohortViewSet()

# Without the realtime exemption, every behavioral cohort and its referrers are excluded.
self.assertEqual(viewset._find_behavioral_cohorts(cohorts), {1, 2, 3, 5, 6, 7})
self.assertEqual(find_behavioral_cohorts(cohorts), {1, 2, 3, 5, 6, 7})
# With it, 5 is flag-compatible (not a seed) and 6 only referenced 5, so both stay.
# 7 still reaches real seed 1, so it remains excluded.
self.assertEqual(viewset._find_behavioral_cohorts(cohorts, allow_realtime_backfilled=True), {1, 2, 3, 7})
self.assertEqual(find_behavioral_cohorts(cohorts, allow_realtime_backfilled=True), {1, 2, 3, 7})

@patch("posthog.api.cohort.report_user_action")
def test_basic_list_omits_heavy_fields(self, patch_capture):
Expand Down
143 changes: 143 additions & 0 deletions products/cohorts/backend/models/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from collections import defaultdict
from typing import Any

from django.core.cache import cache
from django.db import transaction
from django.db.models import Q, TextField
from django.db.models.functions import Cast
from django.db.models.signals import post_delete, post_save, pre_save
from django.dispatch import receiver

Expand Down Expand Up @@ -35,6 +40,141 @@ def _cohort_dependents_key(cohort_id: int) -> str:
return f"cohort:dependents:{cohort_id}"


# Set of behavioral (flag-incompatible) cohort ids per team, hidden from the feature-flag
# property picker. Cached because the flag's cohort typeahead hits the cohorts list endpoint
# on every keystroke, and recomputing the dependency graph there means loading every cohort
# for the team into memory. The TTL is a backstop; the cache is invalidated whenever a cohort
# in the team changes (see cohort_changed / cohort_deleted). It is keyed on
# allow_realtime_backfilled because that toggles which realtime cohorts count as seeds.
BEHAVIORAL_COHORT_IDS_CACHE_TIMEOUT = 60 * 60 # 1 hour


def _behavioral_cohort_ids_key(team_id: int, allow_realtime_backfilled: bool) -> str:
return f"cohort:flag_excluded_behavioral_ids:{team_id}:{int(allow_realtime_backfilled)}"


def _build_cohort_dependency_graph(all_cohorts: dict[int, Cohort]) -> tuple[dict[int, set[int]], set[int]]:
"""Build a directed graph of cohort dependencies and identify behavioral cohorts.

Returns (adjacency_list, behavioral_cohort_ids). Static cohorts are skipped: they have
pre-computed membership and don't re-evaluate their filters, so they're always safe to
use regardless of filter type.
"""
graph: dict[int, set[int]] = defaultdict(set)
behavioral_cohorts: set[int] = set()

def check_property_values(values: Any, source_id: int) -> None:
if not isinstance(values, list):
return

for value in values:
if not isinstance(value, dict):
continue

if value.get("type") == "behavioral":
behavioral_cohorts.add(source_id)
elif value.get("type") == "cohort":
try:
target_id = int(value.get("value", "0"))
except ValueError:
continue
if target_id in all_cohorts:
graph[source_id].add(target_id)
elif value.get("type") in ("AND", "OR") and value.get("values"):
check_property_values(value["values"], source_id)

for cohort_id, cohort in all_cohorts.items():
if cohort.is_static:
continue
if cohort.filters:
properties = cohort.filters.get("properties", {})
if isinstance(properties, dict):
check_property_values(properties.get("values", []), cohort_id)

return graph, behavioral_cohorts


def find_behavioral_cohorts(all_cohorts: dict[int, Cohort], *, allow_realtime_backfilled: bool = False) -> set[int]:
"""Find cohorts that are behavioral, or reference (transitively) a behavioral cohort.

A cohort is affected if it's a behavioral seed, or references one through the dependency
graph. We walk the *reverse* graph once from the seeds (O(V+E)) — every node that can
reach a seed via forward edges is affected.

When allow_realtime_backfilled is True, realtime cohorts that have been backfilled are
not seeds: they can be evaluated via the cohort_membership table during flag evaluation.
(They can still be pulled in if they reference another seed.)
"""
graph, behavioral_cohorts = _build_cohort_dependency_graph(all_cohorts)

flag_compatible: set[int] = set()
if allow_realtime_backfilled:
flag_compatible = {
cid for cid in behavioral_cohorts if (cohort := all_cohorts.get(cid)) and cohort.is_flag_compatible
}
seeds = behavioral_cohorts - flag_compatible

# Reverse adjacency: target -> sources that reference it.
reverse: dict[int, set[int]] = defaultdict(set)
for source_id, targets in graph.items():
for target_id in targets:
reverse[target_id].add(source_id)

affected = set(seeds)
stack = list(seeds)
while stack:
node = stack.pop()
for source_id in reverse.get(node, ()):
if source_id not in affected:
affected.add(source_id)
stack.append(source_id)

return affected


def _compute_flag_excluded_behavioral_cohort_ids(team_id: int, *, allow_realtime_backfilled: bool) -> set[int]:
# Only non-static cohorts whose filters reference a behavioral node or another cohort can
# be a seed or reach one; the rest are leaves that never get excluded. Filtering them out
# in SQL keeps the in-memory graph — and the JSON we parse — small. The bare-word match
# can't produce false negatives: a behavioral or cohort node always serializes the literal
# "behavioral"/"cohort" substring. A false positive (e.g. a person-property value of
# "cohort") only loads an extra leaf, which the graph walk then ignores.
graph_source = (
Cohort.objects.filter(team_id=team_id, deleted=False, is_static=False)
.annotate(_filters_text=Cast("filters", output_field=TextField()))
.filter(Q(_filters_text__icontains="behavioral") | Q(_filters_text__icontains="cohort"))
.only("id", "is_static", "filters", "cohort_type", "last_backfill_person_properties_at")
)
all_cohorts = {cohort.id: cohort for cohort in graph_source}
return find_behavioral_cohorts(all_cohorts, allow_realtime_backfilled=allow_realtime_backfilled)


def get_flag_excluded_behavioral_cohort_ids(team_id: int, *, allow_realtime_backfilled: bool | None) -> set[int]:
"""Behavioral (flag-incompatible) cohort ids for a team, cached across requests."""
# feature_enabled can return None when the flag can't be evaluated; normalize so the
# cache key is stable and the compute path sees a real bool.
allow_realtime_backfilled = bool(allow_realtime_backfilled)
cache_key = _behavioral_cohort_ids_key(team_id, allow_realtime_backfilled)
cached = cache.get(cache_key)
if cached is not None: # empty list is a valid cached result, not a miss
return set(cached)

behavioral_cohort_ids = _compute_flag_excluded_behavioral_cohort_ids(
team_id, allow_realtime_backfilled=allow_realtime_backfilled
)
cache.set(cache_key, list(behavioral_cohort_ids), timeout=BEHAVIORAL_COHORT_IDS_CACHE_TIMEOUT)
return behavioral_cohort_ids


def _invalidate_team_behavioral_cohort_cache(team_id: int) -> None:
cache.delete_many(
[
_behavioral_cohort_ids_key(team_id, allow_realtime_backfilled=True),
_behavioral_cohort_ids_key(team_id, allow_realtime_backfilled=False),
]
)


def extract_cohort_dependencies(cohort: Cohort) -> set[int]:
"""
Extract cohort dependencies from the given cohort.
Expand Down Expand Up @@ -339,6 +479,7 @@ def cohort_changed(sender, instance, **kwargs):
return

transaction.on_commit(lambda: _on_cohort_changed(instance))
transaction.on_commit(lambda: _invalidate_team_behavioral_cohort_cache(instance.team_id))


@receiver(post_save, sender=Cohort)
Expand Down Expand Up @@ -400,6 +541,7 @@ def cohort_deleted(sender, instance, **kwargs):
Clear and rebuild dependency caches when cohort is deleted.
"""
transaction.on_commit(lambda: _on_cohort_changed(instance, always_invalidate=True))
transaction.on_commit(lambda: _invalidate_team_behavioral_cohort_cache(instance.team_id))


@receiver(post_delete, sender=Team)
Expand All @@ -413,5 +555,6 @@ def clear_cache():
for cohort_id in team_cohorts:
cache.delete(_cohort_dependencies_key(cohort_id))
cache.delete(_cohort_dependents_key(cohort_id))
_invalidate_team_behavioral_cohort_cache(instance.pk)

transaction.on_commit(clear_cache)
Loading
Loading