Skip to content

Commit c76fda1

Browse files
authored
Merge pull request #56 from tartiflette/ISSUE-55
ISSUE-55 - Provides a context factory parameter
2 parents ddaf215 + ed6e739 commit c76fda1

10 files changed

Lines changed: 116 additions & 37 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
1212
## [Released]
1313

1414
- [1.x.x]
15+
- [1.1.x]
16+
- [1.1.0](./changelogs/1.1.0.md) - 2019-10-02
1517
- [1.0.x]
1618
- [1.0.0](./changelogs/1.0.0.md) - 2019-09-12
1719
- [0.x.x]

changelogs/1.1.0.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# [1.1.0] -- 2019-10-02
2+
3+
## Added
4+
5+
- [ISSUE-55](https://github.com/tartiflette/tartiflette-aiohttp/issues/55) - Add
6+
a new optional `context_factory` parameter to the `register_graphql_handlers`
7+
function. This parameter can take a coroutine function which will be called on
8+
each request with the following signature:
9+
```python
10+
async def context_factory(
11+
context: Dict[str, Any], req: "aiohttp.web.Request"
12+
) -> Dict[str, Any]:
13+
"""
14+
Generates a new context.
15+
:param context: the value filled in through the `executor_context`
16+
parameter
17+
:param req: the incoming aiohttp request instance
18+
:type context: Dict[str, Any]
19+
:type req: aiohttp.web.Request
20+
:return: the context for the incoming request
21+
:rtype: Dict[str, Any]
22+
"""
23+
```
24+
25+
The aim of this function will be to returns the context which will be forwarded
26+
to the Tartiflette engine on the `execute` or `subscribe` method.

changelogs/next.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# [next]
1+
# [Next]
22

33
## Added
44

55
## Changed
66

7-
## Fixed
7+
## Fixed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"isort==4.3.21",
1414
]
1515

16-
_VERSION = "1.0.0"
16+
_VERSION = "1.1.0"
1717

1818
_PACKAGES = find_packages(exclude=["tests*"])
1919

tartiflette_aiohttp/__init__.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import json
22

33
from functools import partial
4-
from inspect import iscoroutine
5-
from typing import Any, Dict, List, Optional, Union
4+
from inspect import iscoroutine, iscoroutinefunction
5+
from typing import Any, Callable, Dict, List, Optional, Union
66

