diff --git a/src/pyramid/csrf.py b/src/pyramid/csrf.py index 304ce4020..c9c65b3d0 100644 --- a/src/pyramid/csrf.py +++ b/src/pyramid/csrf.py @@ -1,5 +1,5 @@ +import secrets from urllib.parse import urlparse -import uuid from webob.cookies import CookieProfile from zope.interface import implementer @@ -64,7 +64,7 @@ class SessionCSRFStoragePolicy: """ - _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex)) + _token_factory = staticmethod(lambda: text_(secrets.token_hex())) def __init__(self, key='_csrft_'): self.key = key @@ -109,7 +109,7 @@ class CookieCSRFStoragePolicy: """ - _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex)) + _token_factory = staticmethod(lambda: text_(secrets.token_hex())) def __init__( self, @@ -161,6 +161,55 @@ def check_csrf_token(self, request, supplied_token): ) +@implementer(ICSRFStoragePolicy) +class HttpHeaderCSRFStoragePolicy: + """A CSRF storage policy that persists the CSRF token in an HTTP header. + + ``header_name`` + + The header name in which the CSRF token will be stored. + Default: `X-CSRF-Token`. + + .. versionadded: 2.0.3 + + """ + + _token_factory = staticmethod(lambda: text_(secrets.token_hex())) + + def __init__(self, header_name='X-CSRF-Token'): + self.header_name = header_name + + def new_csrf_token(self, request): + """Sets a new CSRF token into the header and returns it.""" + token = self._token_factory() + request.headers[self.header_name] = token + + def set_header(request, response): + response.headers.add(self.header_name, token) + + request.add_response_callback(set_header) + + return token + + def get_csrf_token(self, request): + """Returns the currently active CSRF token from the header, + generating a new one if needed.""" + token = request.headers.get(self.header_name) + + if not token: + token = self.new_csrf_token(request) + + return token + + def check_csrf_token(self, request, supplied_token): + """Returns ``True`` if the ``supplied_token`` is valid.""" + expected_token = self.get_csrf_token(request) + + return not strings_differ( + bytes_(expected_token), bytes_(supplied_token) + ) + + def get_csrf_token(request): """Get the currently active CSRF token for the request passed, generating a new one using ``new_csrf_token(request)`` if one does not exist. This