Skip to content

Commit

Permalink
fix: remove inputs state from client (#6207)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Oct 1, 2024
1 parent ebbc251 commit 47eb5f0
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 45 deletions.
53 changes: 17 additions & 36 deletions jina/clients/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import os
from abc import ABC
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, Optional, Union, Tuple

from jina.excepts import BadClientInput
from jina.helper import T, parse_client, send_telemetry_event, typename
Expand Down Expand Up @@ -47,8 +47,6 @@ def __init__(
# affect users os-level envs.
os.unsetenv('http_proxy')
os.unsetenv('https_proxy')
self._inputs = None
self._inputs_length = None
self._setup_instrumentation(
name=(
self.args.name
Expand Down Expand Up @@ -125,60 +123,43 @@ def check_input(inputs: Optional['InputType'] = None, **kwargs) -> None:
raise BadClientInput from ex

def _get_requests(
self, **kwargs
) -> Union[Iterator['Request'], AsyncIterator['Request']]:
self, inputs, **kwargs
) -> Tuple[Union[Iterator['Request'], AsyncIterator['Request']], Optional[int]]:
"""
Get request in generator.
:param inputs: The inputs argument to get the requests from.
:param kwargs: Keyword arguments.
:return: Iterator of request.
:return: Iterator of request and the length of the inputs.
"""
_kwargs = vars(self.args)
_kwargs['data'] = self.inputs
if hasattr(inputs, '__call__'):
inputs = inputs()

_kwargs['data'] = inputs
# override by the caller-specific kwargs
_kwargs.update(kwargs)

if hasattr(self._inputs, '__len__'):
total_docs = len(self._inputs)
if hasattr(inputs, '__len__'):
total_docs = len(inputs)
elif 'total_docs' in _kwargs:
total_docs = _kwargs['total_docs']
else:
total_docs = None

if total_docs:
self._inputs_length = max(1, total_docs / _kwargs['request_size'])
inputs_length = max(1, total_docs / _kwargs['request_size'])
else:
inputs_length = None

if inspect.isasyncgen(self.inputs):
if inspect.isasyncgen(inputs):
from jina.clients.request.asyncio import request_generator

return request_generator(**_kwargs)
return request_generator(**_kwargs), inputs_length
else:
from jina.clients.request import request_generator

return request_generator(**_kwargs)

@property
def inputs(self) -> 'InputType':
"""
An iterator of bytes, each element represents a Document's raw content.
``inputs`` defined in the protobuf
:return: inputs
"""
return self._inputs

@inputs.setter
def inputs(self, bytes_gen: 'InputType') -> None:
"""
Set the input data.
:param bytes_gen: input type
"""
if hasattr(bytes_gen, '__call__'):
self._inputs = bytes_gen()
else:
self._inputs = bytes_gen
return request_generator(**_kwargs), inputs_length

@abc.abstractmethod
async def _get_results(
Expand Down
5 changes: 2 additions & 3 deletions jina/clients/base/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ async def _get_results(
else grpc.Compression.NoCompression
)

self.inputs = inputs
req_iter = self._get_requests(**kwargs)
req_iter, inputs_length = self._get_requests(inputs=inputs, **kwargs)
continue_on_error = self.continue_on_error
# while loop with retries, check in which state the `iterator` remains after failure
options = client_grpc_options(
Expand Down Expand Up @@ -120,7 +119,7 @@ async def _get_results(
self.logger.debug(f'connected to {self.args.host}:{self.args.port}')

with ProgressBar(
total_length=self._inputs_length, disable=not self.show_progress
total_length=inputs_length, disable=not self.show_progress
) as p_bar:
try:
if stream:
Expand Down
5 changes: 2 additions & 3 deletions jina/clients/base/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,14 @@ async def _get_results(
with ImportExtensions(required=True):
pass

self.inputs = inputs
request_iterator = self._get_requests(**kwargs)
request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs)
on = kwargs.get('on', '/post')
if len(self._endpoints) == 0:
await self._get_endpoints_from_openapi(**kwargs)

async with AsyncExitStack() as stack:
cm1 = ProgressBar(
total_length=self._inputs_length, disable=not self.show_progress
total_length=inputs_length, disable=not self.show_progress
)
p_bar = stack.enter_context(cm1)
proto = 'https' if self.args.tls else 'http'
Expand Down
5 changes: 2 additions & 3 deletions jina/clients/base/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,11 @@ async def _get_results(
with ImportExtensions(required=True):
pass

self.inputs = inputs
request_iterator = self._get_requests(**kwargs)
request_iterator, inputs_length = self._get_requests(inputs=inputs, **kwargs)

async with AsyncExitStack() as stack:
cm1 = ProgressBar(
total_length=self._inputs_length, disable=not (self.show_progress)
total_length=inputs_length, disable=not (self.show_progress)
)
p_bar = stack.enter_context(cm1)

Expand Down

0 comments on commit 47eb5f0

Please sign in to comment.