Skip to content

Commit 44a4aed

Browse files
committed
Add InsecureClient and reorganize code
1 parent a2ae8bb commit 44a4aed

File tree

1 file changed

+70
-19
lines changed

1 file changed

+70
-19
lines changed

authzed/api/v1/__init__.py

+70-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Callable, Any
12
import asyncio
23

34
import grpc
45
import grpc.aio
56

7+
from grpc_interceptor import ClientCallDetails, ClientInterceptor
8+
69
from authzed.api.v1.core_pb2 import (
710
AlgebraicSubjectSet,
811
ContextualizedCaveat,
@@ -70,47 +73,95 @@ class Client(SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub,
7073
"""
7174

7275
def __init__(self, target, credentials, options=None, compression=None):
76+
channel = self.create_channel(target, credentials, options, compression)
77+
self.init_stubs(channel)
78+
79+
def init_stubs(self, channel):
80+
SchemaServiceStub.__init__(self, channel)
81+
PermissionsServiceStub.__init__(self, channel)
82+
ExperimentalServiceStub.__init__(self, channel)
83+
WatchServiceStub.__init__(self, channel)
84+
85+
def create_channel(self, target, credentials, options=None, compression=None):
7386
try:
7487
asyncio.get_running_loop()
7588
channelfn = grpc.aio.secure_channel
7689
except RuntimeError:
7790
channelfn = grpc.secure_channel
7891

79-
channel = channelfn(target, credentials, options, compression)
80-
SchemaServiceStub.__init__(self, channel)
81-
PermissionsServiceStub.__init__(self, channel)
82-
ExperimentalServiceStub.__init__(self, channel)
83-
WatchServiceStub.__init__(self, channel)
92+
return channelfn(target, credentials, options, compression)
8493

8594

86-
class AsyncClient(
87-
SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub, WatchServiceStub
88-
):
95+
class AsyncClient(Client):
8996
"""
9097
v1 Authzed gRPC API client, for use with asyncio.
9198
"""
9299

93100
def __init__(self, target, credentials, options=None, compression=None):
94101
channel = grpc.aio.secure_channel(target, credentials, options, compression)
95-
SchemaServiceStub.__init__(self, channel)
96-
PermissionsServiceStub.__init__(self, channel)
97-
ExperimentalServiceStub.__init__(self, channel)
98-
WatchServiceStub.__init__(self, channel)
102+
self.init_stubs(channel)
99103

100104

101-
class SyncClient(
102-
SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub, WatchServiceStub
103-
):
105+
class SyncClient(Client):
104106
"""
105107
v1 Authzed gRPC API client, running synchronously.
106108
"""
107109

108110
def __init__(self, target, credentials, options=None, compression=None):
109111
channel = grpc.secure_channel(target, credentials, options, compression)
110-
SchemaServiceStub.__init__(self, channel)
111-
PermissionsServiceStub.__init__(self, channel)
112-
ExperimentalServiceStub.__init__(self, channel)
113-
WatchServiceStub.__init__(self, channel)
112+
self.init_stubs(channel)
113+
114+
115+
class TokenAuthorization(ClientInterceptor):
116+
def __init__(self, token: str):
117+
self._token = token
118+
119+
def intercept(
120+
self,
121+
method: Callable,
122+
request_or_iterator: Any,
123+
call_details: grpc.ClientCallDetails,
124+
):
125+
metadata: list[tuple[str, str | bytes]] = [("authorization", f"Bearer {self._token}")]
126+
if call_details.metadata is not None:
127+
metadata = [*metadata, *call_details.metadata]
128+
129+
new_details = ClientCallDetails(
130+
call_details.method,
131+
call_details.timeout,
132+
metadata,
133+
call_details.credentials,
134+
call_details.wait_for_ready,
135+
call_details.compression,
136+
)
137+
138+
return method(request_or_iterator, new_details)
139+
140+
141+
class InsecureClient(Client):
142+
"""
143+
An insecure client variant for non-TLS contexts.
144+
145+
The default behavior of the python gRPC client is to restrict non-TLS
146+
calls to `localhost` only, which is frustrating in contexts like docker-compose,
147+
so we provide this as a convenience.
148+
"""
149+
150+
def __init__(
151+
self,
152+
target: str,
153+
token: str,
154+
options=None,
155+
compression=None,
156+
):
157+
fake_credentials = grpc.local_channel_credentials()
158+
channel = self.create_channel(target, fake_credentials, options, compression)
159+
auth_interceptor = TokenAuthorization(token)
160+
161+
insecure_channel = grpc.insecure_channel(target, options, compression)
162+
channel = grpc.intercept_channel(insecure_channel, auth_interceptor)
163+
164+
self.init_stubs(channel)
114165

115166

116167
__all__ = [

0 commit comments

Comments
 (0)