Skip to content

Commit ce28084

Browse files
authored
feat(python-bindings): add session deletion API (#176)
# Description Expose session deletion API to python bindings, and add tests. Signed-off-by: Mauro Sardara <[email protected]>
1 parent d1a0303 commit ce28084

File tree

6 files changed

+96
-18
lines changed

6 files changed

+96
-18
lines changed

data-plane/gateway/service/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ impl Service {
687687
true => Ok(()),
688688
false => {
689689
error!("error deleting session");
690-
Err(ServiceError::SessionError("unknown".to_string()))
690+
Err(ServiceError::SessionError("session not found".to_string()))
691691
}
692692
}
693693
}

data-plane/python-bindings/agp_bindings/__init__.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional
66

77
from ._agp_bindings import (
8-
__version__,
8+
__version__,
99
build_profile,
1010
build_info,
1111
SESSION_UNSPECIFIED,
@@ -21,6 +21,7 @@
2121
create_pyservice,
2222
create_rr_session,
2323
create_streaming_session,
24+
delete_session,
2425
disconnect,
2526
init_tracing as init_tracing,
2627
publish,
@@ -229,6 +230,27 @@ async def create_streaming_session(
229230
self.sessions[session.id] = (session, asyncio.Queue(queue_size))
230231
return session
231232

233+
async def delete_session(self, session_id: int):
234+
"""
235+
Delete a session.
236+
237+
Args:
238+
session_id (int): The ID of the session to delete.
239+
240+
Returns:
241+
None
242+
"""
243+
244+
# Check if the session ID is in the sessions map
245+
if session_id not in self.sessions:
246+
raise Exception("session not found", session_id)
247+
248+
# Remove the session from the map
249+
del self.sessions[session_id]
250+
251+
# Remove the session from the gateway
252+
await delete_session(self.svc, session_id)
253+
232254
async def run_server(self, config: dict):
233255
"""
234256
Start the server part of the Gateway service. The server will be started only
@@ -404,6 +426,10 @@ async def publish(
404426
None
405427
"""
406428

429+
# Make sure the sessions exists
430+
if session.id not in self.sessions:
431+
raise Exception("session not found", session.id)
432+
407433
dest = PyAgentType(organization, namespace, agent)
408434
await publish(self.svc, session, 1, msg, dest, agent_id)
409435

data-plane/python-bindings/agp_bindings/_agp_bindings.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def create_rr_session(svc:PyService, config:PyRequestResponseConfiguration=...)
7272
def create_streaming_session(svc:PyService, config:PyStreamingConfiguration) -> typing.Any:
7373
...
7474

75+
def delete_session(svc:PyService, session_id:builtins.int) -> typing.Any:
76+
...
77+
7578
def disconnect(svc:PyService, conn:builtins.int) -> typing.Any:
7679
...
7780

data-plane/python-bindings/src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ mod _agp_bindings {
1616
#[pymodule_export]
1717
use pyservice::{
1818
PyService, connect, create_ff_session, create_pyservice, create_rr_session,
19-
create_streaming_session, disconnect, publish, receive, remove_route, run_server,
20-
set_route, stop_server, subscribe, unsubscribe,
19+
create_streaming_session, delete_session, disconnect, publish, receive, remove_route,
20+
run_server, set_route, stop_server, subscribe, unsubscribe,
2121
};
2222

2323
#[pymodule_export]

data-plane/python-bindings/src/pyservice.rs

+32-14
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ impl PyService {
9595
))
9696
}
9797

98+
async fn delete_session(&self, session_id: session::Id) -> Result<(), ServiceError> {
99+
self.sdk
100+
.service
101+
.delete_session(&self.sdk.agent, session_id)
102+
.await
103+
}
104+
98105
async fn run_server(&self, config: PyGrpcServerConfig) -> Result<(), ServiceError> {
99106
self.sdk.service.run_server(&config)
100107
}
@@ -250,7 +257,7 @@ pub fn create_ff_session(
250257
config.fire_and_forget_configuration,
251258
))
252259
.await
253-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
260+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
254261
})
255262
}
256263

@@ -267,7 +274,7 @@ pub fn create_rr_session(
267274
config.request_response_configuration,
268275
))
269276
.await
270-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
277+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
271278
})
272279
}
273280

@@ -284,7 +291,18 @@ pub fn create_streaming_session(
284291
config.streaming_configuration,
285292
))
286293
.await
287-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
294+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
295+
})
296+
}
297+
298+
#[gen_stub_pyfunction]
299+
#[pyfunction]
300+
#[pyo3(signature = (svc, session_id))]
301+
pub fn delete_session(py: Python, svc: PyService, session_id: u32) -> PyResult<Bound<PyAny>> {
302+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
303+
svc.delete_session(session_id)
304+
.await
305+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
288306
})
289307
}
290308

@@ -299,7 +317,7 @@ pub fn run_server(py: Python, svc: PyService, config: Py<PyDict>) -> PyResult<Bo
299317
pyo3_async_runtimes::tokio::future_into_py(py, async move {
300318
svc.run_server(config)
301319
.await
302-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
320+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
303321
})
304322
}
305323

