Skip to content

Commit 78fcde5

Browse files
authored
[REF] error handling compose (#1146)
* make sure headers are added * test cors headers * handle all errors in compose
1 parent 1dcbad6 commit 78fcde5

File tree

6 files changed

+78
-24
lines changed

6 files changed

+78
-24
lines changed

compose/backend/neurosynth_compose/__init__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33

44
import connexion
55
from authlib.integrations.flask_client import OAuth
6-
from connexion.middleware import MiddlewarePosition
76
from connexion.resolver import MethodResolver
8-
from flask_cors import CORS
97
from starlette.middleware.cors import CORSMiddleware
108

119
from .database import init_db
12-
from .resources.auth import handle_auth_error, AuthError
10+
from .resources.auth import AuthError, handle_auth_error
1311

1412

1513
def create_app():
@@ -32,9 +30,7 @@ def create_app():
3230
"APIKEYINFO_FUNC", app.config.get("APIKEYINFO_FUNC", default_apikey)
3331
)
3432

35-
connexion_app.add_middleware(
36-
CORSMiddleware,
37-
position=MiddlewarePosition.BEFORE_ROUTING,
33+
cors_kwargs = dict(
3834
allow_origins=["*"],
3935
allow_credentials=True,
4036
allow_methods=["*"],
@@ -79,10 +75,12 @@ def create_app():
7975
init_db(app)
8076

8177
app.secret_key = app.config["JWT_SECRET_KEY"]
82-
CORS(app)
8378

8479
app.register_error_handler(AuthError, handle_auth_error)
8580

81+
cors_asgi_app = CORSMiddleware(connexion_app, **cors_kwargs)
82+
8683
app.extensions["connexion_app"] = connexion_app
84+
app.extensions["connexion_asgi"] = cors_asgi_app
8785

8886
return app

compose/backend/neurosynth_compose/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .__init__ import create_app
44

55
app = create_app()
6+
asgi_app = app.extensions["connexion_asgi"]
67

78
ext_celery = FlaskCeleryExt(app)
89
celery_app = ext_celery.celery

compose/backend/neurosynth_compose/tests/request_utils.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
from functools import partialmethod
33

4+
from starlette.testclient import TestClient as StarletteTestClient
5+
46

57
class Client(object):
68
def __init__(self, token, test_client=None, prepend="", username=None):
@@ -9,32 +11,26 @@ def __init__(self, token, test_client=None, prepend="", username=None):
911
if test_client is None:
1012
from flask import current_app as app
1113

14+
asgi_app = app.extensions.get("connexion_asgi")
1215
connexion_app = app.extensions.get("connexion_app")
1316

14-
if connexion_app is not None:
15-
target_app = connexion_app
16-
if not getattr(target_app, "test_client", None):
17-
target_app = getattr(target_app, "app", target_app)
18-
if not getattr(target_app, "test_client", None):
19-
target_app = getattr(target_app, "_app", target_app)
20-
if not getattr(target_app, "test_client", None):
21-
target_app = getattr(target_app, "app", target_app)
22-
test_client = target_app.test_client()
17+
if asgi_app is not None:
18+
test_client = StarletteTestClient(asgi_app)
19+
elif connexion_app is not None and hasattr(connexion_app, "test_client"):
20+
test_client = connexion_app.test_client()
2321
else:
2422
test_client = app.test_client()
2523

26-
if hasattr(test_client, "open"):
27-
self.client_mode = "flask"
28-
else:
29-
self.client_mode = "requests"
24+
self.client_flask = hasattr(test_client, "open")
25+
self.client_mode = "flask" if self.client_flask else "requests"
3026

3127
self.client = test_client
3228
self.prepend = prepend
3329
self.token = token
3430
self.username = username
3531

3632
def close(self):
37-
if self.client_flask and hasattr(self.client, "close"):
33+
if hasattr(self.client, "close"):
3834
self.client.close()
3935

4036
def _get_headers(self):
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pytest
2+
from starlette.testclient import TestClient
3+
from werkzeug.routing import Rule
4+
5+
from .request_utils import Client
6+
7+
8+
@pytest.fixture(scope="module")
9+
def cors_test_client(app):
10+
endpoint_name = "__cors_test_status"
11+
if endpoint_name not in app.view_functions:
12+
from flask import jsonify
13+
14+
def _cors_status(status_code):
15+
if status_code == 204:
16+
return "", 204
17+
return jsonify({"status": status_code}), status_code
18+
19+
app.url_map.add(
20+
Rule(
21+
"/__test/cors/<int:status_code>",
22+
endpoint=endpoint_name,
23+
methods=["GET"],
24+
)
25+
)
26+
app.view_functions[endpoint_name] = _cors_status
27+
28+
client = TestClient(app.extensions["connexion_asgi"])
29+
try:
30+
yield client
31+
finally:
32+
client.close()
33+
34+
35+
@pytest.fixture
36+
def anonymous_client():
37+
return Client(token=None)
38+
39+
40+
@pytest.mark.parametrize(
41+
"client_fixture, method, path, expected_status",
42+
[
43+
("auth_client", "get", "/api/specifications?page_size=1", 200),
44+
("cors_test_client", "get", "/__test/cors/204", 204),
45+
("cors_test_client", "get", "/__test/cors/400", 400),
46+
("anonymous_client", "post", "/api/annotations", 401),
47+
("cors_test_client", "get", "/__test/cors/404", 404),
48+
("cors_test_client", "get", "/__test/cors/420", 420),
49+
("cors_test_client", "get", "/__test/cors/500", 500),
50+
],
51+
)
52+
def test_cors_headers_present(
53+
client_fixture, method, path, expected_status, request, user_data
54+
):
55+
client = request.getfixturevalue(client_fixture)
56+
headers = {"Origin": "https://client.example"}
57+
response = getattr(client, method)(path, headers=headers)
58+
59+
assert response.status_code == expected_status
60+
assert response.headers.get("Access-Control-Allow-Origin") == "*"

compose/docker-compose.dev.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ services:
1111
compose:
1212
expose:
1313
- "8000"
14-
command: /usr/local/bin/gunicorn -w 2 -b :8000 neurosynth_compose.core:app --log-level debug --timeout 300 --reload
14+
command: /usr/local/bin/gunicorn -k uvicorn.workers.UvicornWorker --reload -w 2 -b :8000 neurosynth_compose.core:asgi_app --log-level debug --timeout 300
1515
restart: "no"
1616

1717
compose_worker:
1818
volumes:
1919
- ./neurosynth_compose:/usr/local/lib/python3.10/site-packages/neurosynth_compose
20-

compose/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ services:
1717
- ./backend/migrations:/migrations
1818
- ./:/compose
1919
- ${FILE_DIR}:/file-data
20-
command: /usr/local/bin/gunicorn -w 2 -b :8000 neurosynth_compose.core:app --log-level debug --timeout 120
20+
command: /usr/local/bin/gunicorn -k uvicorn.workers.UvicornWorker -w 2 -b :8000 neurosynth_compose.core:asgi_app --log-level debug --timeout 120
2121
env_file:
2222
- .env
2323

0 commit comments

Comments
 (0)