Skip to content

Commit 9c6a2f8

Browse files
Merge pull request #883 from AMWA-TV/fix-mock-auth
Fix authorization issues with mock resources
2 parents 5db9be1 + 969071d commit 9c6a2f8

File tree

6 files changed

+157
-115
lines changed

6 files changed

+157
-115
lines changed

nmostesting/IS10Utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from Crypto.PublicKey import RSA
1616
from authlib.jose import jwt, JsonWebKey
1717

18+
import re
1819
import time
1920
import uuid
2021

2122
from .NMOSUtils import NMOSUtils
2223
from OpenSSL import crypto
2324
from cryptography.hazmat.primitives import serialization
2425
from cryptography import x509
26+
from flask import request
2527
from .TestHelper import get_default_ip, get_mocks_hostname
2628

2729
from . import Config as CONFIG
@@ -157,3 +159,36 @@ def is_any_contain(list, enum):
157159
if item in [e.name for e in enum]:
158160
return True
159161
return False
162+
163+
@staticmethod
164+
def check_authorization(auth, path, scope="x-nmos-registration", write=False):
165+
def _check_path_match(path, path_wildcards):
166+
path_match = False
167+
for path_wildcard in path_wildcards:
168+
pattern = path_wildcard.replace("*", ".*")
169+
if re.search(pattern, path):
170+
path_match = True
171+
break
172+
return path_match
173+
174+
if CONFIG.ENABLE_AUTH:
175+
try:
176+
if "Authorization" not in request.headers:
177+
return 400, "Authorization header not found"
178+
if not request.headers["Authorization"].startswith("Bearer "):
179+
return 400, "Bearer not found in Authorization header"
180+
token = request.headers["Authorization"].split(" ")[1]
181+
claims = jwt.decode(token, auth.generate_jwk())
182+
claims.validate()
183+
if claims["iss"] != auth.make_issuer():
184+
return 401, f"Unexpected issuer, expected: {auth.make_issuer()}, actual: {claims['iss']}"
185+
# TODO: Check 'aud' claim matches 'mocks.<domain>'
186+
if not _check_path_match(path, claims[scope]["read"]):
187+
return 403, f"Paths mismatch for {scope} read claims"
188+
if write and not _check_path_match(path, claims[scope]["write"]):
189+
return 403, f"Paths mismatch for {scope} write claims"
190+
except KeyError as err:
191+
return 400, f"KeyError: {err}"
192+
except Exception as err:
193+
return 400, f"Exception: {err}"
194+
return True, ""

nmostesting/IS12Utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def send_command(self, test, command_json):
171171
for tm in self.ncp_websocket.get_timestamped_messages():
172172
parsed_message = json.loads(tm.message)
173173

