diff --git a/src/microdot/csrf.py b/src/microdot/csrf.py new file mode 100644 index 0000000..a5b047a --- /dev/null +++ b/src/microdot/csrf.py @@ -0,0 +1,154 @@ +import binascii +import hashlib +import hmac +import os +import time + +from microdot import abort + + +class CSRF: + """CSRF protection for Microdot routes. + + This class adds CSRF protection to all requests that use state changing + verbs (all methods except ``GET``, ``QUERY``, ``HEAD`` and ``OPTIONS``). + + :param app: The application instance. + :param secret_key: The secret key for token signing, as a string or bytes + object. + :param time_limit: the duration of the CSRF token, in seconds. Defaults to + one hour. + :param cookie_options: A dictionary with cookie options to pass as + arguments to :meth:`Response.set_cookie() + `. + :param protect_all: If ``True``, all state changing routes are protected by + default. If ``False``, only routes decorated with the + :meth:`protect` decorator are protected. + + The CSRF token is returned to the client in a ``csrf_token`` cookie, and + the client is expected to send this token back in requests that have CSRF + protection enabled. If the request includes a form submission, then the + token can be added as a ``csrf_token`` form field. Otherwise, the client + must include it in a ``X-CSRF-Token`` header. + """ + def __init__(self, app=None, secret_key=None, time_limit=60 * 60, + cookie_options=None, protect_all=True): + self.app = None + if isinstance(secret_key, str): + self.secret_key = secret_key.encode() + else: + self.secret_key = secret_key + self.time_limit = time_limit + self.cookie_options = cookie_options or {} + self.protect_all = protect_all + self.exempt_routes = [] + self.protected_routes = [] + if app is not None: + self.initialize(app, secret_key) + + def generate_csrf_token(self): + csrf_token = binascii.hexlify(os.urandom(20)) + expiration = str(time.time() + self.time_limit).encode() + signature = hmac.new(self.secret_key, csrf_token + b'$' + expiration, + hashlib.sha256).hexdigest() + return csrf_token.decode() + '$' + expiration.decode() + '$' + \ + signature + + def validate_csrf_token(self, token): + try: + csrf_token, expiration, signature = token.split('$') + if time.time() > float(expiration): + return False + except Exception: + return False + return hmac.new(self.secret_key, + (csrf_token + '$' + expiration).encode(), + hashlib.sha256).hexdigest() == signature + + def initialize(self, app, secret_key=None, time_limit=None, + cookie_options=None, protect_all=None): + """Initialize the CSRF class. + + :param app: The application instance. + :param secret_key: The secret key for token signing, as a string or + bytes object. + :param time_limit: the duration of the CSRF token, in seconds. Defaults + to one hour. + :param cookie_options: A dictionary with cookie options to pass as + arguments to :meth:`Response.set_cookie() + `. + :param protect_all: If ``True``, all state changing routes are + protected by default. If ``False``, only routes + decorated with the :meth:`protect` decorator are + protected. + """ + self.app = app + if secret_key is not None: + if isinstance(secret_key, str): + self.secret_key = secret_key.encode() + else: + self.secret_key = secret_key + if time_limit is not None: + self.time_limit = time_limit + if cookie_options is not None: + self.cookie_options = cookie_options + if protect_all is not None: + self.protect_all = protect_all + + @self.app.before_request + async def csrf_before_request(request): + if ( + self.protect_all + and request.method not in ['GET', 'QUERY', 'HEAD', 'OPTIONS'] + and request.route not in self.exempt_routes + ) or request.route in self.protected_routes: + # ensure that a valid CSRF token was provided + csrf_token = None + if request.method == 'POST' and request.form and \ + 'csrf_token' in request.form: + csrf_token = request.form['csrf_token'] + else: + csrf_token = request.headers.get('X-CSRF-Token') + if not self.validate_csrf_token(csrf_token): + abort(403, 'Invalid CSRF token') + + @self.app.after_request + def csrf_after_request(request, response): + if 'csrf_token' not in request.cookies: + options = self.cookie_options + if 'secure' not in options and request.scheme == 'https': + options['secure'] = True + response.set_cookie('csrf_token', self.generate_csrf_token(), + **options) + + def exempt(self, f): + """Decorator to exempt a route from CSRF protection. + + This decorator must be added immediately after the route decorator to + disable CSRF protection on the route. Example:: + + @app.post('/submit') + @csrf.exempt + # add additional decorators here + def submit(request): + # ... + """ + self.exempt_routes.append(f) + return f + + def protect(self, f): + """Decorator to protect a route against CSRF attacks. + + This is useful when it is necessary to protect a request that uses one + of the safe methods that are not supposed to make state changes. The + decorator must be added immediately after the route decorator to + disable CSRF protection on the route. Example:: + + @app.get('/data') + @csrf.force + # add additional decorators here + def get_data(request): + # ... + """ + self.protected_routes.append(f) + return f diff --git a/src/microdot/microdot.py b/src/microdot/microdot.py index 727fecc..1ffcff6 100644 --- a/src/microdot/microdot.py +++ b/src/microdot/microdot.py @@ -321,7 +321,7 @@ class G: def __init__(self, app, client_addr, method, url, http_version, headers, body=None, stream=None, sock=None, url_prefix='', - subapp=None, scheme=None): + subapp=None, scheme=None, route=None): #: The application instance to which this request belongs. self.app = app #: The address of the client, as a tuple (host, port). @@ -339,6 +339,8 @@ def __init__(self, app, client_addr, method, url, http_version, headers, #: The sub-application instance, or `None` if this isn't a mounted #: endpoint. self.subapp = subapp + #: The route function that handles this request. + self.route = route #: The path portion of the URL. self.path = url #: The query string portion of the URL. @@ -1429,6 +1431,8 @@ async def dispatch_request(self, req): try: res = None if callable(f): + req.route = f + # invoke the before request handlers for handler in self.get_request_handlers( req, 'before_request', False): diff --git a/src/microdot/session.py b/src/microdot/session.py index cbfb0a8..07de334 100644 --- a/src/microdot/session.py +++ b/src/microdot/session.py @@ -23,7 +23,7 @@ def delete(self): class Session: - """ + """Session handling :param app: The application instance. :param secret_key: The secret key, as a string or bytes object. :param cookie_options: A dictionary with cookie options to pass as diff --git a/tests/test_csrf.py b/tests/test_csrf.py new file mode 100644 index 0000000..185c0e3 --- /dev/null +++ b/tests/test_csrf.py @@ -0,0 +1,175 @@ +import asyncio +import time +import unittest +from microdot import Microdot +from microdot.csrf import CSRF +from microdot.test_client import TestClient + + +class TestMicrodot(unittest.TestCase): + @classmethod + def setUpClass(cls): + if hasattr(asyncio, 'set_event_loop'): + asyncio.set_event_loop(asyncio.new_event_loop()) + cls.loop = asyncio.get_event_loop() + + def _run(self, coro): + return self.loop.run_until_complete(coro) + + def test_protect_all_true(self): + app = Microdot() + csrf = CSRF(app, 'top-secret') + + @app.get('/') + def index(request): + return 204 + + @app.post('/submit') + def submit(request): + return 204 + + @app.post('/submit-exempt') + @csrf.exempt + def submit_exempt(request): + return 204 + + client = TestClient(app) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 204) + csrf_token = client.cookies['csrf_token'] + + res = self._run(client.post('/submit')) + self.assertEqual(res.status_code, 403) + + res = self._run(client.post( + '/submit', + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + body='csrf_token=' + csrf_token, + )) + self.assertEqual(res.status_code, 204) + + res = self._run(client.post( + '/submit', + headers={'X-CSRF-Token': csrf_token}, + )) + self.assertEqual(res.status_code, 204) + + res = self._run(client.post('/submit-exempt')) + self.assertEqual(res.status_code, 204) + + def test_protect_all_false(self): + app = Microdot() + csrf = CSRF() + csrf.initialize(app, b'top-secret', protect_all=False) + + @app.get('/') + def index(request): + return 204 + + @app.post('/submit') + @csrf.protect + def submit(request): + return 204 + + @app.post('/submit-exempt') + def submit_exempt(request): + return 204 + + client = TestClient(app) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 204) + csrf_token = client.cookies['csrf_token'] + + res = self._run(client.post('/submit')) + self.assertEqual(res.status_code, 403) + + res = self._run(client.post( + '/submit', + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + body='csrf_token=' + csrf_token, + )) + self.assertEqual(res.status_code, 204) + + res = self._run(client.post( + '/submit', + headers={'X-CSRF-Token': csrf_token}, + )) + self.assertEqual(res.status_code, 204) + + res = self._run(client.post('/submit-exempt')) + self.assertEqual(res.status_code, 204) + + def test_initialize_with_defaults(self): + app = Microdot() + csrf = CSRF(secret_key='top-secret', time_limit=600, + cookie_options={'path': '/'}) + csrf.initialize(app) + self.assertEqual(csrf.secret_key, b'top-secret') + self.assertEqual(csrf.time_limit, 600) + self.assertEqual(csrf.cookie_options, {'path': '/'}) + self.assertEqual(csrf.protect_all, True) + + def test_initialize_with_overrides(self): + app = Microdot() + csrf = CSRF(secret_key='top-secret', time_limit=600, + cookie_options={'path': '/'}) + csrf.initialize(app, secret_key=b'another-key', time_limit=1200, + cookie_options={}, protect_all=False) + self.assertEqual(csrf.secret_key, b'another-key') + self.assertEqual(csrf.time_limit, 1200) + self.assertEqual(csrf.cookie_options, {}) + self.assertEqual(csrf.protect_all, False) + + def test_token_expired(self): + app = Microdot() + CSRF(app, 'top-secret', time_limit=0.25) + + @app.get('/') + def index(request): + return 204 + + @app.post('/submit') + def submit(request): + return 204 + + client = TestClient(app) + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 204) + csrf_token = client.cookies['csrf_token'] + + res = self._run(client.post( + '/submit', + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + body='csrf_token=' + csrf_token, + )) + self.assertEqual(res.status_code, 204) + + time.sleep(0.25) + res = self._run(client.post( + '/submit', + headers={'Content-Type': 'application/x-www-form-urlencoded'}, + body='csrf_token=' + csrf_token, + )) + self.assertEqual(res.status_code, 403) + + def test_cookie_is_secure(self): + app = Microdot() + CSRF(app, 'top-secret', time_limit=0.25) + + @app.get('/') + def index(request): + print(request.url) + return 204 + + @app.post('/submit') + def submit(request): + return 204 + + client = TestClient(app, scheme='https') + + res = self._run(client.get('/')) + self.assertEqual(res.status_code, 204) + self.assertIn('; Secure', res.headers['Set-Cookie'][0])