diff --git a/docs/quickstart.md b/docs/quickstart.md index 5ffad1c..dfa12d9 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -160,9 +160,22 @@ def logout(): return "You've been successfully logged out!" ``` -If the logout view is mounted under a custom endpoint (other than the default, which is -[the name of the view function](https://flask.palletsprojects.com/en/2.0.x/api/#flask.Flask.route)), or if using Blueprints, you -must specify the full URL in the Flask-pyoidc configuration using `post_logout_redirect_uris`: +If you are using Blueprints to create routes, you can provide `logout_view` argument which takes +[name of the view function](https://flask.palletsprojects.com/en/2.0.x/api/#flask.Flask.route) as parameter. This +argument is used to resolve URL for `post_logout_redirect_uris`. +```python +from flask import Blueprint + +blueprint = Blueprint(name='api', import_name=__name__) + +@blueprint.route('/logout') +@auth.oidc_logout(logout_view='api.logout') +def logout(): + return "You've been successfully logged out!" +``` + +`logout_view` argument is optional to provide in the decorator because you can directly specify +`post_logout_redirect_uris` as complete URL in the Flask-pyoidc configuration: ```python ClientMetadata(..., post_logout_redirect_uris=['https://example.com/post_logout']) # if using static client registration ClientRegistrationInfo(..., post_logout_redirect_uris=['https://example.com/post_logout']) # if using dynamic client registration diff --git a/src/flask_pyoidc/flask_pyoidc.py b/src/flask_pyoidc/flask_pyoidc.py index 2e09765..3bf7aa0 100644 --- a/src/flask_pyoidc/flask_pyoidc.py +++ b/src/flask_pyoidc/flask_pyoidc.py @@ -90,7 +90,7 @@ def init_app(self, app): def _get_urls_for_logout_views(self): try: - return [url_for(view.__name__, _external=True) for view in self._logout_views] + return [url_for(view, _external=True) for view in self._logout_views] except BuildError: logger.error('could not build url for logout view, it might be mounted under a custom endpoint') raise @@ -260,25 +260,32 @@ def _logout(self, post_logout_redirect_uri): return redirect(end_session_request.request(client.provider_end_session_endpoint), 303) return None - def oidc_logout(self, view_func): - self._logout_views.append(view_func) + def oidc_logout(self, logout_view: Optional[str] = None): - @functools.wraps(view_func) - def wrapper(*args, **kwargs): - if 'state' in flask.request.args: - # returning redirect from provider - if flask.request.args['state'] != flask.session.pop('end_session_state', None): - logger.error("Got unexpected state '%s' after logout redirect.", flask.request.args['state']) - return view_func(*args, **kwargs) + def logout_decorator(view_func): + + @functools.wraps(view_func) + def wrapper(*args, **kwargs): + if 'state' in flask.request.args: + # returning redirect from provider + if flask.request.args['state'] != flask.session.pop('end_session_state', None): + logger.error("Got unexpected state '%s' after logout redirect.", flask.request.args['state']) + return view_func(*args, **kwargs) + + post_logout_redirect_uri = flask.request.url + redirect_to_provider = self._logout(post_logout_redirect_uri) + if redirect_to_provider: + return redirect_to_provider - post_logout_redirect_uri = flask.request.url - redirect_to_provider = self._logout(post_logout_redirect_uri) - if redirect_to_provider: - return redirect_to_provider + return view_func(*args, **kwargs) - return view_func(*args, **kwargs) + return wrapper - return wrapper + if callable(logout_view): + self._logout_views.append(logout_view.__name__) + return logout_decorator(logout_view) + self._logout_views.append(logout_view) + return logout_decorator def error_view(self, view_func): self._error_view = view_func diff --git a/tests/test_flask_pyoidc.py b/tests/test_flask_pyoidc.py index 333a1ad..a07524a 100644 --- a/tests/test_flask_pyoidc.py +++ b/tests/test_flask_pyoidc.py @@ -662,6 +662,21 @@ def test_logout_handles_no_user_session(self): self.assert_view_mock(logout_view_mock, result) + def test_oidc_logout_when_endpoint_name_is_provided(self): + authn = self.init_app() + # Decorator with an argument. + view_func1 = authn.oidc_logout(logout_view='logout1')(self.get_view_mock('logout1')) + self.app.add_url_rule('/logout1', 'logout1', view_func=view_func1) + view_func2 = authn.oidc_logout(logout_view='test.logout')(self.get_view_mock('logout2')) + self.app.add_url_rule('/logout2', 'test.logout', view_func=view_func2) + # Decorator without an argument. + view_func3 = authn.oidc_logout(self.get_view_mock('logout3')) + self.app.add_url_rule('/logout3', 'logout3', view_func=view_func3) + + with self.app.app_context(): + assert authn._get_urls_for_logout_views() == [f'http://{self.CLIENT_DOMAIN}{endpoint}' + for endpoint in ('/logout1', '/logout2', '/logout3')] + def test_authentication_error_response_calls_to_error_view_if_set(self): state = 'test_tate' error_response = {'error': 'invalid_request', 'error_description': 'test error'}