Skip to content

Commit 6b4941d

Browse files
committed
Create heartbeat device outside of mutex lock
1 parent 7fc799e commit 6b4941d

File tree

3 files changed

+121
-42
lines changed

3 files changed

+121
-42
lines changed

src/devices.rs

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
// jkcoxson
22

3-
use std::{collections::HashMap, io::Read, net::IpAddr, path::PathBuf, sync::Arc};
3+
use std::{collections::HashMap, io::Read, net::IpAddr, path::PathBuf};
44

55
use log::{debug, info, trace, warn};
6-
use tokio::{
7-
io::AsyncReadExt,
8-
sync::{oneshot::Sender, Mutex},
9-
};
10-
11-
use crate::heartbeat;
6+
use tokio::{io::AsyncReadExt, sync::oneshot::Sender};
127

138
pub struct SharedDevices {
149
pub devices: HashMap<String, MuxerDevice>,
1510
pub last_index: u64,
1611
pub last_interface_index: u64,
1712
plist_storage: String,
18-
use_heartbeat: bool,
1913
known_mac_addresses: HashMap<String, String>,
2014
paired_udids: Vec<String>,
2115
}
@@ -39,7 +33,7 @@ pub struct MuxerDevice {
3933
}
4034

4135
impl SharedDevices {
42-
pub async fn new(plist_storage: Option<String>, use_heartbeat: bool) -> Self {
36+
pub async fn new(plist_storage: Option<String>) -> Self {
4337
let plist_storage = if let Some(plist_storage) = plist_storage {
4438
info!("Plist storage specified, ensure the environment is aware");
4539
plist_storage
@@ -70,7 +64,6 @@ impl SharedDevices {
7064
last_index: 0,
7165
last_interface_index: 0,
7266
plist_storage,
73-
use_heartbeat,
7467
known_mac_addresses: HashMap::new(),
7568
paired_udids: Vec::new(),
7669
}
@@ -81,22 +74,14 @@ impl SharedDevices {
8174
network_address: IpAddr,
8275
service_name: String,
8376
connection_type: String,
84-
data: Arc<Mutex<Self>>,
77+
heartbeat_handle: Option<Sender<()>>,
8578
) -> Result<(), Box<dyn std::error::Error>> {
8679
if self.devices.contains_key(&udid) {
8780
trace!("Device has already been added, skipping");
8881
return Ok(());
8982
}
9083
self.last_index += 1;
9184
self.last_interface_index += 1;
92-
let pairing_file = self.get_pairing_record(udid.clone()).await?;
93-
let pairing_file = idevice::pairing_file::PairingFile::from_bytes(&pairing_file)?;
94-
95-
let handle = if self.use_heartbeat {
96-
Some(heartbeat::heartbeat(network_address, udid.clone(), pairing_file, data).await?)
97-
} else {
98-
None
99-
};
10085

10186
let dev = MuxerDevice {
10287
connection_type,
@@ -105,7 +90,7 @@ impl SharedDevices {
10590
interface_index: self.last_interface_index,
10691
network_address: Some(network_address),
10792
serial_number: udid.clone(),
108-
heartbeat_handle: handle,
93+
heartbeat_handle,
10994
connection_speed: None,
11095
location_id: None,
11196
product_id: None,
@@ -141,7 +126,7 @@ impl SharedDevices {
141126
}
142127
};
143128
}
144-
pub async fn get_pairing_record(&self, udid: String) -> Result<Vec<u8>, std::io::Error> {
129+
pub async fn get_pairing_record(&self, udid: &String) -> Result<Vec<u8>, std::io::Error> {
145130
let path = PathBuf::from(self.plist_storage.clone()).join(format!("{}.plist", udid));
146131
info!("Attempting to read pairing file: {path:?}");
147132
if !path.exists() {

src/main.rs

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{fs, os::unix::prelude::PermissionsExt};
66

77
use crate::raw_packet::RawPacket;
88
use devices::SharedDevices;
9+
use idevice::pairing_file::PairingFile;
910
use log::{debug, error, info, trace, warn};
1011
use tokio::{
1112
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
@@ -110,9 +111,7 @@ async fn main() {
110111
}
111112
info!("Collected arguments, proceeding");
112113

113-
let data = Arc::new(Mutex::new(
114-
devices::SharedDevices::new(plist_storage, use_heartbeat).await,
115-
));
114+
let data = Arc::new(Mutex::new(devices::SharedDevices::new(plist_storage).await));
116115
info!("Created new central data");
117116
let data_clone = data.clone();
118117

@@ -137,7 +136,7 @@ async fn main() {
137136
}
138137
};
139138

140-
handle_stream(socket, data.clone()).await;
139+
handle_stream(socket, data.clone(), use_heartbeat).await;
141140
}
142141
});
143142
}
@@ -168,15 +167,15 @@ async fn main() {
168167
}
169168
};
170169

171-
handle_stream(socket, data.clone()).await;
170+
handle_stream(socket, data.clone(), use_heartbeat).await;
172171
}
173172
});
174173
}
175174

