11# Copyright (c) 2017-2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- from typing import List , Tuple , Union , cast
4+ from typing import Any , AsyncIterable , Callable , Iterable , List , Tuple , TypeVar , Union , cast
55from urllib .parse import urlparse
66
77from grpc import (
1313 metadata_call_credentials ,
1414 ssl_channel_credentials ,
1515)
16- from grpc .aio import Channel , insecure_channel , secure_channel
16+ from grpc .aio import (
17+ Channel ,
18+ ClientCallDetails ,
19+ StreamStreamCall ,
20+ StreamStreamClientInterceptor ,
21+ StreamUnaryCall ,
22+ StreamUnaryClientInterceptor ,
23+ UnaryStreamCall ,
24+ UnaryStreamClientInterceptor ,
25+ UnaryUnaryCall ,
26+ UnaryUnaryClientInterceptor ,
27+ insecure_channel ,
28+ secure_channel ,
29+ )
1730
1831from ..config import Config
1932
2033__all__ = ["create_channel" ]
2134
35+ RequestType = TypeVar ("RequestType" )
36+ RequestIterableType = Union [Iterable [Any ], AsyncIterable [Any ]]
37+ ResponseIterableType = AsyncIterable [Any ]
38+
2239
2340def create_channel (config : "Config" ) -> "Channel" :
2441 """
@@ -55,7 +72,15 @@ def create_channel(config: "Config") -> "Channel":
5572 ),
5673 )
5774 return secure_channel (u .netloc , credentials , tuple (options ))
75+
76+ elif config .access .token_version is not None :
77+ # Python/C++ libraries refuse to allow "credentials" objects to be passed around on
78+ # non-TLS channels, but they don't check interceptors; use an interceptor to inject
79+ # an Authorization header instead
80+ return insecure_channel (u .netloc , options , interceptors = [GrpcAuthInterceptor (config )])
81+
5882 else :
83+ # no TLS, no tokens--simply create an insecure channel with no adornments
5984 return insecure_channel (u .netloc , options )
6085
6186
@@ -74,3 +99,68 @@ def __call__(self, context: "AuthMetadataContext", callback: "AuthMetadataPlugin
7499 options .append (("authorization" , "Bearer " + self ._config .access .token ))
75100
76101 callback (tuple (options ), None )
102+
103+
104+ class GrpcAuthInterceptor (
105+ UnaryUnaryClientInterceptor ,
106+ UnaryStreamClientInterceptor ,
107+ StreamUnaryClientInterceptor ,
108+ StreamStreamClientInterceptor ,
109+ ):
110+ """
111+ An interceptor that injects "Authorization" metadata into a request.
112+
113+ This works around the fact that the C++ gRPC libraries (which Python is built on) highly
114+ discourage sending authorization data over the wire unless the connection is protected with TLS.
115+ """
116+
117+ # NOTE: There are a number of typing errors in the grpc.aio classes, so we're ignoring a handful
118+ # of lines until those problems are addressed.
119+
120+ def __init__ (self , config : "Config" ):
121+ self ._config = config
122+
123+ async def intercept_unary_unary (
124+ self ,
125+ continuation : "Callable[[ClientCallDetails, RequestType], UnaryUnaryCall]" ,
126+ client_call_details : ClientCallDetails ,
127+ request : RequestType ,
128+ ) -> "Union[UnaryUnaryCall, RequestType]" :
129+ return await continuation (self ._modify_client_call_details (client_call_details ), request )
130+
131+ async def intercept_unary_stream (
132+ self ,
133+ continuation : "Callable[[ClientCallDetails, RequestType], UnaryStreamCall]" ,
134+ client_call_details : ClientCallDetails ,
135+ request : RequestType ,
136+ ) -> "Union[ResponseIterableType, UnaryStreamCall]" :
137+ return await continuation (self ._modify_client_call_details (client_call_details ), request )
138+
139+ async def intercept_stream_unary (
140+ self ,
141+ continuation : "Callable[[ClientCallDetails, RequestType], StreamUnaryCall]" ,
142+ client_call_details : ClientCallDetails ,
143+ request_iterator : RequestIterableType ,
144+ ) -> StreamUnaryCall :
145+ return await continuation (
146+ self ._modify_client_call_details (client_call_details ), request_iterator # type: ignore
147+ )
148+
149+ async def intercept_stream_stream (
150+ self ,
151+ continuation : Callable [[ClientCallDetails , RequestType ], StreamStreamCall ],
152+ client_call_details : ClientCallDetails ,
153+ request_iterator : RequestIterableType ,
154+ ) -> "Union[ResponseIterableType, StreamStreamCall]" :
155+ return await continuation (
156+ self ._modify_client_call_details (client_call_details ), request_iterator # type: ignore
157+ )
158+
159+ def _modify_client_call_details (self , client_call_details : ClientCallDetails ):
160+ if (
161+ "authorization" not in client_call_details .metadata
162+ and self ._config .access .token_version is not None
163+ ):
164+ client_call_details .metadata .add ("authorization" , f"Bearer { self ._config .access .token } " )
165+
166+ return client_call_details
0 commit comments