174+
if parsed_message is None:
175+
raise NMOSTestException(test.FAIL(
176+
f"Null message received for command: {str(command_json)}",
177+
f"https://specs.amwa.tv/is-12/branches/{self.apis[CONTROL_API_KEY]['spec_branch']}"
178+
"/docs/Protocol_messaging.html#command-message-type"))
179+
174180
if self.message_type_to_schema_name(parsed_message.get("messageType")):
175181
self._validate_is12_schema(
176182
test,
@@ -230,7 +236,7 @@ def get_notifications(self):
230236
# Get any timestamped messages that have arrived in the interim period
231237
for tm in self.ncp_websocket.get_timestamped_messages():
232238
parsed_message = json.loads(tm.message)
233-
if parsed_message["messageType"] == MessageTypes.Notification:
239+
if parsed_message and parsed_message["messageType"] == MessageTypes.Notification:
234240
self.notifications += [IS12Notification(n, tm.received_time)
235241
for n in parsed_message["notifications"]]
236242
return self.notifications

nmostesting/NMOSTesting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
# Primary Authorization server
170170
if CONFIG.ENABLE_AUTH:
171171
auth_app = Flask(__name__)
172+
CORS(auth_app)
172173
auth_app.debug = False
173174
auth_app.config['AUTH_INSTANCE'] = 0
174175
auth_app.config['PORT'] = PRIMARY_AUTH.port

nmostesting/mocks/Auth.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def __init__(self, port_increment, version="v1.0"):
113113
self.host = get_mocks_hostname()
114114
# authorization code of the authorization code flow
115115
self.code = None
116+
self.scopes_cache = {} # remember client scopes
116117

117118
def make_mdns_info(self, priority=0, api_ver=None, ip=None):
118119
"""Get an mDNS ServiceInfo object in order to create an advertisement"""
@@ -302,10 +303,6 @@ def auth_auth():
302303
# Recommended parameters
303304
# state
304305

305-
ctype_valid, ctype_message = check_content_type(request.headers, "application/x-www-form-urlencoded")
306-
if not ctype_valid:
307-
raise AuthException("invalid_request", ctype_message)
308-
309306
# hmm, no client authorization done, just redirects a random authorization code back to the client
310307
# TODO: add web pages for client authorization for the future
311308

@@ -342,6 +339,8 @@ def auth_auth():
342339
if not scope_found:
343340
error = "invalid_request"
344341
error_description = "scope: {} are not supported".format(scopes)
342+
# cache the client scopes
343+
auth.scopes_cache[request.args["client_id"]] = scopes
345344

346345
vars = {}
347346
if error:
@@ -370,7 +369,6 @@ def auth_auth():
370369
def auth_token():
371370
auth = AUTHS[flask.current_app.config["AUTH_INSTANCE"]]
372371
try:
373-
auth_header_required = False
374372
scopes = []
375373

376374
ctype_valid, ctype_message = check_content_type(request.headers, "application/x-www-form-urlencoded")
@@ -395,7 +393,13 @@ def auth_token():
395393

396394
refresh_token = query["refresh_token"][0] if "refresh_token" in query else None
397395

398-
scopes = query["scope"][0].split() if "scope" in query else SCOPE.split() if SCOPE else []
396+
# Scope query parameter is OPTIONAL
397+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2
398+
# and https://datatracker.ietf.org/doc/html/rfc6749#section-6
399+
# Use scopes cached from when the token was created if not provided in query
400+
cached_scopes = auth.scopes_cache[client_id] if client_id in auth.scopes_cache else []
401+
scopes = query["scope"][0].split() if "scope" in query else cached_scopes \
402+
if len(cached_scopes) else SCOPE.split() if SCOPE else []
399403
if scopes:
400404
scope_found = IS10Utils.is_any_contain(scopes, SCOPES)
401405
if not scope_found:
@@ -484,8 +488,6 @@ def auth_token():
484488
else:
485489
raise AuthException("unsupported_grant_type",
486490
"missing client_assertion_type used for private_key_jwt client authentication")
487-
else:
488-
auth_header_required = True
489491

490492
# for the Confidential client, client_id and client_secret are embedded in the Authorization header
491493
auth_header = request.headers.get("Authorization", None)
@@ -504,8 +506,6 @@ def auth_token():
504506
"missing client_id or client_secret from authorization header")
505507
else:
506508
raise AuthException("invalid_client", "invalid authorization header")
507-
elif auth_header_required:
508-
raise AuthException("invalid_client", "invalid authorization header", HTTPStatus.UNAUTHORIZED)
509509

510510
# client_id MUST be provided by all types of client
511511
if not client_id:

nmostesting/mocks/Node.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from .. import Config as CONFIG
2525
from ..TestHelper import get_default_ip, do_request
2626
from ..IS04Utils import IS04Utils
27+
from ..IS10Utils import IS10Utils
28+
from .Auth import PRIMARY_AUTH
2729

2830

2931
class Node(object):
@@ -39,6 +41,7 @@ def reset(self):
3941
self.receivers = {}
4042
self.senders = {}
4143
self.patched_sdp = {}
44+
self.auth_cache = {}
4245

4346
def get_sender(self, media_type="video/raw", version="v1.3"):
4447
protocol = "http"
@@ -360,11 +363,49 @@ def patch_staged(self, resource, resource_id, request_json):
360363

361364
return response_data, response_code
362365

366+
def check_authorization(self, auth, path, scope, write=False):
367+
if not CONFIG.ENABLE_AUTH:
368+
return True, ""
369+
370+
if "Authorization" in request.headers and request.headers["Authorization"].startswith("Bearer ") \
371+
and scope in self.auth_cache and \
372+
((write and self.auth_cache[scope]["Write"]) or self.auth_cache[scope]["Read"]):
373+
return True, ""
374+
375+
authorized, error_message = IS10Utils.check_authorization(auth,
376+
path,
377+
scope=scope,
378+
write=write)
379+
if authorized:
380+
if scope not in self.auth_cache:
381+
self.auth_cache[scope] = {"Read": True, "Write": write}
382+
else:
383+
self.auth_cache[scope]["Read"] = True
384+
self.auth_cache[scope]["Write"] = self.auth_cache[scope]["Write"] or write
385+
return authorized, error_message
386+
363387

