Skip to content

Commit e842f18

Browse files
committed
udp_incoming_handler
1 parent 5efbf00 commit e842f18

2 files changed

Lines changed: 62 additions & 67 deletions

File tree

examples/dns2socks.rs

Lines changed: 61 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use std::{
1212
use tokio::{
1313
io::{AsyncReadExt, AsyncWriteExt, BufStream},
1414
net::{TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
15-
sync::mpsc::{self, Receiver},
1615
};
1716
use trust_dns_proto::{
1817
op::{Message, Query, ResponseCode::NoError},
@@ -102,77 +101,23 @@ async fn main() -> Result<()> {
102101
}
103102

104103
async fn udp_thread(opt: CmdOpt, user_key: Option<UserKey>, cache: Cache<Vec<Query>, Message>, timeout: Duration) -> Result<()> {
105-
let udp_listener = Arc::new(UdpSocket::bind(&opt.listen_addr).await?);
104+
let listener = Arc::new(UdpSocket::bind(&opt.listen_addr).await?);
106105
log::info!("Udp listening on: {}", opt.listen_addr);
107-
let (sender, mut receiver) = mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
108-
109-
let listener = udp_listener.clone();
110-
111-
// to avoid move semantic occurs, we defined a function instead of a closure
112-
async fn channel_end(
113-
receiver: &mut Receiver<(SocketAddr, Vec<u8>)>,
114-
opt: &CmdOpt,
115-
cache: &Cache<Vec<Query>, Message>,
116-
listener: &Arc<UdpSocket>,
117-
user_key: &Option<UserKey>,
118-
timeout: Duration,
119-
) -> Result<()> {
120-
while let Some((src, mut buf)) = receiver.recv().await {
121-
let message = parse_data_to_dns_message(&buf, false)?;
122-
let domain = extract_domain_from_dns_message(&message)?;
123-
124-
if opt.cache_records {
125-
if let Some(cached_message) = dns_cache_get_message(cache, &message) {
126-
let data = cached_message.to_vec().map_err(|e| e.to_string())?;
127-
listener.send_to(&data, &src).await?;
128-
log_dns_message("DNS query via UDP cache hit", &domain, &cached_message);
129-
continue;
130-
}
131-
}
132-
133-
let proxy_addr = opt.socks5_server;
134-
let udp_server_addr = opt.dns_remote_server;
135-
let auth = user_key.clone();
136-
137-
let data = if opt.force_tcp {
138-
let mut new_buf = (buf.len() as u16).to_be_bytes().to_vec();
139-
new_buf.append(&mut buf);
140-
tcp_via_socks5_server(proxy_addr, udp_server_addr, auth, &new_buf, timeout).await?
141-
} else {
142-
client::UdpClientImpl::datagram(proxy_addr, udp_server_addr, auth)
143-
.await?
144-
.transfer_data(&buf, timeout)
145-
.await?
146-
};
147-
let message = parse_data_to_dns_message(&data, opt.force_tcp)?;
148-
let msg_buf = message.to_vec().map_err(|e| e.to_string())?;
149-
150-
listener.send_to(&msg_buf, &src).await?;
151-
152-
log_dns_message("DNS query via UDP", &domain, &message);
153-
if opt.cache_records {
154-
dns_cache_put_message(cache, &message).await;
155-
}
156-
}
157-
Ok::<(), Error>(())
158-
}
159-
160-
tokio::spawn(async move {
161-
loop {
162-
if let Err(e) = channel_end(&mut receiver, &opt, &cache, &listener, &user_key, timeout).await {
163-
log::error!("UDP channel_end thread error \"{}\"", e);
164-
}
165-
}
166-
});
167106

168107
loop {
169-
let udp_listener = udp_listener.clone();
170-
let sender = sender.clone();
108+
let listener = listener.clone();
109+
let opt = opt.clone();
110+
let cache = cache.clone();
111+
let auth = user_key.clone();
171112
let block = async move {
172113
let mut buf = vec![0u8; MAX_BUFFER_SIZE];
173-
let (len, src) = udp_listener.recv_from(&mut buf).await?;
114+
let (len, src) = listener.recv_from(&mut buf).await?;
174115
buf.resize(len, 0);
175-
sender.send((src, buf)).await.map_err(|e| e.to_string())?;
116+
tokio::spawn(async move {
117+
if let Err(e) = udp_incoming_handler(listener, buf, src, opt, cache, auth, timeout).await {
118+
log::error!("DNS query via UDP incoming handler error \"{}\"", e);
119+
}
120+
});
176121
Ok::<(), Error>(())
177122
};
178123
if let Err(e) = block.await {
@@ -181,6 +126,56 @@ async fn udp_thread(opt: CmdOpt, user_key: Option<UserKey>, cache: Cache<Vec<Que
181126
}
182127
}
183128

129+
async fn udp_incoming_handler(
130+
listener: Arc<UdpSocket>,
131+
mut buf: Vec<u8>,
132+
src: SocketAddr,
133+
opt: CmdOpt,
134+
cache: Cache<Vec<Query>, Message>,
135+
auth: Option<UserKey>,
136+
timeout: Duration,
137+
) -> Result<()> {
138+
let message = parse_data_to_dns_message(&buf, false)?;
139+
let domain = extract_domain_from_dns_message(&message)?;
140+
141+
if opt.cache_records {
142+
if let Some(cached_message) = dns_cache_get_message(&cache, &message) {
143+
let data = cached_message.to_vec().map_err(|e| e.to_string())?;
144+
listener.send_to(&data, &src).await?;
145+
log_dns_message("DNS query via UDP cache hit", &domain, &cached_message);
146+
return Ok(());
147+
}
148+
}
149+
150+
let proxy_addr = opt.socks5_server;
151+
let udp_server_addr = opt.dns_remote_server;
152+
153+
let data = if opt.force_tcp {
154+
let mut new_buf = (buf.len() as u16).to_be_bytes().to_vec();
155+
new_buf.append(&mut buf);
156+
tcp_via_socks5_server(proxy_addr, udp_server_addr, auth, &new_buf, timeout)
157+
.await
158+
.map_err(|e| format!("querying \"{domain}\" {e}"))?
159+
} else {
160+
client::UdpClientImpl::datagram(proxy_addr, udp_server_addr, auth)
161+
.await
162+
.map_err(|e| format!("preparing to query \"{domain}\" {e}"))?
163+
.transfer_data(&buf, timeout)
164+
.await
165+
.map_err(|e| format!("querying \"{domain}\" {e}"))?
166+
};
167+
let message = parse_data_to_dns_message(&data, opt.force_tcp)?;
168+
let msg_buf = message.to_vec().map_err(|e| e.to_string())?;
169+
170+
listener.send_to(&msg_buf, &src).await?;
171+
172+
log_dns_message("DNS query via UDP", &domain, &message);
173+
if opt.cache_records {
174+
dns_cache_put_message(&cache, &message).await;
175+
}
176+
Ok::<(), Error>(())
177+
}
178+
184179
async fn tcp_thread(opt: CmdOpt, user_key: Option<UserKey>, cache: Cache<Vec<Query>, Message>, timeout: Duration) -> Result<()> {
185180
let listener = TcpListener::bind(&opt.listen_addr).await?;
186181
log::info!("TCP listening on: {}", opt.listen_addr);

src/error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ pub enum Error {
3737
#[error("Utf8Error: {0}")]
3838
Utf8Error(#[from] std::str::Utf8Error),
3939

40-
#[error("String error: {0}")]
40+
#[error("{0}")]
4141
String(String),
4242
}
4343

0 commit comments

Comments
 (0)