Skip to content

Commit d1450b0

Browse files
committed
feat: pass request-headers as metadata
1 parent a8098f2 commit d1450b0

File tree

4 files changed

+157
-8
lines changed

4 files changed

+157
-8
lines changed

jina/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def _ignore_google_warnings():
3535
'ignore',
3636
category=DeprecationWarning,
3737
message='Deprecated call to `pkg_resources.declare_namespace(\'google\')`.',
38-
append=True
38+
append=True,
3939
)
4040

4141

@@ -81,7 +81,7 @@ def _ignore_google_warnings():
8181

8282
# do not change this line manually
8383
# this is managed by proto/build-proto.sh and updated on every execution
84-
__proto_version__ = '0.1.27'
84+
__proto_version__ = '0.1.28'
8585

8686
try:
8787
__docarray_version__ = _docarray.__version__

jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,7 @@ async def post(body: input_model, response: Response, request: Request):
210210
docs,
211211
exec_endpoint=endpoint_path,
212212
parameters=body.parameters,
213-
metadata=dict(
214-
request.headers or {
215-
"no_headers": "true"
216-
}
217-
),
213+
metadata=dict(request.headers or {"no_headers": "true"}),
218214
target_executor=target_executor,
219215
request_id=req_id,
220216
return_results=True,
@@ -252,6 +248,8 @@ def add_streaming_routes(
252248
endpoint_path,
253249
input_doc_model=None,
254250
):
251+
from fastapi import Request
252+
255253
@app.api_route(
256254
path=f'/{endpoint_path.strip("/")}',
257255
methods=['GET'],

jina/serve/runtimes/worker/http_fastapi_app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def post(body: input_model, response: Response, request: Request):
9797

9898
if body.parameters is not None:
9999
req.parameters = body.parameters
100-
req.metadata = dict(request.headers or {"no_headers": "true"})
100+
req.metadata = dict(request.headers or {})
101101
req.header.exec_endpoint = endpoint_path
102102
data = body.data
103103
if isinstance(data, list):
@@ -152,6 +152,7 @@ async def streaming_get(request: Request = None, body: input_doc_model = None):
152152
body = Document.from_pydantic_model(body)
153153
req = DataRequest()
154154
req.header.exec_endpoint = endpoint_path
155+
req.metadata = dict(request.headers or {})
155156
if not docarray_v2:
156157
req.data.docs = DocumentArray([body])
157158
else:
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import logging
2+
from typing import Dict, List, Literal, Optional
3+
4+
import pytest
5+
from docarray import BaseDoc, DocList
6+
7+
from jina import Client, Deployment, Executor, requests
8+
from jina.helper import random_port
9+
10+
11+
class PortGetter:
12+
def __init__(self):
13+
self.ports = {
14+
"http": {
15+
True: random_port(),
16+
False: random_port(),
17+
},
18+
"grpc": {
19+
True: random_port(),
20+
False: random_port(),
21+
},
22+
}
23+
24+
def get_port(self, protocol: Literal["http", "grpc"], include_gateway: bool) -> int:
25+
return self.ports[protocol][include_gateway]
26+
27+
@property
28+
def gateway_ports(self) -> List[int]:
29+
return [self.ports["http"][True], self.ports["grpc"][True]]
30+
31+
@property
32+
def no_gateway_ports(self) -> List[int]:
33+
return [self.ports["http"][False], self.ports["grpc"][False]]
34+
35+
36+
@pytest.fixture(scope='module')
37+
def port_getter() -> callable:
38+
getter = PortGetter()
39+
return getter
40+
41+
42+
class DictDoc(BaseDoc):
43+
data: dict
44+
45+
46+
class MetadataExecutor(Executor):
47+
@requests(on="/get-metadata-headers")
48+
def post_endpoint(
49+
self,
50+
docs: DocList[DictDoc],
51+
parameters: Optional[Dict] = None,
52+
metadata: Optional[Dict] = None,
53+
**kwargs,
54+
) -> DocList[DictDoc]:
55+
return DocList[DictDoc]([DictDoc(data=metadata)])
56+
57+
@requests(on='/stream-metadata-headers')
58+
async def stream_task(
59+
self, doc: DictDoc, metadata: Optional[dict] = None, **kwargs
60+
) -> DictDoc:
61+
for k, v in sorted((metadata or {}).items()):
62+
yield DictDoc(data={k: v})
63+
64+
yield DictDoc(data={"DONE": "true"})
65+
66+
67+
@pytest.fixture(scope='module')
68+
def deployment_no_gateway(port_getter: PortGetter) -> Deployment:
69+
70+
with Deployment(
71+
uses=MetadataExecutor,
72+
protocol=["http", "grpc"],
73+
port=port_getter.no_gateway_ports,
74+
include_gateway=False,
75+
) as dep:
76+
yield dep
77+
78+
79+
@pytest.fixture(scope='module')
80+
def deployment_gateway(port_getter: PortGetter) -> Deployment:
81+
82+
with Deployment(
83+
uses=MetadataExecutor,
84+
protocol=["http", "grpc"],
85+
port=port_getter.gateway_ports,
86+
include_gateway=False,
87+
) as dep:
88+
yield dep
89+
90+
91+
@pytest.fixture(scope='module')
92+
def deployments(deployment_gateway, deployment_no_gateway) -> Dict[bool, Deployment]:
93+
return {
94+
True: deployment_gateway,
95+
False: deployment_no_gateway,
96+
}
97+
98+
99+
@pytest.mark.parametrize('include_gateway', [False, True])
100+
def test_headers_in_http_metadata(
101+
include_gateway, port_getter: PortGetter, deployments
102+
):
103+
port = port_getter.get_port("http", include_gateway)
104+
data = {
105+
"data": [{"text": "test"}],
106+
"parameters": {
107+
"parameter1": "value1",
108+
},
109+
}
110+
logging.info(f"Posting to {port}")
111+
client = Client(port=port, protocol="http")
112+
resp = client.post(
113+
on=f'/get-metadata-headers',
114+
inputs=DocList([DictDoc(data=data)]),
115+
headers={
116+
"header1": "value1",
117+
"header2": "value2",
118+
},
119+
return_type=DocList[DictDoc],
120+
)
121+
assert resp[0].data['header1'] == 'value1'
122+
123+
124+
@pytest.mark.asyncio
125+
@pytest.mark.parametrize('include_gateway', [False, True])
126+
async def test_headers_in_http_metadata_streaming(
127+
include_gateway, port_getter: PortGetter, deployments
128+
):
129+
client = Client(
130+
port=port_getter.get_port("http", include_gateway),
131+
protocol="http",
132+
asyncio=True,
133+
)
134+
data = {"data": [{"text": "test"}], "parameters": {"parameter1": "value1"}}
135+
chunks = []
136+
137+
async for doc in client.stream_doc(
138+
on=f'/stream-metadata-headers',
139+
inputs=DictDoc(data=data),
140+
headers={
141+
"header1": "value1",
142+
"header2": "value2",
143+
},
144+
return_type=DictDoc,
145+
):
146+
chunks.append(doc)
147+
assert len(chunks) > 2
148+
149+
assert DictDoc(data={'header1': 'value1'}) in chunks
150+
assert DictDoc(data={'header2': 'value2'}) in chunks

0 commit comments

Comments
 (0)