Skip to content

Commit ed6297f

Browse files
stuqdognpmenard
andauthored
RSDK-828 - fix file handle leak (#7)
Co-authored-by: Nicolas Menard <[email protected]>
1 parent 2595535 commit ed6297f

File tree

6 files changed

+150
-50
lines changed

6 files changed

+150
-50
lines changed

src/ffi/dial_ffi.rs

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@ use tokio::sync::oneshot;
1111
use tracing::Level;
1212

1313
use crate::rpc::dial::{
14-
CredentialsExt, DialBuilder, DialOptions, WithCredentials, WithoutCredentials,
14+
CredentialsExt, DialBuilder, DialOptions, ViamChannel, WithCredentials, WithoutCredentials,
1515
};
1616
use libc::c_char;
1717

1818
use crate::proxy;
1919
use hyper::Server;
2020
use std::ffi::{CStr, CString};
21-
use tower::{make::Shared, ServiceBuilder};
21+
use tower::{make::Shared, util::Either, ServiceBuilder};
2222
use tower_http::{
23+
auth::AddAuthorization,
2324
trace::{DefaultMakeSpan, DefaultOnRequest, DefaultOnResponse, TraceLayer},
2425
LatencyUnit,
2526
};
@@ -32,6 +33,7 @@ use crate::proxy::grpc_proxy::GRPCProxy;
3233
pub struct DialFfi {
3334
runtime: Option<Runtime>,
3435
sigs: Option<Vec<oneshot::Sender<()>>>,
36+
channels: Vec<Either<AddAuthorization<ViamChannel>, ViamChannel>>,
3537
}
3638

3739
impl Drop for DialFfi {
@@ -48,6 +50,7 @@ impl DialFfi {
4850
Self {
4951
runtime: Some(Runtime::new().unwrap()),
5052
sigs: None,
53+
channels: vec![],
5154
}
5255
}
5356
fn push_signal(&mut self, sig: oneshot::Sender<()>) {
@@ -172,18 +175,19 @@ pub unsafe extern "C" fn dial(
172175
} else {
173176
disable_webrtc = uri_str.contains(".local") || uri_str.contains("localhost");
174177
}
175-
let server = match runtime.block_on(async move {
176-
let dial = match payload {
177-
Some(p) => tower::util::Either::A(
178+
let (server, channel) = match runtime.block_on(async move {
179+
let channel = match payload {
180+
Some(p) => Either::A(
178181
dial_with_cred(uri_str, p.to_str()?, allow_insec, disable_webrtc)?
179182
.connect()
180183
.await?,
181184
),
182185
None => {
183186
let c = dial_without_cred(uri_str, allow_insec, disable_webrtc)?;
184-
tower::util::Either::B(c.connect().await?)
187+
Either::B(c.connect().await?)
185188
}
186189
};
190+
let dial = channel.clone();
187191
let g = GRPCProxy::new(dial, uri);
188192
let service = ServiceBuilder::new()
189193
.layer(
@@ -200,14 +204,15 @@ pub unsafe extern "C" fn dial(
200204
let server = Server::builder(conn)
201205
.http2_only(true)
202206
.serve(Shared::new(service));
203-
Ok::<_, Box<dyn std::error::Error>>(server)
207+
Ok::<_, Box<dyn std::error::Error>>((server, channel))
204208
}) {
205209
Ok(s) => s,
206210
Err(e) => {
207211
println!("Error building GRPC proxy reason : {e:?}");
208212
return ptr::null_mut();
209213
}
210214
};
215+
ctx.channels.push(channel);
211216
let server = server.with_graceful_shutdown(async {
212217
rx.await.ok();
213218
});
@@ -248,13 +253,25 @@ pub extern "C" fn free_rust_runtime(rt_ptr: Option<Box<DialFfi>>) -> i32 {
248253
return -1;
249254
}
250255
};
251-
match ctx.sigs.take() {
252-
Some(sigs) => {
253-
for sig in sigs {
254-
let _ = sig.send(());
255-
}
256+
if let Some(sigs) = ctx.sigs.take() {
257+
for sig in sigs {
258+
let _ = sig.send(());
259+
}
260+
}
261+
262+
for channel in &ctx.channels {
263+
let channel = match channel {
264+
Either::A(chan) => chan.get_ref(),
265+
Either::B(chan) => chan,
266+
};
267+
match channel {
268+
ViamChannel::Direct(_) => (),
269+
ViamChannel::WebRTC(chan) => ctx
270+
.runtime
271+
.as_ref()
272+
.map(|rt| rt.block_on(async move { chan.close().await }))
273+
.unwrap_or_default(),
256274
}
257-
None => {}
258275
}
259276
log::debug!("Freeing rust runtime");
260277
0

src/ffi/spatialmath/vector3.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,42 @@ pub unsafe extern "C" fn free_vector_memory(ptr: *mut Vector3) {
1313
if ptr.is_null() {
1414
return;
1515
}
16-
Box::from_raw(ptr);
16+
let _ = Box::from_raw(ptr);
1717
}
1818

1919
#[no_mangle]
2020
pub unsafe extern "C" fn vector_get_components(vec_ptr: *const Vector3) -> *const c_double {
2121
null_pointer_check!(vec_ptr);
2222
let vec = *vec_ptr;
23-
let components: [c_double;3] = [vec.x, vec.y, vec.z];
23+
let components: [c_double; 3] = [vec.x, vec.y, vec.z];
2424
Box::into_raw(Box::new(components)) as *const _
2525
}
2626

2727
#[no_mangle]
2828
pub unsafe extern "C" fn vector_set_x(vec_ptr: *mut Vector3, x_val: f64) {
2929
null_pointer_check!(vec_ptr);
30-
let vec = &mut*vec_ptr;
30+
let vec = &mut *vec_ptr;
3131
vec.x = x_val
3232
}
3333

3434
#[no_mangle]
3535
pub unsafe extern "C" fn vector_set_y(vec_ptr: *mut Vector3, y_val: f64) {
3636
null_pointer_check!(vec_ptr);
37-
let vec = &mut*vec_ptr;
37+
let vec = &mut *vec_ptr;
3838
vec.y = y_val
3939
}
4040

4141
#[no_mangle]
4242
pub unsafe extern "C" fn vector_set_z(vec_ptr: *mut Vector3, z_val: f64) {
4343
null_pointer_check!(vec_ptr);
44-
let vec = &mut*vec_ptr;
44+
let vec = &mut *vec_ptr;
4545
vec.z = z_val
4646
}
4747

4848
#[no_mangle]
4949
pub unsafe extern "C" fn normalize_vector(vec_ptr: *mut Vector3) {
5050
null_pointer_check!(vec_ptr);
51-
let vec = &mut*vec_ptr;
51+
let vec = &mut *vec_ptr;
5252
vec.normalize();
5353
}
5454

@@ -63,7 +63,7 @@ pub unsafe extern "C" fn vector_get_normalized(vec_ptr: *const Vector3) -> *mut
6363
#[no_mangle]
6464
pub unsafe extern "C" fn scale_vector(vec_ptr: *mut Vector3, factor: f64) {
6565
null_pointer_check!(vec_ptr);
66-
let vec = &mut*vec_ptr;
66+
let vec = &mut *vec_ptr;
6767
vec.scale(factor);
6868
}
6969

@@ -77,7 +77,8 @@ pub unsafe extern "C" fn vector_get_scaled(vec_ptr: *const Vector3, factor: f64)
7777

7878
#[no_mangle]
7979
pub unsafe extern "C" fn vector_add(
80-
vec_ptr_1: *const Vector3, vec_ptr_2: *const Vector3
80+
vec_ptr_1: *const Vector3,
81+
vec_ptr_2: *const Vector3,
8182
) -> *mut Vector3 {
8283
null_pointer_check!(vec_ptr_1);
8384
null_pointer_check!(vec_ptr_2);
@@ -88,7 +89,8 @@ pub unsafe extern "C" fn vector_add(
8889

8990
#[no_mangle]
9091
pub unsafe extern "C" fn vector_subtract(
91-
vec_ptr_1: *const Vector3, vec_ptr_2: *const Vector3
92+
vec_ptr_1: *const Vector3,
93+
vec_ptr_2: *const Vector3,
9294
) -> *mut Vector3 {
9395
null_pointer_check!(vec_ptr_1);
9496
null_pointer_check!(vec_ptr_2);
@@ -98,7 +100,10 @@ pub unsafe extern "C" fn vector_subtract(
98100
}
99101

100102
#[no_mangle]
101-
pub unsafe extern "C" fn vector_dot_product(vec_ptr_1: *const Vector3, vec_ptr_2: *const Vector3) -> f64 {
103+
pub unsafe extern "C" fn vector_dot_product(
104+
vec_ptr_1: *const Vector3,
105+
vec_ptr_2: *const Vector3,
106+
) -> f64 {
102107
null_pointer_check!(vec_ptr_1, f64::NAN);
103108
null_pointer_check!(vec_ptr_2, f64::NAN);
104109
let vec1 = &*vec_ptr_1;
@@ -108,12 +113,13 @@ pub unsafe extern "C" fn vector_dot_product(vec_ptr_1: *const Vector3, vec_ptr_2
108113

109114
#[no_mangle]
110115
pub unsafe extern "C" fn vector_cross_product(
111-
vec_ptr_1: *mut Vector3, vec_ptr_2: *mut Vector3
116+
vec_ptr_1: *mut Vector3,
117+
vec_ptr_2: *mut Vector3,
112118
) -> *mut Vector3 {
113119
null_pointer_check!(vec_ptr_1);
114120
null_pointer_check!(vec_ptr_2);
115-
let vec1 = &mut*vec_ptr_1;
116-
let vec2 = &mut*vec_ptr_2;
121+
let vec1 = &mut *vec_ptr_1;
122+
let vec2 = &mut *vec_ptr_2;
117123
let vec = vec1.cross(vec2);
118124
vec.to_raw_pointer()
119125
}

src/rpc/base_channel.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
use anyhow::Result;
2-
use std::sync::{
3-
atomic::{AtomicBool, AtomicPtr, Ordering},
4-
Arc,
2+
use std::{
3+
fmt::Debug,
4+
sync::{
5+
atomic::{AtomicBool, AtomicPtr, Ordering},
6+
Arc,
7+
},
58
};
69
use webrtc::{data_channel::RTCDataChannel, peer_connection::RTCPeerConnection};
710

@@ -14,16 +17,28 @@ pub struct WebRTCBaseChannel {
1417
closed: AtomicBool,
1518
}
1619

20+
impl Debug for WebRTCBaseChannel {
21+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22+
f.debug_struct("WebRTCBaseChannel")
23+
.field("Peer connection id", &self.peer_connection.get_stats_id())
24+
.field("Data channel id", &self.data_channel.id())
25+
.finish()
26+
}
27+
}
28+
1729
impl WebRTCBaseChannel {
1830
pub(crate) async fn new(
1931
peer_connection: Arc<RTCPeerConnection>,
2032
data_channel: Arc<RTCDataChannel>,
2133
) -> Arc<Self> {
2234
let dc = data_channel.clone();
23-
let pc = peer_connection.clone();
35+
let pc = Arc::downgrade(&peer_connection);
2436
peer_connection
2537
.on_ice_connection_state_change(Box::new(move |conn_state| {
26-
let pc = pc.clone();
38+
let pc = match pc.upgrade(){
39+
Some(pc) => pc,
40+
None => return Box::pin(async {}),
41+
};
2742
Box::pin(async move {
2843
let sctp = pc.sctp();
2944
let transport = sctp.transport();
@@ -44,9 +59,12 @@ impl WebRTCBaseChannel {
4459
closed: AtomicBool::new(false),
4560
});
4661

47-
let c = channel.clone();
62+
let c = Arc::downgrade(&channel);
4863
dc.on_error(Box::new(move |err: webrtc::Error| {
49-
let c = c.clone();
64+
let c = match c.upgrade() {
65+
Some(c) => c,
66+
None => return Box::pin(async {}),
67+
};
5068
Box::pin(async move {
5169
if let Err(e) = c.close_with_reason(Some(anyhow::Error::from(err))).await {
5270
log::error!("error closing channel: {e}")

src/rpc/client_channel.rs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@ use anyhow::Result;
77
use chashmap::CHashMap;
88
use hyper::Body;
99
use prost::Message;
10-
use std::sync::{
11-
atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering},
12-
Arc,
10+
use std::{
11+
fmt::Debug,
12+
sync::{
13+
atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering},
14+
Arc,
15+
},
1316
};
1417
use webrtc::{
1518
data_channel::{data_channel_message::DataChannelMessage, RTCDataChannel},
@@ -27,7 +30,35 @@ pub struct WebRTCClientChannel {
2730
pub(crate) receiver_bodies: CHashMap<u64, hyper::Body>,
2831
}
2932

33+
impl Debug for WebRTCClientChannel {
34+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35+
f.debug_struct("WebRTCClientChannel")
36+
.field("stream_id_counter", &self.stream_id_counter)
37+
.field("base channel", &self.base_channel)
38+
.finish()
39+
}
40+
}
41+
42+
impl Drop for WebRTCClientChannel {
43+
fn drop(&mut self) {
44+
let bc = self.base_channel.clone();
45+
if !bc.is_closed() {
46+
let _ = tokio::spawn(async move {
47+
if let Err(e) = bc.close().await {
48+
log::error!("Error closing base channel: {e}");
49+
}
50+
});
51+
};
52+
log::debug!("Dropping client channel {:?}", &self);
53+
}
54+
}
55+
3056
impl WebRTCClientChannel {
57+
pub async fn close(&self) {
58+
self.base_channel.close().await.unwrap();
59+
self.base_channel.data_channel.close().await.unwrap();
60+
self.base_channel.peer_connection.close().await.unwrap();
61+
}
3162
pub(crate) async fn new(
3263
peer_connection: Arc<RTCPeerConnection>,
3364
data_channel: Arc<RTCDataChannel>,
@@ -42,17 +73,25 @@ impl WebRTCClientChannel {
4273

4374
let channel = Arc::new(channel);
4475
let ret_channel = channel.clone();
76+
let channel = Arc::downgrade(&channel);
4577

4678
data_channel
4779
.on_message(Box::new(move |msg: DataChannelMessage| {
4880
let channel = channel.clone();
4981
Box::pin(async move {
82+
let channel = match channel.upgrade() {
83+
Some(channel) => channel,
84+
None => {
85+
return;
86+
}
87+
};
5088
if let Err(e) = channel.on_channel_message(msg).await {
5189
log::error!("error deserializing message: {e}");
5290
}
5391
})
5492
}))
5593
.await;
94+
log::debug!("Client channel created");
5695
ret_channel
5796
}
5897

0 commit comments

Comments
 (0)