Skip to content

Commit fd1f2a1

Browse files
committed
fix udp: more reliable protocol
1 parent 7d9e605 commit fd1f2a1

File tree

10 files changed

+190
-78
lines changed

10 files changed

+190
-78
lines changed

Android/app/src/main/java/io/github/teamclouday/AndroidMic/domain/streaming/UdpStreamer.kt

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,41 @@ class UdpStreamer(private val scope: CoroutineScope, val ip: String, var port: I
2525

2626
override fun connect(): Boolean {
2727
socket.soTimeout = 1500
28-
return true
28+
29+
val message = Message.MessageWrapper.newBuilder()
30+
.setConnect(
31+
Message.ConnectMessage.newBuilder()
32+
.build()
33+
)
34+
.build()
35+
36+
val pack = message.toByteArray()
37+
val combined = pack.size.toBigEndianU32() + pack
38+
39+
val packet = DatagramPacket(
40+
combined,
41+
0,
42+
combined.size,
43+
address,
44+
port
45+
)
46+
47+
try {
48+
socket.send(packet)
49+
} catch (_: Exception) {
50+
return false
51+
}
52+
53+
val buff = ByteArray(CHECK_2.length)
54+
val recvPacket = DatagramPacket(buff, buff.size)
55+
56+
try {
57+
socket.receive(recvPacket)
58+
} catch (_: Exception) {
59+
return false
60+
}
61+
62+
return recvPacket.data.contentEquals(CHECK_2.toByteArray())
2963
}
3064

3165
override fun disconnect(): Boolean {
@@ -45,18 +79,23 @@ class UdpStreamer(private val scope: CoroutineScope, val ip: String, var port: I
4579
streamJob = scope.launch {
4680
audioStream.collect { data ->
4781
try {
48-
val message = Message.AudioPacketMessageOrdered.newBuilder()
49-
.setSequenceNumber(sequenceIdx++)
82+
83+
val message = Message.MessageWrapper.newBuilder()
5084
.setAudioPacket(
51-
Message.AudioPacketMessage.newBuilder()
52-
.setBuffer(ByteString.copyFrom(data.buffer))
53-
.setSampleRate(data.sampleRate)
54-
.setAudioFormat(data.audioFormat)
55-
.setChannelCount(data.channelCount)
85+
Message.AudioPacketMessageOrdered.newBuilder()
86+
.setSequenceNumber(sequenceIdx++)
87+
.setAudioPacket(
88+
Message.AudioPacketMessage.newBuilder()
89+
.setBuffer(ByteString.copyFrom(data.buffer))
90+
.setSampleRate(data.sampleRate)
91+
.setAudioFormat(data.audioFormat)
92+
.setChannelCount(data.channelCount)
93+
)
5694
.build()
5795
)
5896
.build()
5997

98+
6099
val pack = message.toByteArray()
61100
val combined = pack.size.toBigEndianU32() + pack
62101

@@ -65,7 +104,7 @@ class UdpStreamer(private val scope: CoroutineScope, val ip: String, var port: I
65104
0,
66105
combined.size,
67106
address,
68-
port!!
107+
port
69108
)
70109

71110
socket.send(packet)
Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
syntax = "proto3";
22

33
message AudioPacketMessage {
4-
bytes buffer = 1;
5-
uint32 sample_rate = 2;
6-
uint32 channel_count = 3;
7-
uint32 audio_format = 4;
4+
bytes buffer = 1;
5+
uint32 sample_rate = 2;
6+
uint32 channel_count = 3;
7+
uint32 audio_format = 4;
88
}
99

1010
message AudioPacketMessageOrdered {
11-
uint32 sequence_number = 1;
12-
AudioPacketMessage audio_packet = 2;
11+
uint32 sequence_number = 1;
12+
AudioPacketMessage audio_packet = 2;
13+
}
14+
15+
message ConnectMessage {}
16+
17+
message MessageWrapper {
18+
oneof payload {
19+
AudioPacketMessageOrdered audio_packet = 1;
20+
ConnectMessage connect = 2;
21+
}
1322
}

RustApp/src/proto/message.proto

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
syntax = "proto3";
22

3-
package message;
4-
53
message AudioPacketMessage {
6-
bytes buffer = 1;
7-
uint32 sample_rate = 2;
8-
uint32 channel_count = 3;
9-
uint32 audio_format = 4;
4+
bytes buffer = 1;
5+
uint32 sample_rate = 2;
6+
uint32 channel_count = 3;
7+
uint32 audio_format = 4;
108
}
119

1210
message AudioPacketMessageOrdered {
13-
uint32 sequence_number = 1;
14-
AudioPacketMessage audio_packet = 2;
11+
uint32 sequence_number = 1;
12+
AudioPacketMessage audio_packet = 2;
13+
}
14+
15+
message ConnectMessage {}
16+
17+
message MessageWrapper {
18+
oneof payload {
19+
AudioPacketMessageOrdered audio_packet = 1;
20+
ConnectMessage connect = 2;
21+
}
1522
}

RustApp/src/streamer/adb_streamer.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use tokio::process::Command;
33

44
use crate::{
55
config::ConnectionMode,
6-
streamer::{StreamerMsg, DEFAULT_PC_PORT, tcp_streamer},
6+
streamer::{DEFAULT_PC_PORT, StreamerMsg, tcp_streamer},
77
};
88

99
use super::{
@@ -68,12 +68,8 @@ async fn exec_cmd(mut cmd: Command) -> Result<String, ConnectError> {
6868
}
6969

7070
pub async fn new(stream_config: AudioStream) -> Result<AdbStreamer, ConnectError> {
71-
let tcp_streamer = tcp_streamer::new(
72-
"127.0.0.1".parse().unwrap(),
73-
DEFAULT_PC_PORT,
74-
stream_config,
75-
)
76-
.await?;
71+
let tcp_streamer =
72+
tcp_streamer::new("127.0.0.1".parse().unwrap(), DEFAULT_PC_PORT, stream_config).await?;
7773

7874
let devices = get_connected_devices().await?;
7975
if devices.is_empty() {

RustApp/src/streamer/message.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,20 @@ pub struct AudioPacketMessageOrdered {
1717
#[prost(message, optional, tag = "2")]
1818
pub audio_packet: ::core::option::Option<AudioPacketMessage>,
1919
}
20+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
21+
pub struct ConnectMessage {}
22+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
23+
pub struct MessageWrapper {
24+
#[prost(oneof = "message_wrapper::Payload", tags = "1, 2")]
25+
pub payload: ::core::option::Option<message_wrapper::Payload>,
26+
}
27+
/// Nested message and enum types in `MessageWrapper`.
28+
pub mod message_wrapper {
29+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
30+
pub enum Payload {
31+
#[prost(message, tag = "1")]
32+
AudioPacket(super::AudioPacketMessageOrdered),
33+
#[prost(message, tag = "2")]
34+
Connect(super::ConnectMessage),
35+
}
36+
}

RustApp/src/streamer/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ mod usb_streamer;
2424
#[cfg(feature = "usb")]
2525
use crate::streamer::usb_streamer::UsbStreamer;
2626

27-
pub use message::{AudioPacketMessage, AudioPacketMessageOrdered};
27+
pub use message::AudioPacketMessage;
2828
pub use streamer_runner::{ConnectOption, StreamerCommand, StreamerMsg, sub};
2929

3030
use crate::{audio::AudioProcessParams, config::AudioFormat};
@@ -90,8 +90,8 @@ enum Streamer {
9090

9191
#[derive(Debug, Error)]
9292
enum ConnectError {
93-
#[error("can't bind a port on the pc: {0}")]
94-
CantBindPort(io::Error),
93+
#[error("can't bind port {0} on the pc: {1}")]
94+
CantBindPort(u16, io::Error),
9595
#[error("can't find a local address: {0}")]
9696
NoLocalAddress(io::Error),
9797
#[error("accept failed: {0}")]

RustApp/src/streamer/tcp_streamer.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ pub enum TcpStreamerState {
3737
},
3838
}
3939

40-
pub async fn new(ip: IpAddr, port: u16, stream_config: AudioStream) -> Result<TcpStreamer, ConnectError> {
40+
pub async fn new(
41+
ip: IpAddr,
42+
port: u16,
43+
stream_config: AudioStream,
44+
) -> Result<TcpStreamer, ConnectError> {
4145
let listener = TcpListener::bind((ip, port))
4246
.await
43-
.map_err(ConnectError::CantBindPort)?;
47+
.map_err(|e| ConnectError::CantBindPort(port, e))?;
4448

4549
let addr = TcpListener::local_addr(&listener).map_err(ConnectError::NoLocalAddress)?;
4650

RustApp/src/streamer/udp_streamer.rs

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ use tokio_util::{codec::LengthDelimitedCodec, udp::UdpFramed};
77

88
use crate::{
99
config::ConnectionMode,
10-
streamer::{AudioPacketMessage, WriteError},
10+
streamer::{
11+
AudioPacketMessage, CHECK_2, WriteError,
12+
message::{MessageWrapper, message_wrapper::Payload},
13+
},
1114
};
1215

13-
use super::{AudioPacketMessageOrdered, AudioStream, ConnectError, StreamerMsg, StreamerTrait};
16+
use super::{AudioStream, ConnectError, StreamerMsg, StreamerTrait};
1417

1518
const MAX_WAIT_TIME: Duration = Duration::from_millis(1500);
1619

@@ -25,10 +28,14 @@ pub struct UdpStreamer {
2528
tracked_sequence: u32,
2629
}
2730

28-
pub async fn new(ip: IpAddr, port: u16, stream_config: AudioStream) -> Result<UdpStreamer, ConnectError> {
31+
pub async fn new(
32+
ip: IpAddr,
33+
port: u16,
34+
stream_config: AudioStream,
35+
) -> Result<UdpStreamer, ConnectError> {
2936
let socket = UdpSocket::bind((ip, port))
3037
.await
31-
.map_err(ConnectError::CantBindPort)?;
38+
.map_err(|e| ConnectError::CantBindPort(port, e))?;
3239

3340
let addr = socket.local_addr().map_err(ConnectError::NoLocalAddress)?;
3441

@@ -66,48 +73,77 @@ impl StreamerTrait for UdpStreamer {
6673

6774
async fn next(&mut self) -> Result<Option<StreamerMsg>, ConnectError> {
6875
match tokio::time::timeout(
69-
Duration::from_secs(if self.is_listening { Duration::MAX.as_secs() } else { 1 }),
76+
Duration::from_secs(if self.is_listening {
77+
Duration::MAX.as_secs()
78+
} else {
79+
1
80+
}),
7081
self.framed.next(),
7182
)
7283
.await
7384
{
7485
Ok(res) => match res {
7586
Some(Ok((frame, addr))) => {
76-
match AudioPacketMessageOrdered::decode(frame) {
87+
match MessageWrapper::decode(frame) {
7788
Ok(packet) => {
78-
if self.is_listening {
79-
self.is_listening = false;
80-
return Ok(Some(StreamerMsg::Connected {
81-
ip: Some(self.ip),
82-
port: Some(self.port),
83-
mode: ConnectionMode::Udp,
84-
}));
85-
}
86-
87-
if packet.sequence_number < self.tracked_sequence {
88-
// drop packet
89-
info!(
90-
"dropped packet: old sequence number {} < {}",
91-
packet.sequence_number, self.tracked_sequence
92-
);
93-
}
94-
self.tracked_sequence = packet.sequence_number;
95-
96-
let packet = packet.audio_packet.unwrap();
97-
let buffer_size = packet.buffer.len();
98-
let sample_rate = packet.sample_rate;
99-
100-
match self.stream_config.process_audio_packet(packet) {
101-
Ok(Some(buffer)) => {
102-
debug!("From {:?}, received {} bytes", addr, buffer_size);
103-
Ok(Some(StreamerMsg::UpdateAudioWave {
104-
data: AudioPacketMessage::to_wave_data(
105-
&buffer,
106-
sample_rate,
107-
),
108-
}))
89+
match packet.payload {
90+
Some(payload) => {
91+
let message = match payload {
92+
Payload::AudioPacket(packet) => {
93+
if packet.sequence_number < self.tracked_sequence {
94+
// drop packet
95+
info!(
96+
"dropped packet: old sequence number {} < {}",
97+
packet.sequence_number, self.tracked_sequence
98+
);
99+
}
100+
self.tracked_sequence = packet.sequence_number;
101+
102+
let packet = packet.audio_packet.unwrap();
103+
let buffer_size = packet.buffer.len();
104+
let sample_rate = packet.sample_rate;
105+
106+
match self.stream_config.process_audio_packet(packet) {
107+
Ok(Some(buffer)) => {
108+
debug!(
109+
"From {:?}, received {} bytes",
110+
addr, buffer_size
111+
);
112+
Some(StreamerMsg::UpdateAudioWave {
113+
data: AudioPacketMessage::to_wave_data(
114+
&buffer,
115+
sample_rate,
116+
),
117+
})
118+
}
119+
_ => None,
120+
}
121+
}
122+
Payload::Connect(_) => {
123+
self.framed
124+
.get_ref()
125+
.send_to(CHECK_2.as_bytes(), &addr)
126+
.await
127+
.map_err(|e| {
128+
ConnectError::HandShakeFailed("writing", e)
129+
})?;
130+
131+
None
132+
}
133+
};
134+
135+
if self.is_listening {
136+
self.is_listening = false;
137+
Ok(Some(StreamerMsg::Connected {
138+
ip: Some(self.ip),
139+
port: Some(self.port),
140+
mode: ConnectionMode::Udp,
141+
}))
142+
} else {
143+
Ok(message)
144+
}
109145
}
110-
_ => Ok(None),
146+
None => todo!(),
111147
}
112148
}
113149
Err(e) => Err(ConnectError::WriteError(WriteError::Deserializer(e))),

RustApp/src/ui/app.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,10 @@ impl AppState {
211211
return self.add_log(e);
212212
};
213213

214-
ConnectOption::Tcp { ip, port: config.port }
214+
ConnectOption::Tcp {
215+
ip,
216+
port: config.port,
217+
}
215218
}
216219
ConnectionMode::Udp => {
217220
let Some(ip) = config.ip_or_default() else {
@@ -220,7 +223,10 @@ impl AppState {
220223
error!("failed to start audio stream: {e}");
221224
return self.add_log(e);
222225
};
223-
ConnectOption::Udp { ip, port: config.port }
226+
ConnectOption::Udp {
227+
ip,
228+
port: config.port,
229+
}
224230
}
225231
#[cfg(feature = "adb")]
226232
ConnectionMode::Adb => ConnectOption::Adb,

0 commit comments

Comments
 (0)