364388
NODE = Node(1)
365389
NODE_API = Blueprint('node_api', __name__)
366390

367391

392+
# Authorization decorator
393+
def check_authorization(func):
394+
def wrapper(*args, **kwargs):
395+
write = (request.method == 'PATCH')
396+
authorized, error_message = NODE.check_authorization(PRIMARY_AUTH,
397+
request.path,
398+
scope="x-nmos-connection",
399+
write=write)
400+
if authorized is not True:
401+
abort(authorized, description=error_message)
402+
403+
return func(*args, **kwargs)
404+
# Rename wrapper to allow decoration of decorator
405+
wrapper.__name__ = func.__name__
406+
return wrapper
407+
408+
368409
@NODE_API.route('/x-nmos', methods=['GET'], strict_slashes=False)
369410
def x_nmos_root():
370411
base_data = ['connection/']
@@ -373,27 +414,31 @@ def x_nmos_root():
373414

374415

375416
@NODE_API.route('/x-nmos/connection', methods=['GET'], strict_slashes=False)
417+
@check_authorization
376418
def connection_root():
377419
base_data = ['v1.0/', 'v1.1/']
378420

379421
return make_response(Response(json.dumps(base_data), mimetype='application/json'))
380422

381423

382424
@NODE_API.route('/x-nmos/connection/<version>', methods=['GET'], strict_slashes=False)
425+
@check_authorization
383426
def version(version):
384427
base_data = ['bulk/', 'single/']
385428

386429
return make_response(Response(json.dumps(base_data), mimetype='application/json'))
387430

388431

389432
@NODE_API.route('/x-nmos/connection/<version>/single', methods=['GET'], strict_slashes=False)
433+
@check_authorization
390434
def single(version):
391435
base_data = ['senders/', 'receivers/']
392436

393437
return make_response(Response(json.dumps(base_data), mimetype='application/json'))
394438

395439

396440
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/', methods=["GET"], strict_slashes=False)
441+
@check_authorization
397442
def resources(version, resource):
398443
if resource == 'senders':
399444
base_data = [r + '/' for r in [*NODE.senders]]
@@ -404,6 +449,7 @@ def resources(version, resource):
404449

405450

406451
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>', methods=["GET"], strict_slashes=False)
452+
@check_authorization
407453
def connection(version, resource, resource_id):
408454
if resource != 'senders' and resource != 'receivers':
409455
abort(404)
@@ -440,6 +486,7 @@ def _get_constraints(resource):
440486

441487
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>/constraints',
442488
methods=["GET"], strict_slashes=False)
489+
@check_authorization
443490
def constraints(version, resource, resource_id):
444491
base_data = [_get_constraints(resource)]
445492

@@ -472,6 +519,7 @@ def _check_constraint(constraint, transport_param):
472519

473520
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>/staged',
474521
methods=["GET", "PATCH"], strict_slashes=False)
522+
@check_authorization
475523
def staged(version, resource, resource_id):
476524
"""
477525
GET returns current staged data for given resource
@@ -515,6 +563,7 @@ def staged(version, resource, resource_id):
515563

516564
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>/active',
517565
methods=["GET"], strict_slashes=False)
566+
@check_authorization
518567
def active(version, resource, resource_id):
519568
try:
520569
if resource == 'senders':
@@ -529,6 +578,7 @@ def active(version, resource, resource_id):
529578

530579
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>/transporttype',
531580
methods=["GET"], strict_slashes=False)
581+
@check_authorization
532582
def transport_type(version, resource, resource_id):
533583
# TODO fetch from resource info
534584
base_data = "urn:x-nmos:transport:rtp"
@@ -583,6 +633,7 @@ def node_sdp(media_type, media_subtype):
583633

584634
@NODE_API.route('/x-nmos/connection/<version>/single/<resource>/<resource_id>/transportfile',
585635
methods=["GET"], strict_slashes=False)
636+
@check_authorization
586637
def transport_file(version, resource, resource_id):
587638
# GET should either redirect to the location of the transport file or return it directly
588639
try:

0 commit comments

Comments
 (0)