@@ -313,7 +331,7 @@ pub fn stop_server(py: Python, svc: PyService, endpoint: String) -> PyResult<Bou
313331
pyo3_async_runtimes::tokio::future_into_py(py, async move {
314332
svc.stop_server(&endpoint)
315333
.await
316-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
334+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
317335
})
318336
}
319337

@@ -329,7 +347,7 @@ pub fn connect(py: Python, svc: PyService, config: Py<PyDict>) -> PyResult<Bound
329347
pyo3_async_runtimes::tokio::future_into_py(py, async move {
330348
svc.connect(config)
331349
.await
332-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
350+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
333351
})
334352
}
335353

@@ -339,7 +357,7 @@ pub fn disconnect(py: Python, svc: PyService, conn: u64) -> PyResult<Bound<PyAny
339357
pyo3_async_runtimes::tokio::future_into_py(py, async move {
340358
svc.disconnect(conn)
341359
.await
342-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
360+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
343361
})
344362
}
345363

@@ -356,7 +374,7 @@ pub fn subscribe(
356374
pyo3_async_runtimes::tokio::future_into_py(py, async move {
357375
svc.subscribe(conn, name, id)
358376
.await
359-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
377+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
360378
})
361379
}
362380

@@ -373,7 +391,7 @@ pub fn unsubscribe(
373391
pyo3_async_runtimes::tokio::future_into_py(py, async move {
374392
svc.unsubscribe(conn, name, id)
375393
.await
376-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
394+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
377395
})
378396
}
379397

@@ -390,7 +408,7 @@ pub fn set_route(
390408
pyo3_async_runtimes::tokio::future_into_py(py, async move {
391409
svc.set_route(conn, name, id)
392410
.await
393-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
411+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
394412
})
395413
}
396414

@@ -407,7 +425,7 @@ pub fn remove_route(
407425
pyo3_async_runtimes::tokio::future_into_py(py, async move {
408426
svc.remove_route(conn, name, id)
409427
.await
410-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
428+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
411429
})
412430
}
413431

@@ -426,7 +444,7 @@ pub fn publish(
426444
pyo3_async_runtimes::tokio::future_into_py(py, async move {
427445
svc.publish(session_info.session_info, fanout, blob, name, id)
428446
.await
429-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
447+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
430448
})
431449
}
432450

@@ -440,7 +458,7 @@ pub fn receive(py: Python, svc: PyService) -> PyResult<Bound<PyAny>> {
440458
async move {
441459
svc.receive()
442460
.await
443-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
461+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
444462
},
445463
)
446464
}
@@ -457,6 +475,6 @@ pub fn create_pyservice(
457475
pyo3_async_runtimes::tokio::future_into_py(py, async move {
458476
PyService::create_pyservice(organization, namespace, agent_type, id)
459477
.await
460-
.map_err(|e| PyErr::new::<PyException, _>(format!("{}", e.to_string())))
478+
.map_err(|e| PyErr::new::<PyException, _>(e.to_string()))
461479
})
462480
}

data-plane/python-bindings/tests/test_bindings.py

+31
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,28 @@ async def test_end_to_end(server):
6363
# check if the message is correct
6464
assert msg_rcv == bytes(msg)
6565

66+
# delete sessions
67+
await agp_bindings.delete_session(svc_alice, session_info.id)
68+
await agp_bindings.delete_session(svc_bob, session_info.id)
69+
70+
# try to send a message after deleting the session - this should raise an exception
71+
try:
72+
await agp_bindings.publish(svc_alice, session_info, 1, msg, bob_class, None)
73+
except Exception as e:
74+
assert "session not found" in str(e), f"Unexpected error message: {str(e)}"
75+
6676
# disconnect alice
6777
await agp_bindings.disconnect(svc_alice, conn_id_alice)
6878

6979
# disconnect bob
7080
await agp_bindings.disconnect(svc_bob, conn_id_bob)
7181

82+
# try to delete a random session, we should get an exception
83+
try:
84+
await agp_bindings.delete_session(svc_alice, 123456789)
85+
except Exception as e:
86+
assert "session not found" in str(e)
87+
7288

7389
@pytest.mark.asyncio
7490
@pytest.mark.parametrize("server", ["127.0.0.1:12345"], indirect=True)
@@ -131,6 +147,21 @@ async def test_gateway_wrapper(server):
131147
# check if the message is correct
132148
assert msg_rcv == bytes(msg)
133149

150+
# delete sessions
151+
await gateway1.delete_session(session_info.id)
152+
await gateway2.delete_session(session_info.id)
153+
154+
# try to send a message after deleting the session - this should raise an exception
155+
try:
156+
await gateway1.publish(session_info, msg, org, ns, agent1)
157+
except Exception as e:
158+
assert "session not found" in str(e), f"Unexpected error message: {str(e)}"
159+
160+
# try to delete a random session, we should get an exception
161+
try:
162+
await gateway1.delete_session(123456789)
163+
except Exception as e:
164+
assert "session not found" in str(e), f"Unexpected error message: {str(e)}"
134165

135166
@pytest.mark.asyncio
136167
@pytest.mark.parametrize("server", ["127.0.0.1:12346"], indirect=True)

0 commit comments

Comments
 (0)