diff --git a/lego/apps/oauth/models.py b/lego/apps/oauth/models.py index c1e6022a3..e9c610333 100644 --- a/lego/apps/oauth/models.py +++ b/lego/apps/oauth/models.py @@ -1,3 +1,6 @@ +import fnmatch +from urllib.parse import urlparse + from django.db import models from oauth2_provider.models import AbstractApplication @@ -12,3 +15,63 @@ class APIApplication(AbstractApplication): class Meta: permission_handler = APIApplicationPermissionHandler() + + def redirect_uri_allowed(self, uri): + """ + Check if a given URI matches one of the allowed redirect URIs. + Supports: + - Space-separated URIs + - Comma-separated URIs + - Wildcards (*) in host and path using fnmatch + + Allowed patterns: + - https://example.com/callback + - https://*.example.com/callback + - https://example.com/* + - https://*.example.com/* + + Not Allowed: + - * + - https://* + - https://*.com + """ + if not uri: + return False + + # Support both space and comma separated URIs + raw_uris = self.redirect_uris.replace(",", " ") + allowed_uris = [u.strip() for u in raw_uris.split() if u.strip()] + + parsed_uri = urlparse(uri) + if not parsed_uri.scheme or not parsed_uri.netloc: + return False + + for allowed_uri in allowed_uris: + parsed_allowed = urlparse(allowed_uri) + + # Check to avoid universal links such as https://*, to avoid wildcard abuse + allowed_host = parsed_allowed.hostname or "" + if allowed_host == "*" or ( + allowed_host.startswith("*.") and len(allowed_host.split(".")) < 3 + ): + continue + + if parsed_allowed.scheme != parsed_uri.scheme: + continue + + if not fnmatch.fnmatch( + parsed_uri.hostname or "", parsed_allowed.hostname or "" + ): + continue + + allowed_path = parsed_allowed.path or "/" + uri_path = parsed_uri.path or "/" + if not fnmatch.fnmatch(uri_path, allowed_path): + continue + + if parsed_allowed.port and parsed_allowed.port != parsed_uri.port: + continue + + return True + + return False diff --git a/lego/apps/oauth/tests/test_models.py b/lego/apps/oauth/tests/test_models.py index dbc396578..8d8df58ee 100644 --- a/lego/apps/oauth/tests/test_models.py +++ b/lego/apps/oauth/tests/test_models.py @@ -11,3 +11,95 @@ def test_initial_application(self): api_app = APIApplication.objects.get(pk=1) self.assertTrue(len(api_app.description)) self.assertEqual(api_app.authorization_grant_type, Application.GRANT_PASSWORD) + + +class RedirectURIAllowedTestCase(BaseTestCase): + + def setUp(self): + self.app = APIApplication( + name="Test App", + client_type=APIApplication.CLIENT_PUBLIC, + authorization_grant_type=APIApplication.GRANT_AUTHORIZATION_CODE, + ) + + def test_exact_match(self): + self.app.redirect_uris = "https://example.com/callback" + self.assertTrue(self.app.redirect_uri_allowed("https://example.com/callback")) + self.assertFalse(self.app.redirect_uri_allowed("https://example.com/other")) + + def test_space_separated_uris(self): + self.app.redirect_uris = "https://a.com/cb https://b.com/cb" + self.assertTrue(self.app.redirect_uri_allowed("https://a.com/cb")) + self.assertTrue(self.app.redirect_uri_allowed("https://b.com/cb")) + self.assertFalse(self.app.redirect_uri_allowed("https://c.com/cb")) + + def test_comma_separated_uris(self): + self.app.redirect_uris = "https://a.com/cb, https://b.com/cb" + self.assertTrue(self.app.redirect_uri_allowed("https://a.com/cb")) + self.assertTrue(self.app.redirect_uri_allowed("https://b.com/cb")) + self.assertFalse(self.app.redirect_uri_allowed("https://c.com/cb")) + + def test_mixed_comma_space_separated(self): + self.app.redirect_uris = "https://a.com/cb, https://b.com/cb https://c.com/cb" + self.assertTrue(self.app.redirect_uri_allowed("https://a.com/cb")) + self.assertTrue(self.app.redirect_uri_allowed("https://b.com/cb")) + self.assertTrue(self.app.redirect_uri_allowed("https://c.com/cb")) + + def test_path_wildcard(self): + self.app.redirect_uris = "https://example.com/*" + self.assertTrue(self.app.redirect_uri_allowed("https://example.com/callback")) + self.assertTrue(self.app.redirect_uri_allowed("https://example.com/any/path")) + self.assertFalse(self.app.redirect_uri_allowed("https://other.com/callback")) + + def test_subdomain_wildcard(self): + self.app.redirect_uris = "https://*.example.com/callback" + self.assertTrue( + self.app.redirect_uri_allowed("https://app.example.com/callback") + ) + self.assertTrue( + self.app.redirect_uri_allowed("https://api.example.com/callback") + ) + self.assertFalse(self.app.redirect_uri_allowed("https://example.com/callback")) + self.assertFalse( + self.app.redirect_uri_allowed("https://app.other.com/callback") + ) + + def test_combined_wildcards(self): + self.app.redirect_uris = "https://*.example.com/*" + self.assertTrue( + self.app.redirect_uri_allowed("https://app.example.com/callback") + ) + self.assertTrue( + self.app.redirect_uri_allowed("https://api.example.com/any/path") + ) + self.assertFalse(self.app.redirect_uri_allowed("https://example.com/callback")) + + def test_scheme_must_match(self): + self.app.redirect_uris = "https://example.com/callback" + self.assertFalse(self.app.redirect_uri_allowed("http://example.com/callback")) + + def test_port_matching(self): + self.app.redirect_uris = "https://example.com:8080/callback" + self.assertTrue( + self.app.redirect_uri_allowed("https://example.com:8080/callback") + ) + self.assertFalse( + self.app.redirect_uri_allowed("https://example.com:9090/callback") + ) + self.assertFalse(self.app.redirect_uri_allowed("https://example.com/callback")) + + def test_global_wildcard(self): + self.app.redirect_uris = "https://*" + self.assertFalse(self.app.redirect_uri_allowed("https://example.com/callback")) + self.assertFalse(self.app.redirect_uri_allowed("/")) + self.assertFalse(self.app.redirect_uri_allowed("https://*.com")) + + def test_empty_uri(self): + self.app.redirect_uris = "https://example.com/callback" + self.assertFalse(self.app.redirect_uri_allowed("")) + self.assertFalse(self.app.redirect_uri_allowed(None)) + + def test_invalid_uri(self): + self.app.redirect_uris = "https://example.com/callback" + self.assertFalse(self.app.redirect_uri_allowed("not-a-valid-uri")) + self.assertFalse(self.app.redirect_uri_allowed("/relative/path"))