176175
if use_mdns {
177176
let local = tokio::task::LocalSet::new();
178177
local.spawn_local(async move {
179-
mdns::discover(data_clone).await;
178+
mdns::discover(data_clone, use_heartbeat).await;
180179
error!("mDNS discovery stopped, how the heck did you break this");
181180
});
182181
local.await;
@@ -197,6 +196,7 @@ enum Directions {
197196
async fn handle_stream(
198197
mut socket: impl AsyncRead + AsyncWrite + Unpin + Send + 'static,
199198
data: Arc<Mutex<SharedDevices>>,
199+
use_heartbeat: bool,
200200
) {
201201
tokio::spawn(async move {
202202
let mut current_directions = Directions::None;
@@ -297,6 +297,14 @@ async fn handle_stream(
297297
}
298298
};
299299

300+
let ip_address = match ip_address.parse() {
301+
Ok(i) => i,
302+
Err(_) => {
303+
warn!("Bad IP requested: {ip_address}");
304+
return;
305+
}
306+
};
307+
300308
let udid = match parsed.plist.get("DeviceID") {
301309
Some(plist::Value::String(u)) => u,
302310
_ => {
@@ -305,21 +313,50 @@ async fn handle_stream(
305313
}
306314
};
307315

308-
let mut central_data = data.lock().await;
309-
let ip_address = match ip_address.parse() {
310-
Ok(i) => i,
311-
Err(_) => {
312-
warn!("Bad IP requested: {ip_address}");
313-
return;
314-
}
316+
let heartbeat_handle = if use_heartbeat {
317+
let pairing_file =
318+
match data.lock().await.get_pairing_record(udid).await {
319+
Ok(p) => match PairingFile::from_bytes(&p) {
320+
Ok(p) => p,
321+
Err(e) => {
322+
log::error!("Failed to parse pair record: {e:?}");
323+
return;
324+
}
325+
},
326+
Err(e) => {
327+
log::error!(
328+
"Failed to get pairing file for device: {e:?}"
329+
);
330+
return;
331+
}
332+
};
333+
let heartbeat_handle = match heartbeat::heartbeat(
334+
ip_address,
335+
udid.clone(),
336+
pairing_file,
337+
data.clone(),
338+
)
339+
.await
340+
{
341+
Ok(h) => h,
342+
Err(e) => {
343+
warn!("Failed to start heartbeat: {e:?}");
344+
return;
345+
}
346+
};
347+
Some(heartbeat_handle)
348+
} else {
349+
None
315350
};
351+
352+
let mut central_data = data.lock().await;
316353
let res = match central_data
317354
.add_network_device(
318355
udid.to_owned(),
319356
ip_address,
320357
service_name.to_owned(),
321358
connection_type.to_owned(),
322-
data.clone(),
359+
heartbeat_handle,
323360
)
324361
.await
325362
{
@@ -391,7 +428,7 @@ async fn handle_stream(
391428
let lock = data.lock().await;
392429
let pair_file = match lock
393430
.get_pairing_record(match parsed.plist.get("PairRecordID") {
394-
Some(plist::Value::String(p)) => p.to_owned(),
431+
Some(plist::Value::String(p)) => p,
395432
_ => {
396433
warn!("Request did not contain PairRecordID");
397434
return;

src/mdns.rs

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// jkcoxson
22

3-
use crate::devices::SharedDevices;
3+
use crate::{devices::SharedDevices, heartbeat};
4+
use idevice::pairing_file::PairingFile;
45
use log::{info, warn};
56
use std::net::IpAddr;
67
use std::sync::Arc;
7-
88
use tokio::sync::Mutex;
99

1010
#[cfg(not(feature = "zeroconf"))]
@@ -24,7 +24,7 @@ const SERVICE_NAME: &str = "apple-mobdev2";
2424
const SERVICE_PROTOCOL: &str = "tcp";
2525

2626
#[cfg(feature = "zeroconf")]
27-
pub async fn discover(data: Arc<Mutex<SharedDevices>>) {
27+
pub async fn discover(data: Arc<Mutex<SharedDevices>>, use_heartbeat: bool) {
2828
let service_name = format!("_{}._{}.local", SERVICE_NAME, SERVICE_PROTOCOL);
2929
println!("Starting mDNS discovery for {} with zeroconf", service_name);
3030

@@ -66,13 +66,42 @@ pub async fn discover(data: Arc<Mutex<SharedDevices>>) {
6666
}
6767
println!("Adding device {}", udid);
6868

69+
let heartbeat_handle = if use_heartbeat {
70+
let pairing_file = match data.lock().await.get_pairing_record(&udid).await {
71+
Ok(p) => match PairingFile::from_bytes(&p) {
72+
Ok(p) => p,
73+
Err(e) => {
74+
log::error!("Failed to parse pair record: {e:?}");
75+
continue;
76+
}
77+
},
78+
Err(e) => {
79+
log::error!("Failed to get pairing file for device: {e:?}");
80+
continue;
81+
}
82+
};
83+
let heartbeat_handle =
84+
match heartbeat::heartbeat(addr, udid.clone(), pairing_file, data.clone())
85+
.await
86+
{
87+
Ok(h) => h,
88+
Err(e) => {
89+
warn!("Failed to start heartbeat: {e:?}");
90+
return;
91+
}
92+
};
93+
Some(heartbeat_handle)
94+
} else {
95+
None
96+
};
97+
6998
if let Err(e) = lock
7099
.add_network_device(
71100
udid.clone(),
72101
addr,
73102
service_name.clone(),
74103
"Network".to_string(),
75-
data.clone(),
104+
heartbeat_handle,
76105
)
77106
.await
78107
{
@@ -84,7 +113,7 @@ pub async fn discover(data: Arc<Mutex<SharedDevices>>) {
84113
}
85114

86115
#[cfg(not(feature = "zeroconf"))]
87-
pub async fn discover(data: Arc<Mutex<SharedDevices>>) {
116+
pub async fn discover(data: Arc<Mutex<SharedDevices>>, use_heartbeat: bool) {
88117
use log::warn;
89118

90119
let service_name = format!("_{}._{}.local", SERVICE_NAME, SERVICE_PROTOCOL);
@@ -125,14 +154,42 @@ pub async fn discover(data: Arc<Mutex<SharedDevices>>) {
125154
continue;
126155
}
127156
println!("Adding device {}", udid);
157+
let heartbeat_handle = if use_heartbeat {
158+
let pairing_file = match data.lock().await.get_pairing_record(&udid).await {
159+
Ok(p) => match PairingFile::from_bytes(&p) {
160+
Ok(p) => p,
161+
Err(e) => {
162+
log::error!("Failed to parse pair record: {e:?}");
163+
continue;
164+
}
165+
},
166+
Err(e) => {
167+
log::error!("Failed to get pairing file for device: {e:?}");
168+
continue;
169+
}
170+
};
171+
let heartbeat_handle =
172+
match heartbeat::heartbeat(addr, udid.clone(), pairing_file, data.clone())
173+
.await
174+
{
175+
Ok(h) => h,
176+
Err(e) => {
177+
warn!("Failed to start heartbeat: {e:?}");
178+
return;
179+
}
180+
};
181+
Some(heartbeat_handle)
182+
} else {
183+
None
184+
};
128185

129186
if let Err(e) = lock
130187
.add_network_device(
131188
udid.clone(),
132189
addr,
133190
service_name.clone(),
134191
"Network".to_string(),
135-
data.clone(),
192+
heartbeat_handle,
136193
)
137194
.await
138195
{

0 commit comments

Comments
 (0)