Skip to content

Commit 057b5fb

Browse files
authored
Merge pull request #60 from mynhardtburger/case-insensitive-route-info
Bug fix: Make get_route_info() case insensitive
2 parents b7bb118 + 3fe4942 commit 057b5fb

File tree

2 files changed

+155
-18
lines changed

2 files changed

+155
-18
lines changed

caikit_tgis_backend/tgis_backend.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Standard
1717
from copy import deepcopy
1818
from threading import Lock
19-
from typing import Any, Dict, Optional
19+
from typing import Any, Dict, Optional, Tuple, Union
2020

2121
# Third Party
2222
import grpc
@@ -196,6 +196,11 @@ def handle_runtime_context(
196196
{"hostname": route_info},
197197
fill_with_defaults=True,
198198
)
199+
else:
200+
log.debug(
201+
"<TGB32948346D> No %s context override found",
202+
self.ROUTE_INFO_HEADER_KEY,
203+
)
199204

200205
## Backend user interface ##
201206

@@ -351,6 +356,7 @@ def get_route_info(
351356
context: Optional[RuntimeServerContextType],
352357
) -> Optional[str]:
353358
"""Get the string value of the x-route-info header/metadata if present
359+
in a case insensitive manner.
354360
355361
Args:
356362
context (Optional[RuntimeServerContextType]): The grpc or fastapi
@@ -363,9 +369,12 @@ def get_route_info(
363369
if context is None:
364370
return context
365371
if isinstance(context, grpc.ServicerContext):
366-
return dict(context.invocation_metadata()).get(cls.ROUTE_INFO_HEADER_KEY)
372+
return TGISBackend._request_metadata_get(
373+
context.invocation_metadata(), cls.ROUTE_INFO_HEADER_KEY
374+
)
375+
367376
if HAVE_FASTAPI and isinstance(context, fastapi.Request):
368-
return context.headers.get(cls.ROUTE_INFO_HEADER_KEY)
377+
return TGISBackend._request_header_get(context, cls.ROUTE_INFO_HEADER_KEY)
369378
error.log_raise(
370379
"<TGB92615097E>",
371380
TypeError(f"context is of an unsupported type: {type(context)}"),
@@ -415,6 +424,38 @@ def _safely_update_state(
415424
if remote_models_cfg:
416425
self._remote_models_cfg.setdefault(model_id, remote_models_cfg)
417426

427+
@classmethod
428+
def _request_header_get(cls, request: fastapi.Request, key: str) -> Optional[str]:
429+
"""
430+
Returns the first matching value for the header key (case insensitive).
431+
If no matching header was found return None.
432+
"""
433+
# https://github.com/encode/starlette/blob/5ed55c441126687106109a3f5e051176f88cd3e6/starlette/datastructures.py#L543
434+
items: list[Tuple[str, str]] = request.headers.items()
435+
get_header_key = key.lower()
436+
437+
for header_key, header_value in items:
438+
if header_key.lower() == get_header_key:
439+
return header_value
440+
441+
@classmethod
442+
def _request_metadata_get(
443+
cls, metadata: Tuple[str, Union[str, bytes]], key: str
444+
) -> Optional[str]:
445+
"""
446+
Returns the first matching value for the metadata key (case insensitive).
447+
If no matching metadata was found return None.
448+
"""
449+
# https://grpc.github.io/grpc/python/glossary.html#term-metadatum
450+
get_metadata_key = key.lower()
451+
452+
for metadata_key, metadata_value in metadata:
453+
if str(metadata_key).lower() == get_metadata_key:
454+
if isinstance(metadata_value, str):
455+
return metadata_value
456+
if isinstance(metadata_value, bytes):
457+
return metadata_value.decode()
458+
418459

419460
# Register local backend
420461
register_backend_type(TGISBackend)

tests/test_tgis_backend.py

Lines changed: 111 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# Standard
1919
from copy import deepcopy
2020
from dataclasses import asdict
21-
from typing import Any, Dict, Optional, Union
21+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2222
from unittest import mock
2323
import os
2424
import tempfile
@@ -97,17 +97,60 @@ def mock_tgis_fixture():
9797
mock_tgis.stop()
9898

9999

100-
class TestServicerContext:
100+
class TestServicerContext(grpc.ServicerContext):
101101
"""
102102
A dummy class for mimicking ServicerContext invocation metadata storage.
103103
"""
104104

105105
def __init__(self, metadata: Dict[str, Union[str, bytes]]):
106106
self.metadata = metadata
107107

108-
def invocation_metadata(self):
108+
def invocation_metadata(self) -> Sequence[Tuple[str, Union[str, bytes]]]:
109+
# https://grpc.github.io/grpc/python/glossary.html#term-metadata
109110
return list(self.metadata.items())
110111

112+
def is_active(self):
113+
raise NotImplementedError
114+
115+
def time_remaining(self):
116+
raise NotImplementedError
117+
118+
def cancel(self):
119+
raise NotImplementedError
120+
121+
def add_callback(self, callback):
122+
raise NotImplementedError
123+
124+
def peer(self):
125+
raise NotImplementedError
126+
127+
def peer_identities(self):
128+
raise NotImplementedError
129+
130+
def peer_identity_key(self):
131+
raise NotImplementedError
132+
133+
def auth_context(self):
134+
raise NotImplementedError
135+
136+
def send_initial_metadata(self, initial_metadata):
137+
raise NotImplementedError
138+
139+
def set_trailing_metadata(self, trailing_metadata):
140+
raise NotImplementedError
141+
142+
def abort(self, code, details):
143+
raise NotImplementedError
144+
145+
def abort_with_status(self, status):
146+
raise NotImplementedError
147+
148+
def set_code(self, code):
149+
raise NotImplementedError
150+
151+
def set_details(self, details):
152+
raise NotImplementedError
153+
111154

112155
## Conn Config #################################################################
113156

@@ -927,34 +970,84 @@ def test_tgis_backend_conn_testing_enabled(tgis_mock_insecure):
927970
{
928971
"type": "http",
929972
"headers": [
930-
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), b"sometext")
973+
(
974+
TGISBackend.ROUTE_INFO_HEADER_KEY.encode("latin-1"),
975+
"http exact".encode("latin-1"),
976+
)
977+
],
978+
}
979+
),
980+
"http exact",
981+
),
982+
(
983+
fastapi.Request(
984+
{
985+
"type": "http",
986+
"headers": [
987+
(
988+
TGISBackend.ROUTE_INFO_HEADER_KEY.upper().encode("latin-1"),
989+
"http upper-case".encode("latin-1"),
990+
)
931991
],
932992
}
933993
),
934-
"sometext",
994+
"http upper-case",
935995
),
936996
(
937997
fastapi.Request(
938-
{"type": "http", "headers": [(b"route-info", b"sometext")]}
998+
{
999+
"type": "http",
1000+
"headers": [
1001+
(
1002+
TGISBackend.ROUTE_INFO_HEADER_KEY.title().encode("latin-1"),
1003+
"http title-case".encode("latin-1"),
1004+
)
1005+
],
1006+
}
1007+
),
1008+
"http title-case",
1009+
),
1010+
(
1011+
fastapi.Request(
1012+
{
1013+
"type": "http",
1014+
"headers": [
1015+
(
1016+
"route-info".encode("latin-1"),
1017+
"http not-found".encode("latin-1"),
1018+
)
1019+
],
1020+
}
9391021
),
9401022
None,
9411023
),
9421024
(
943-
TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "sometext"}),
944-
"sometext",
1025+
TestServicerContext({TGISBackend.ROUTE_INFO_HEADER_KEY: "grpc exact"}),
1026+
"grpc exact",
9451027
),
9461028
(
947-
TestServicerContext({"route-info": "sometext"}),
1029+
TestServicerContext(
1030+
{TGISBackend.ROUTE_INFO_HEADER_KEY.upper(): "grpc upper-case"}
1031+
),
1032+
"grpc upper-case",
1033+
),
1034+
(
1035+
TestServicerContext(
1036+
{TGISBackend.ROUTE_INFO_HEADER_KEY.title(): "grpc title-case"}
1037+
),
1038+
"grpc title-case",
1039+
),
1040+
(
1041+
TestServicerContext({"route-info": "grpc not found"}),
9481042
None,
9491043
),
950-
("should raise TypeError", None),
1044+
("should raise TypeError", TypeError()),
9511045
(None, None),
952-
# Uncertain how to create a grpc.ServicerContext object
9531046
],
9541047
)
955-
def test_get_route_info(context, route_info: Optional[str]):
956-
if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
957-
with pytest.raises(TypeError):
1048+
def test_get_route_info(context, route_info: Union[str, None, Exception]):
1049+
if isinstance(route_info, Exception):
1050+
with pytest.raises(type(route_info)):
9581051
TGISBackend.get_route_info(context)
9591052
else:
9601053
actual_route_info = TGISBackend.get_route_info(context)
@@ -970,7 +1063,10 @@ def test_handle_runtime_context_with_route_info():
9701063
{
9711064
"type": "http",
9721065
"headers": [
973-
(TGISBackend.ROUTE_INFO_HEADER_KEY.encode(), route_info.encode("utf-8"))
1066+
(
1067+
TGISBackend.ROUTE_INFO_HEADER_KEY.encode("latin-1"),
1068+
route_info.encode("latin-1"),
1069+
)
9741070
],
9751071
}
9761072
)

0 commit comments

Comments
 (0)