Skip to content

Commit 609287e

Browse files
varunshankartmbo
andauthored
Port "connection timeout to action server" changes to 3.6.x - [ENG 689] (#12965)
* Merge pull request #106 from RasaHQ/ENG-680-DEFAULT_KEEP_ALIVE_TIMEOUT Fix connection to action server - [ENG 680] --------- Co-authored-by: Tom Bocklisch <tom@rasa.com>
1 parent 241bf28 commit 609287e

6 files changed

Lines changed: 136 additions & 157 deletions

File tree

rasa/core/agent.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -112,53 +112,59 @@ async def _pull_model_and_fingerprint(
112112

113113
logger.debug(f"Requesting model from server {model_server.url}...")
114114

115-
try:
116-
params = model_server.combine_parameters()
117-
async with model_server.session.request(
118-
"GET",
119-
model_server.url,
120-
timeout=DEFAULT_REQUEST_TIMEOUT,
121-
headers=headers,
122-
params=params,
123-
) as resp:
124-
if resp.status in [204, 304]:
125-
logger.debug(
126-
"Model server returned {} status code, "
127-
"indicating that no new model is available. "
128-
"Current fingerprint: {}"
129-
"".format(resp.status, fingerprint)
130-
)
131-
return None
132-
elif resp.status == 404:
133-
logger.debug(
134-
"Model server could not find a model at the requested "
135-
"endpoint '{}'. It's possible that no model has been "
136-
"trained, or that the requested tag hasn't been "
137-
"assigned.".format(model_server.url)
138-
)
139-
return None
140-
elif resp.status != 200:
141-
logger.debug(
142-
"Tried to fetch model from server, but server response "
143-
"status code is {}. We'll retry later..."
144-
"".format(resp.status)
115+
async with model_server.session() as session:
116+
try:
117+
params = model_server.combine_parameters()
118+
async with session.request(
119+
"GET",
120+
model_server.url,
121+
timeout=DEFAULT_REQUEST_TIMEOUT,
122+
headers=headers,
123+
params=params,
124+
) as resp:
125+
126+
if resp.status in [204, 304]:
127+
logger.debug(
128+
"Model server returned {} status code, "
129+
"indicating that no new model is available. "
130+
"Current fingerprint: {}"
131+
"".format(resp.status, fingerprint)
132+
)
133+
return None
134+
elif resp.status == 404:
135+
logger.debug(
136+
"Model server could not find a model at the requested "
137+
"endpoint '{}'. It's possible that no model has been "
138+
"trained, or that the requested tag hasn't been "
139+
"assigned.".format(model_server.url)
140+
)
141+
return None
142+
elif resp.status != 200:
143+
logger.debug(
144+
"Tried to fetch model from server, but server response "
145+
"status code is {}. We'll retry later..."
146+
"".format(resp.status)
147+
)
148+
return None
149+
150+
model_path = Path(model_directory) / resp.headers.get(
151+
"filename", "model.tar.gz"
145152
)
146-
return None
147-
model_path = Path(model_directory) / resp.headers.get(
148-
"filename", "model.tar.gz"
153+
with open(model_path, "wb") as file:
154+
file.write(await resp.read())
155+
156+
logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))
157+
158+
# return the new fingerprint
159+
return resp.headers.get("ETag")
160+
161+
except aiohttp.ClientError as e:
162+
logger.debug(
163+
"Tried to fetch model from server, but "
164+
"couldn't reach server. We'll retry later... "
165+
"Error: {}.".format(e)
149166
)
150-
with open(model_path, "wb") as file:
151-
file.write(await resp.read())
152-
logger.debug("Saved model to '{}'".format(os.path.abspath(model_path)))
153-
# return the new fingerprint
154-
return resp.headers.get("ETag")
155-
except aiohttp.ClientError as e:
156-
logger.debug(
157-
"Tried to fetch model from server, but "
158-
"couldn't reach server. We'll retry later... "
159-
"Error: {}.".format(e)
160-
)
161-
return None
167+
return None
162168

163169

164170
async def _run_model_pulling_worker(model_server: EndpointConfig, agent: Agent) -> None:

rasa/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
DEFAULT_LOCK_LIFETIME = 60 # in seconds
2626

27+
DEFAULT_KEEP_ALIVE_TIMEOUT = 120 # in seconds
28+
2729
BEARER_TOKEN_PREFIX = "Bearer "
2830

2931
# The lowest priority is intended to be used by machine learning policies.

rasa/core/run.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
import asyncio
22
import logging
33
import uuid
4+
import platform
45
import os
56
from functools import partial
6-
from typing import Any, List, Optional, TYPE_CHECKING, Text, Union, Dict
7+
from typing import (
8+
Any,
9+
Callable,
10+
List,
11+
Optional,
12+
Text,
13+
Tuple,
14+
Union,
15+
Dict,
16+
)
717

818
import rasa.core.utils
919
from rasa.plugin import plugin_manager
@@ -23,8 +33,6 @@
2333
from sanic import Sanic
2434
from asyncio import AbstractEventLoop
2535

26-
if TYPE_CHECKING:
27-
from aiohttp import ClientSession
2836

2937
logger = logging.getLogger() # get the root logger
3038

@@ -80,6 +88,14 @@ def _create_app_without_api(cors: Optional[Union[Text, List[Text]]] = None) -> S
8088
return app
8189

8290

91+
def _is_apple_silicon_system() -> bool:
92+
# check if the system is MacOS
93+
if platform.system().lower() != "darwin":
94+
return False
95+
# check for arm architecture, indicating apple silicon
96+
return platform.machine().startswith("arm") or os.uname().machine.startswith("arm")
97+
98+
8399
def configure_app(
84100
input_channels: Optional[List["InputChannel"]] = None,
85101
cors: Optional[Union[Text, List[Text], None]] = None,
@@ -99,6 +115,9 @@ def configure_app(
99115
syslog_port: Optional[int] = None,
100116
syslog_protocol: Optional[Text] = None,
101117
request_timeout: Optional[int] = None,
118+
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
119+
use_uvloop: Optional[bool] = True,
120+
keep_alive_timeout: int = constants.DEFAULT_KEEP_ALIVE_TIMEOUT,
102121
) -> Sanic:
103122
"""Run the agent."""
104123
rasa.core.utils.configure_file_logging(
@@ -118,6 +137,14 @@ def configure_app(
118137
else:
119138
app = _create_app_without_api(cors)
120139

140+
app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout
141+
if _is_apple_silicon_system() or not use_uvloop:
142+
app.config.USE_UVLOOP = False
143+
# some library still sets the loop to uvloop, even if disabled for sanic
144+
# using uvloop leads to breakingio errors, see
145+
# https://rasahq.atlassian.net/browse/ENG-667
146+
asyncio.set_event_loop_policy(None)
147+
121148
if input_channels:
122149
channels.channel.register(input_channels, app, route=route)
123150
else:
@@ -150,6 +177,10 @@ async def run_cmdline_io(running_app: Sanic) -> None:
150177

151178
app.add_task(run_cmdline_io)
152179

180+
if server_listeners:
181+
for (listener, event) in server_listeners:
182+
app.register_listener(listener, event)
183+
153184
return app
154185

155186

@@ -179,6 +210,7 @@ def serve_application(
179210
syslog_port: Optional[int] = None,
180211
syslog_protocol: Optional[Text] = None,
181212
request_timeout: Optional[int] = None,
213+
server_listeners: Optional[List[Tuple[Callable, Text]]] = None,
182214
) -> None:
183215
"""Run the API entrypoint."""
184216
if not channel and not credentials:
@@ -204,6 +236,7 @@ def serve_application(
204236
syslog_port=syslog_port,
205237
syslog_protocol=syslog_protocol,
206238
request_timeout=request_timeout,
239+
server_listeners=server_listeners,
207240
)
208241

209242
ssl_context = server.create_ssl_context(
@@ -217,7 +250,7 @@ def serve_application(
217250
partial(load_agent_on_start, model_path, endpoints, remote_storage),
218251
"before_server_start",
219252
)
220-
app.register_listener(create_connection_pools, "after_server_start")
253+
221254
app.register_listener(close_resources, "after_server_stop")
222255

223256
number_of_workers = rasa.core.utils.number_of_sanic_workers(
@@ -279,44 +312,3 @@ async def close_resources(app: Sanic, _: AbstractEventLoop) -> None:
279312
event_broker = current_agent.tracker_store.event_broker
280313
if event_broker:
281314
await event_broker.close()
282-
283-
action_endpoint = current_agent.action_endpoint
284-
if action_endpoint:
285-
await action_endpoint.session.close()
286-
287-
model_server = current_agent.model_server
288-
if model_server:
289-
await model_server.session.close()
290-
291-
292-
async def create_connection_pools(app: Sanic, _: AbstractEventLoop) -> None:
293-
"""Create connection pools for the agent's action server and model server."""
294-
current_agent = getattr(app.ctx, "agent", None)
295-
if not current_agent:
296-
logger.debug("No agent found after server start.")
297-
return None
298-
299-
create_action_endpoint_connection_pool(current_agent)
300-
create_model_server_connection_pool(current_agent)
301-
302-
return None
303-
304-
305-
def create_action_endpoint_connection_pool(agent: Agent) -> Optional["ClientSession"]:
306-
"""Create a connection pool for the action endpoint."""
307-
action_endpoint = agent.action_endpoint
308-
if not action_endpoint:
309-
logger.debug("No action endpoint found after server start.")
310-
return None
311-
312-
return action_endpoint.session
313-
314-
315-
def create_model_server_connection_pool(agent: Agent) -> Optional["ClientSession"]:
316-
"""Create a connection pool for the model server."""
317-
model_server = agent.model_server
318-
if not model_server:
319-
logger.debug("No model server endpoint found after server start.")
320-
return None
321-
322-
return model_server.session

rasa/utils/endpoints.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import ssl
2-
from functools import cached_property
32

43
import aiohttp
5-
import logging
64
import os
75
from aiohttp.client_exceptions import ContentTypeError
86
from sanic.request import Request
@@ -11,10 +9,11 @@
119
from rasa.shared.exceptions import FileNotFoundException
1210
import rasa.shared.utils.io
1311
import rasa.utils.io
12+
import structlog
1413
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
1514

1615

17-
logger = logging.getLogger(__name__)
16+
structlogger = structlog.get_logger()
1817

1918

2019
def read_endpoint_config(
@@ -32,9 +31,13 @@ def read_endpoint_config(
3231

3332
return EndpointConfig.from_dict(content[endpoint_type])
3433
except FileNotFoundError:
35-
logger.error(
36-
"Failed to read endpoint configuration "
37-
"from {}. No such file.".format(os.path.abspath(filename))
34+
structlogger.error(
35+
"endpoint.read.failed_no_such_file",
36+
filename=os.path.abspath(filename),
37+
event_info=(
38+
"Failed to read endpoint configuration file - "
39+
"the file was not found."
40+
),
3841
)
3942
return None
4043

@@ -56,9 +59,13 @@ def concat_url(base: Text, subpath: Optional[Text]) -> Text:
5659
"""
5760
if not subpath:
5861
if base.endswith("/"):
59-
logger.debug(
60-
f"The URL '{base}' has a trailing slash. Please make sure the "
61-
f"target server supports trailing slashes for this endpoint."
62+
structlogger.debug(
63+
"endpoint.concat_url.trailing_slash",
64+
url=base,
65+
event_info=(
66+
"The URL has a trailing slash. Please make sure the "
67+
"target server supports trailing slashes for this endpoint."
68+
),
6269
)
6370
return base
6471

@@ -95,7 +102,6 @@ def __init__(
95102
self.cafile = cafile
96103
self.kwargs = kwargs
97104

98-
@cached_property
99105
def session(self) -> aiohttp.ClientSession:
100106
"""Creates and returns a configured aiohttp client session."""
101107
# create authentication parameters
@@ -164,23 +170,26 @@ async def request(
164170
f"'{os.path.abspath(self.cafile)}' does not exist."
165171
) from e
166172

167-
async with self.session.request(
168-
method,
169-
url,
170-
headers=headers,
171-
params=self.combine_parameters(kwargs),
172-
compress=compress,
173-
ssl=sslcontext,
174-
**kwargs,
175-
) as response:
176-
if response.status >= 400:
177-
raise ClientResponseError(
178-
response.status, response.reason, await response.content.read()
179-
)
180-
try:
181-
return await response.json()
182-
except ContentTypeError:
183-
return None
173+
async with self.session() as session:
174+
async with session.request(
175+
method,
176+
url,
177+
headers=headers,
178+
params=self.combine_parameters(kwargs),
179+
compress=compress,
180+
ssl=sslcontext,
181+
**kwargs,
182+
) as response:
183+
if response.status >= 400:
184+
raise ClientResponseError(
185+
response.status,
186+
response.reason,
187+
await response.content.read(),
188+
)
189+
try:
190+
return await response.json()
191+
except ContentTypeError:
192+
return None
184193

185194
@classmethod
186195
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
@@ -263,7 +272,7 @@ def float_arg(
263272
try:
264273
return float(str(arg))
265274
except (ValueError, TypeError):
266-
logger.warning(f"Failed to convert '{arg}' to float.")
275+
structlogger.warning("endpoint.float_arg.convert_failed", arg=arg, key=key)
267276
return default
268277

269278

@@ -291,5 +300,6 @@ def int_arg(
291300
try:
292301
return int(str(arg))
293302
except (ValueError, TypeError):
294-
logger.warning(f"Failed to convert '{arg}' to int.")
303+
304+
structlogger.warning("endpoint.int_arg.convert_failed", arg=arg, key=key)
295305
return default

0 commit comments

Comments
 (0)