|
| 1 | +--- |
| 2 | +title: Query Complexity Estimator |
| 3 | +summary: Add a validator to estimate the complexity of GraphQL operations. |
| 4 | +tags: security |
| 5 | +--- |
| 6 | + |
| 7 | +# `QueryComplexityEstimator` |
| 8 | + |
| 9 | +Estimate the complexity of a query and attach its cost to the execution context. |
| 10 | + |
| 11 | +This extension works by traversing through the query document and evaluating |
| 12 | +each node's cost. If no field-specific override is provided, field costs are |
| 13 | +estimated using `default_estimator`. |
| 14 | + |
| 15 | +When the extension finishes estimating the complexity of the operations, |
| 16 | +`callback` is called with a map of complexities of all operations and the |
| 17 | +current execution context. This callback can be used for things such as a |
| 18 | +token-bucket rate-limiter based on query complexity, a complexity logger, or for |
| 19 | +storing the complexities in the current execution context so that it can used by |
| 20 | +downstream resolvers. |
| 21 | + |
| 22 | +Additionally, you can configure the extension also to add the complexity |
| 23 | +dictionary to the response that gets sent to the client by setting |
| 24 | +`response_key`. |
| 25 | + |
| 26 | +## Usage example: |
| 27 | + |
| 28 | +```python |
| 29 | +from typing import Iterator |
| 30 | + |
| 31 | +from graphql.error import GraphQLError |
| 32 | + |
| 33 | +import strawberry |
| 34 | +from strawberry.types import ExecutionContext |
| 35 | +from strawberry.extensions import FieldComplexityEstimator, QueryComplexityEstimator |
| 36 | + |
| 37 | + |
| 38 | +class MyEstimator(FieldComplexityEstimator): |
| 39 | + def estimate_complexity( |
| 40 | + self, child_complexities: Iterator[int], arguments: dict[str, Any] |
| 41 | + ) -> int: |
| 42 | + children_sum = sum(child_complexities) |
| 43 | + # scalar fields cost 1 |
| 44 | + if children_sum == 0: |
| 45 | + return 1 |
| 46 | + |
| 47 | + # non-list object fields cost the sum of their children |
| 48 | + if "page_size" not in field_kwargs: |
| 49 | + return children_sum |
| 50 | + |
| 51 | + # paginated fields cost gets multiplied by page size |
| 52 | + return children_sum * field_kwargs["page_size"] |
| 53 | + |
| 54 | + |
| 55 | +# initialize your rate-limiter somehow |
| 56 | +rate_limiter = ... |
| 57 | + |
| 58 | + |
| 59 | +def my_callback( |
| 60 | + complexities: dict[str, int], execution_context: ExecutionContext |
| 61 | +) -> None: |
| 62 | + # add complexities to execution context |
| 63 | + execution_context.context["complexities"] = complexities |
| 64 | + |
| 65 | + # apply a token-bucket rate-limiter |
| 66 | + total_cost = sum(complexities.values()) |
| 67 | + bucket = rate_limiter.get_bucket_for_key(execution_context.context["user_id"]) |
| 68 | + tokens_left = bucket.take_tokens(total_cost) |
| 69 | + if tokens_left <= 0: |
| 70 | + raise GraphQLError( |
| 71 | + "Rate-limit exhausted. Please wait for some time before trying again." |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +schema = strawberry.Schema( |
| 76 | + Query, |
| 77 | + extensions=[ |
| 78 | + QueryComplexityEstimator( |
| 79 | + default_estimator=MyEstimator(), |
| 80 | + callback=my_callback, |
| 81 | + ), |
| 82 | + ], |
| 83 | +) |
| 84 | +``` |
| 85 | + |
| 86 | +## API reference: |
| 87 | + |
| 88 | +```python |
| 89 | +class QueryComplexityEstimator(default_estimator, callback, response_key): ... |
| 90 | +``` |
| 91 | + |
| 92 | +#### `default_estimator: Union[FieldComplexityEstimator, int]` |
| 93 | + |
| 94 | +The default complexity estimator for fields that don't specify overrides. If |
| 95 | +it's an integer, the default estimator will be a |
| 96 | +`ConstantFieldComplexityEstimator` with the integer value. |
| 97 | + |
| 98 | +#### `callback: Optional[Callable[[Dict[str, int], ExecutionContext], None]]` |
| 99 | + |
| 100 | +Called each time validation runs. Receives a dictionary which is a map of the |
| 101 | +complexity for each operation. |
| 102 | + |
| 103 | +#### `response_key: Optional[str]` |
| 104 | + |
| 105 | +If provided, this extension will add the calculated query complexities to the |
| 106 | +response that gets sent to the client via `get_results()`. The resulting |
| 107 | +complexities will be under the specified key. |
| 108 | + |
| 109 | +```python |
| 110 | +class FieldComplexityEstimator: ... |
| 111 | +``` |
| 112 | + |
| 113 | +Estimate the complexity of a single field. |
| 114 | + |
| 115 | +### `estimate_complexity(child_complexities, arguments) -> None:` |
| 116 | + |
| 117 | +The implementation of the estimator |
| 118 | + |
| 119 | +#### `child_complexities: Iterator[int]` |
| 120 | + |
| 121 | +An iterator over the complexities of child fields, if they exist. This iterator |
| 122 | +is lazy, meaning the complexity of each child will only be evaluated if `next()` |
| 123 | +gets called on the iterator. As such, to avoud unnnecessary computation we |
| 124 | +recommend only iterating over child complexities if you'll use them. |
| 125 | + |
| 126 | +#### `arguments: Dict[str, Any]` |
| 127 | + |
| 128 | +A dict that maps field arguments to their values. |
0 commit comments