77
from tartiflette import Engine
8+
from tartiflette_aiohttp._context_factory import default_context_factory
89
from tartiflette_aiohttp._graphiql import graphiql_handler
910
from tartiflette_aiohttp._handler import Handlers
1011
from tartiflette_aiohttp._subscription_ws_handler import (
@@ -35,15 +36,15 @@ def validate_and_compute_graphiql_option(
3536
def _set_subscription_ws_handler(
3637
app: "Application",
3738
subscription_ws_endpoint: Optional[str],
38-
context: Dict[str, Any],
39+
context_factory: Callable,
3940
) -> None:
4041
if not subscription_ws_endpoint:
4142
return
4243

4344
app.router.add_route(
4445
"GET",
4546
subscription_ws_endpoint,
46-
AIOHTTPSubscriptionHandler(app, context),
47+
AIOHTTPSubscriptionHandler(app, context_factory),
4748
)
4849

4950

@@ -116,6 +117,7 @@ def register_graphql_handlers(
116117
engine_modules: Optional[
117118
List[Union[str, Dict[str, Union[str, Dict[str, str]]]]]
118119
] = None,
120+
context_factory: Optional[Callable] = None,
119121
) -> "Application":
120122
"""Register a Tartiflette Engine to an app
121123
@@ -133,10 +135,12 @@ def register_graphql_handlers(
133135
graphiql_enabled {bool} -- Determines whether or not we should handle a GraphiQL endpoint (default: {False})
134136
graphiql_options {dict} -- Customization options for the GraphiQL instance (default: {None})
135137
engine_modules: {Optional[List[Union[str, Dict[str, Union[str, Dict[str, str]]]]]]} -- Module to import (default:{None})
138+
context_factory: {Optional[Callable]} -- coroutine function in charge of generating the context for each request (default: {None})
136139
137140
Raises:
138141
Exception -- On bad sdl/engine parameter combinaison.
139142
Exception -- On unsupported HTTP Method.
143+
Exception -- if `context_factory` is filled in without a coroutine function.
140144
141145
Return:
142146
The app object.
@@ -150,6 +154,16 @@ def register_graphql_handlers(
150154
if not executor_http_methods:
151155
executor_http_methods = ["GET", "POST"]
152156

157+
if context_factory is None:
158+
context_factory = default_context_factory
159+
160+
if not iscoroutinefunction(context_factory):
161+
raise Exception(
162+
"`context_factory` parameter should be a coroutine function."
163+
)
164+
165+
context_factory = partial(context_factory, executor_context)
166+
153167
if not engine:
154168
engine = Engine()
155169

@@ -174,14 +188,14 @@ def register_graphql_handlers(
174188
executor_http_endpoint,
175189
partial(
176190
getattr(Handlers, "handle_%s" % method.lower()),
177-
executor_context,
191+
context_factory=context_factory,
178192
),
179193
)
180194
except AttributeError:
181195
raise Exception("Unsupported < %s > http method" % method)
182196

183197
_set_subscription_ws_handler(
184-
app, subscription_ws_endpoint, executor_context
198+
app, subscription_ws_endpoint, context_factory
185199
)
186200

187201
_set_graphiql_handler(
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any, Dict
2+
3+
__all__ = ("default_context_factory",)
4+
5+
6+
async def default_context_factory(
7+
context: Dict[str, Any], req: "aiohttp.web.Request"
8+
) -> Dict[str, Any]:
9+
"""
10+
Generates a new context.
11+
:param context: the value filled in through the `executor_context`
12+
parameter
13+
:param req: the incoming aiohttp request instance
14+
:type context: Dict[str, Any]
15+
:type req: aiohttp.web.Request
16+
:return: the context for the incoming request
17+
:rtype: Dict[str, Any]
18+
"""
19+
return {**context, "req": req}

tartiflette_aiohttp/_handler.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import json
22
import logging
33

4-
from copy import copy
5-
64
from aiohttp import web
75

86
logger = logging.getLogger(__name__)
@@ -34,8 +32,11 @@ def prepare_response(data):
3432
return web.json_response(data, headers=headers, dumps=json.dumps)
3533

3634

37-
async def _handle_query(req, query, query_vars, operation_name, context):
38-
context = copy(context)
35+
async def _handle_query(
36+
req, query, query_vars, operation_name, context_factory
37+
):
38+
context = await context_factory(req)
39+
3940
try:
4041
if not operation_name:
4142
operation_name = None
@@ -94,23 +95,23 @@ async def _post_params(req):
9495

9596
class Handlers:
9697
@staticmethod
97-
async def _handle(param_func, user_c, req):
98-
user_c["req"] = req
99-
98+
async def _handle(param_func, req, context_factory):
10099
try:
101100
qry, qry_vars, oprn_name = await param_func(req)
102101
return prepare_response(
103-
await _handle_query(req, qry, qry_vars, oprn_name, user_c)
102+
await _handle_query(
103+
req, qry, qry_vars, oprn_name, context_factory
104+
)
104105
)
105106
except BadRequestError as e:
106107
return prepare_response(
107108
{"data": None, "errors": _format_errors([e])}
108109
)
109110

110111
@staticmethod
111-
async def handle_get(user_context, req):
112-
return await Handlers._handle(_get_params, user_context, req)
112+
async def handle_get(req, context_factory):
113+
return await Handlers._handle(_get_params, req, context_factory)
113114

114115
@staticmethod
115-
async def handle_post(user_context, req):
116-
return await Handlers._handle(_post_params, user_context, req)
116+
async def handle_post(req, context_factory):
117+
return await Handlers._handle(_post_params, req, context_factory)

tartiflette_aiohttp/_subscription_ws_handler.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22

33
from asyncio import ensure_future, shield, wait
4-
from typing import Any, AsyncIterator, Dict, Optional, Set
4+
from typing import Any, AsyncIterator, Callable, Dict, Optional, Set
55

66
from aiohttp import WSMsgType, web
77

@@ -79,9 +79,11 @@ async def close(self, code: int) -> None:
7979

8080

8181
class AIOHTTPSubscriptionHandler:
82-
def __init__(self, app: "Application", context: Dict[str, Any]) -> None:
82+
def __init__(self, app: "Application", context_factory: Callable) -> None:
8383
self._app: "Application" = app
84-
self._context = context
84+
self._context_factory = context_factory
85+
self._socket: Optional["web.WebSocketResponse"] = None
86+
self._context: Optional[Dict[str, Any]] = None
8587

8688
async def _send_message(
8789
self,
@@ -255,9 +257,8 @@ async def _handle_request(self) -> None:
255257
await self._on_close(connection_context, tasks)
256258

257259
async def __call__(self, request: "Request") -> "WebSocketResponse":
258-
self._socket = web.WebSocketResponse( # pylint: disable=attribute-defined-outside-init
259-
protocols=(WS_PROTOCOL,)
260-
)
260+
self._socket = web.WebSocketResponse(protocols=(WS_PROTOCOL,))
261+
self._context = await self._context_factory(request)
261262
await self._socket.prepare(request)
262263
await shield(self._handle_request())
263264
return self._socket

tests/integration/test_handlers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from functools import partial
12
from unittest.mock import Mock
23

34
import pytest
45

6+
from tartiflette_aiohttp import default_context_factory
7+
58

69
@pytest.mark.asyncio
710
async def test_handler__handle_query__context_unicity():
@@ -31,18 +34,18 @@ async def resolver_hello(parent, args, ctx, info):
3134
a_req = Mock()
3235
a_req.app = {"ttftt_engine": tftt_engine}
3336

34-
ctx = {}
37+
context_factory = partial(default_context_factory, {})
3538

3639
await _handle_query(
37-
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
40+
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
3841
)
3942

4043
await _handle_query(
41-
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
44+
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
4245
)
4346

4447
b_response = await _handle_query(
45-
a_req, 'query { hello(name: "Chuck") }', None, None, ctx
48+
a_req, 'query { hello(name: "Chuck") }', None, None, context_factory
4649
)
4750

4851
assert b_response == {"data": {"hello": "hello 1"}}
@@ -71,7 +74,7 @@ async def resolver_hello(parent, args, ctx, info):
7174
a_req = Mock()
7275
a_req.app = {"ttftt_engine": tftt_engine}
7376

74-
ctx = {}
77+
context_factory = partial(default_context_factory, {})
7578

7679
result = await _handle_query(
7780
a_req,
@@ -82,7 +85,7 @@ async def resolver_hello(parent, args, ctx, info):
8285
""",
8386
None,
8487
"B",
85-
ctx,
88+
context_factory,
8689
)
8790

8891
assert result == {"data": {"hello": "hello Bar"}}

tests/unit/test_handlers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from functools import partial
12
from unittest.mock import Mock
23

34
import pytest
45

56
from asynctest import CoroutineMock
67

8+
from tartiflette_aiohttp import default_context_factory
9+
710

811
@pytest.mark.parametrize(
912
"value,expected",
@@ -32,7 +35,11 @@ async def test_handler__handle_query():
3235
a_req.app = {"ttftt_engine": an_engine}
3336

3437
a_response = await _handle_query(
35-
a_req, "query a {}", {"B": "C"}, "a", {"D": "E"}
38+
a_req,
39+
"query a {}",
40+
{"B": "C"},
41+
"a",
42+
partial(default_context_factory, {"D": "E"}),
3643
)
3744

3845
assert a_response == "T"
@@ -42,7 +49,7 @@ async def test_handler__handle_query():
4249
{
4350
"query": "query a {}",
4451
"variables": {"B": "C"},
45-
"context": {"D": "E"},
52+
"context": {"D": "E", "req": a_req},
4653
"operation_name": "a",
4754
},
4855
)
@@ -60,7 +67,11 @@ async def test_handler__handle_query_nok():
6067
a_req.app = {}
6168

6269
a_response = await _handle_query(
63-
a_req, "query a {}", {"B": "C"}, "a", {"D": "E"}
70+
a_req,
71+
"query a {}",
72+
{"B": "C"},
73+
"a",
74+
partial(default_context_factory, {"D": "E"}),
6475
)
6576

6677
assert a_response == {
@@ -179,7 +190,9 @@ async def test_handler__handle():
179190

180191
a_method = CoroutineMock(return_value=("a", "b", "c"))
181192

182-
await Handlers._handle(a_method, {}, a_req)
193+
await Handlers._handle(
194+
a_method, a_req, partial(default_context_factory, {})
195+
)
183196

184197
assert a_method.call_args_list == [((a_req,),)]
185198

0 commit comments

Comments
 (0)