Skip to content

Commit d1a0303

Browse files
authored
feat(python-bindings): improve configuration handling and further refactoring (#167)
# Description This commit eliminates the duplication of configuration definition, which is defined in the python bindings and in the config structs. It also improves code organization and readability: - configuration can be now passed from the bindings for each new connection or server creation - configuration logic is defined in one place only Signed-off-by: Mauro Sardara <[email protected]>
1 parent 1347d49 commit d1a0303

26 files changed

+848
-947
lines changed

data-plane/Cargo.lock

+13-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

data-plane/gateway/datapath/src/messages/utils.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1114,7 +1114,7 @@ mod tests {
11141114
let service_type =
11151115
SessionHeaderType::try_from(i).expect("failed to convert int to service type");
11161116
let service_type_int = i32::from(service_type);
1117-
assert_eq!(service_type_int, service_type.into());
1117+
assert_eq!(service_type_int, i32::from(service_type),);
11181118
}
11191119

11201120
// Test invalid conversion

data-plane/gateway/gateway/build.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fn set_env(name: &str, cmd: &mut Command) {
1414
println!("cargo:rustc-env={}={}", name, value);
1515
}
1616

17-
fn main() {
17+
pub fn main() {
1818
set_env(
1919
"GIT_SHA",
2020
Command::new("git").args(["rev-parse", "--short", "HEAD"]),

data-plane/gateway/service/src/errors.rs

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ pub enum ServiceError {
3535
SessionError(String),
3636
#[error("client already connected: {0}")]
3737
ClientAlreadyConnected(String),
38+
#[error("server not found: {0}")]
39+
ServerNotFound(String),
3840
#[error("unknown error")]
3941
Unknown,
4042
}

data-plane/gateway/service/src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,13 @@ impl Service {
296296
Ok(())
297297
}
298298

299-
pub fn stop_server(&self, endpoint: &str) {
299+
pub fn stop_server(&self, endpoint: &str) -> Result<(), ServiceError> {
300300
// stop the server
301301
if let Some(token) = self.cancellation_tokens.write().remove(endpoint) {
302302
token.cancel();
303+
Ok(())
303304
} else {
304-
error!("server {} not found", endpoint);
305+
Err(ServiceError::ServerNotFound(endpoint.to_string()))
305306
}
306307
}
307308

data-plane/gateway/service/src/streaming.rs

+28-10
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,10 @@ mod tests {
926926
let rtx_header = rtx_msg.get_session_header();
927927
assert_eq!(rtx_header.session_id, 0);
928928
assert_eq!(rtx_header.message_id, 1);
929-
assert_eq!(rtx_header.header_type, SessionHeaderType::RtxRequest.into());
929+
assert_eq!(
930+
rtx_header.header_type,
931+
i32::from(SessionHeaderType::RtxRequest)
932+
);
930933
}
931934

932935
time::sleep(Duration::from_millis(1000)).await;
@@ -988,7 +991,7 @@ mod tests {
988991
let msg_header = msg.get_session_header();
989992
assert_eq!(msg_header.session_id, 120);
990993
assert_eq!(msg_header.message_id, i);
991-
assert_eq!(msg_header.header_type, SessionHeaderType::Stream.into());
994+
assert_eq!(msg_header.header_type, i32::from(SessionHeaderType::Stream));
992995
}
993996

994997
let agp_header = Some(AgpHeader::new(
@@ -1024,7 +1027,10 @@ mod tests {
10241027
let msg_header = msg.get_session_header();
10251028
assert_eq!(msg_header.session_id, 120);
10261029
assert_eq!(msg_header.message_id, 2);
1027-
assert_eq!(msg_header.header_type, SessionHeaderType::RtxReply.into());
1030+
assert_eq!(
1031+
msg_header.header_type,
1032+
i32::from(SessionHeaderType::RtxReply)
1033+
);
10281034
assert_eq!(msg.get_payload().unwrap().blob, vec![0x1, 0x2, 0x3, 0x4]);
10291035
}
10301036

@@ -1090,7 +1096,7 @@ mod tests {
10901096
let msg_header = msg.get_session_header();
10911097
assert_eq!(msg_header.session_id, 0);
10921098
assert_eq!(msg_header.message_id, i);
1093-
assert_eq!(msg_header.header_type, SessionHeaderType::Stream.into());
1099+
assert_eq!(msg_header.header_type, i32::from(SessionHeaderType::Stream));
10941100

10951101
// the receiver should detect a loss for packet 1
10961102
if i != 1 {
@@ -1109,7 +1115,7 @@ mod tests {
11091115
let msg_header = msg.message.get_session_header();
11101116
assert_eq!(msg_header.session_id, 0);
11111117
assert_eq!(msg_header.message_id, 0);
1112-
assert_eq!(msg_header.header_type, SessionHeaderType::Stream.into());
1118+
assert_eq!(msg_header.header_type, i32::from(SessionHeaderType::Stream));
11131119
assert_eq!(
11141120
msg.message.get_source(),
11151121
Agent::from_strings("cisco", "default", "sender", 0)
@@ -1127,7 +1133,10 @@ mod tests {
11271133
let msg_header = msg.get_session_header();
11281134
assert_eq!(msg_header.session_id, 0);
11291135
assert_eq!(msg_header.message_id, 1);
1130-
assert_eq!(msg_header.header_type, SessionHeaderType::RtxRequest.into());
1136+
assert_eq!(
1137+
msg_header.header_type,
1138+
i32::from(SessionHeaderType::RtxRequest)
1139+
);
11311140
assert_eq!(
11321141
msg.get_source(),
11331142
Agent::from_strings("cisco", "default", "receiver", 0)
@@ -1144,7 +1153,10 @@ mod tests {
11441153
let msg_header = msg.get_session_header();
11451154
assert_eq!(msg_header.session_id, 0);
11461155
assert_eq!(msg_header.message_id, 1);
1147-
assert_eq!(msg_header.header_type, SessionHeaderType::RtxRequest.into());
1156+
assert_eq!(
1157+
msg_header.header_type,
1158+
i32::from(SessionHeaderType::RtxRequest)
1159+
);
11481160
assert_eq!(
11491161
msg.get_source(),
11501162
Agent::from_strings("cisco", "default", "receiver", 0)
@@ -1171,7 +1183,10 @@ mod tests {
11711183
let msg_header = msg.get_session_header();
11721184
assert_eq!(msg_header.session_id, 0);
11731185
assert_eq!(msg_header.message_id, 1);
1174-
assert_eq!(msg_header.header_type, SessionHeaderType::RtxReply.into());
1186+
assert_eq!(
1187+
msg_header.header_type,
1188+
i32::from(SessionHeaderType::RtxReply)
1189+
);
11751190
assert_eq!(
11761191
msg.get_source(),
11771192
Agent::from_strings("cisco", "default", "sender", 0)
@@ -1197,7 +1212,10 @@ mod tests {
11971212
let msg_header = msg.message.get_session_header();
11981213
assert_eq!(msg_header.session_id, 0);
11991214
assert_eq!(msg_header.message_id, 1);
1200-
assert_eq!(msg_header.header_type, SessionHeaderType::RtxReply.into());
1215+
assert_eq!(
1216+
msg_header.header_type,
1217+
i32::from(SessionHeaderType::RtxReply)
1218+
);
12011219
assert_eq!(
12021220
msg.message.get_source(),
12031221
Agent::from_strings("cisco", "default", "sender", 0)
@@ -1214,7 +1232,7 @@ mod tests {
12141232
let msg_header = msg.message.get_session_header();
12151233
assert_eq!(msg_header.session_id, 0);
12161234
assert_eq!(msg_header.message_id, 2);
1217-
assert_eq!(msg_header.header_type, SessionHeaderType::Stream.into());
1235+
assert_eq!(msg_header.header_type, i32::from(SessionHeaderType::Stream));
12181236
assert_eq!(
12191237
msg.message.get_source(),
12201238
Agent::from_strings("cisco", "default", "sender", 0)

data-plane/python-bindings/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ agp-datapath = { path = "../gateway/datapath", version = "0.5.0" }
1515
agp-service = { path = "../gateway/service", version = "0.3.0" }
1616
agp-tracing = { path = "../gateway/tracing", version = "0.1.4" }
1717
pyo3 = "0.24.1"
18-
pyo3-async-runtimes = { version = "0.24.0", features = ["tokio-runtime"] }
18+
pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] }
1919
pyo3-stub-gen = "0.7.0"
2020
rand = "0.9.0"
21+
serde-pyobject = "0.6.1"
2122
tokio = "1.43.0"
2223

2324
[package.metadata.maturin]

data-plane/python-bindings/agp_bindings/__init__.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from typing import Optional
66

77
from ._agp_bindings import (
8+
__version__,
9+
build_profile,
10+
build_info,
811
SESSION_UNSPECIFIED,
912
PyAgentType,
1013
PyFireAndForgetConfiguration,
11-
PyGatewayConfig as GatewayConfig,
1214
PyRequestResponseConfiguration,
1315
PyService,
1416
PySessionDirection as PySessionDirection,
@@ -24,9 +26,9 @@
2426
publish,
2527
receive,
2628
remove_route,
27-
serve,
29+
run_server,
2830
set_route,
29-
stop,
31+
stop_server,
3032
subscribe,
3133
unsubscribe,
3234
)
@@ -74,6 +76,9 @@ def __init__(
7476
self.local_name = PyAgentType(organization, namespace, agent)
7577
self.local_id = self.svc.id
7678

79+
# Create connection ID map
80+
self.conn_ids: dict[str, int] = {}
81+
7782
async def __aenter__(self):
7883
"""
7984
Start the receiver loop in the background.
@@ -158,19 +163,6 @@ def get_agent_id(self) -> int:
158163

159164
return self.svc.id
160165

161-
def configure(self, config: GatewayConfig):
162-
"""
163-
Configure the gateway.
164-
165-
Args:
166-
config (GatewayConfig): The gateway configuration class.
167-
168-
Returns:
169-
None
170-
"""
171-
172-
self.svc.configure(config)
173-
174166
async def create_ff_session(
175167
self,
176168
session_config: PyFireAndForgetConfiguration = PyFireAndForgetConfiguration(),
@@ -237,7 +229,7 @@ async def create_streaming_session(
237229
self.sessions[session.id] = (session, asyncio.Queue(queue_size))
238230
return session
239231

240-
async def run_server(self):
232+
async def run_server(self, config: dict):
241233
"""
242234
Start the server part of the Gateway service. The server will be started only
243235
if its configuration is set. Otherwise, it will raise an error.
@@ -249,9 +241,9 @@ async def run_server(self):
249241
None
250242
"""
251243

252-
await serve(self.svc)
244+
await run_server(self.svc, config)
253245

254-
async def stop_server(self):
246+
async def stop_server(self, endpoint: str):
255247
"""
256248
Stop the server part of the Gateway service.
257249
@@ -262,9 +254,9 @@ async def stop_server(self):
262254
None
263255
"""
264256

265-
await stop(self.svc)
257+
await stop_server(self.svc, endpoint)
266258

267-
async def connect(self) -> int:
259+
async def connect(self, client_config: dict) -> int:
268260
"""
269261
Connect to a remote gateway service.
270262
This function will block until the connection is established.
@@ -276,15 +268,24 @@ async def connect(self) -> int:
276268
int: The connection ID.
277269
"""
278270

279-
self.conn_id = await connect(self.svc)
271+
conn_id = await connect(
272+
self.svc,
273+
client_config,
274+
)
275+
276+
# Save the connection ID
277+
self.conn_ids[client_config["endpoint"]] = conn_id
278+
279+
# For the moment we manage one connection only
280+
self.conn_id = conn_id
280281

281282
# Subscribe to the local name
282-
await subscribe(self.svc, self.conn_id, self.local_name, self.local_id)
283+
await subscribe(self.svc, conn_id, self.local_name, self.local_id)
283284

284285
# return the connection ID
285-
return self.conn_id
286+
return conn_id
286287

287-
async def disconnect(self):
288+
async def disconnect(self, endpoint: str):
288289
"""
289290
Disconnect from a remote gateway service.
290291
This function will block until the disconnection is complete.
@@ -296,11 +297,15 @@ async def disconnect(self):
296297
None
297298
298299
"""
299-
300-
await disconnect(self.svc, self.conn_id)
300+
conn = self.conn_ids[endpoint]
301+
await disconnect(self.svc, conn)
301302

302303
async def set_route(
303-
self, organization: str, namespace: str, agent: str, id: Optional[int] = None
304+
self,
305+
organization: str,
306+
namespace: str,
307+
agent: str,
308+
id: Optional[int] = None,
304309
):
305310
"""
306311
Set route for outgoing messages via the connected gateway.

data-plane/python-bindings/agp_bindings/_agp_bindings.pyi

+3-23
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,6 @@ class PyFireAndForgetConfiguration:
2222
def __new__(cls,): ...
2323
...
2424

25-
class PyGatewayConfig:
26-
r"""
27-
gatewayconfig class
28-
"""
29-
endpoint: builtins.str
30-
insecure: builtins.bool
31-
insecure_skip_verify: builtins.bool
32-
tls_ca_path: typing.Optional[builtins.str]
33-
tls_ca_pem: typing.Optional[builtins.str]
34-
tls_cert_path: typing.Optional[builtins.str]
35-
tls_key_path: typing.Optional[builtins.str]
36-
tls_cert_pem: typing.Optional[builtins.str]
37-
tls_key_pem: typing.Optional[builtins.str]
38-
basic_auth_username: typing.Optional[builtins.str]
39-
basic_auth_password: typing.Optional[builtins.str]
40-
def __new__(cls,endpoint:builtins.str, insecure:builtins.bool=False, insecure_skip_verify:builtins.bool=False, tls_ca_path:typing.Optional[builtins.str]=None, tls_ca_pem:typing.Optional[builtins.str]=None, tls_cert_path:typing.Optional[builtins.str]=None, tls_key_path:typing.Optional[builtins.str]=None, tls_cert_pem:typing.Optional[builtins.str]=None, tls_key_pem:typing.Optional[builtins.str]=None, basic_auth_username:typing.Optional[builtins.str]=None, basic_auth_password:typing.Optional[builtins.str]=None): ...
41-
4225
class PyRequestResponseConfiguration:
4326
r"""
4427
request response session config
@@ -55,9 +38,6 @@ class PyRequestResponseConfiguration:
5538

5639
class PyService:
5740
id: builtins.int
58-
def configure(self, config:PyGatewayConfig) -> None:
59-
...
60-
6141

6242
class PySessionInfo:
6343
id: builtins.int
@@ -80,7 +60,7 @@ class PySessionDirection(Enum):
8060
RECEIVER = auto()
8161
BIDIRECTIONAL = auto()
8262

83-
def connect(svc:PyService) -> typing.Any:
63+
def connect(svc:PyService, config:dict) -> typing.Any:
8464
...
8565

8666
def create_ff_session(svc:PyService, config:PyFireAndForgetConfiguration=...) -> typing.Any:
@@ -104,13 +84,13 @@ def receive(svc:PyService) -> typing.Any:
10484
def remove_route(svc:PyService, conn:builtins.int, name:PyAgentType, id:typing.Optional[builtins.int]=None) -> typing.Any:
10585
...
10686

107-
def serve(svc:PyService) -> typing.Any:
87+
def run_server(svc:PyService, config:dict) -> typing.Any:
10888
...
10989

11090
def set_route(svc:PyService, conn:builtins.int, name:PyAgentType, id:typing.Optional[builtins.int]=None) -> typing.Any:
11191
...
11292

113-
def stop(svc:PyService) -> typing.Any:
93+
def stop_server(svc:PyService, endpoint:builtins.str) -> typing.Any:
11494
...
11595

11696
def subscribe(svc:PyService, conn:builtins.int, name:PyAgentType, id:typing.Optional[builtins.int]=None) -> typing.Any:

0 commit comments

Comments
 (0)