@@ -12,7 +12,6 @@ use std::{
1212use tokio:: {
1313 io:: { AsyncReadExt , AsyncWriteExt , BufStream } ,
1414 net:: { TcpListener , TcpStream , ToSocketAddrs , UdpSocket } ,
15- sync:: mpsc:: { self , Receiver } ,
1615} ;
1716use trust_dns_proto:: {
1817 op:: { Message , Query , ResponseCode :: NoError } ,
@@ -102,77 +101,23 @@ async fn main() -> Result<()> {
102101}
103102
104103async fn udp_thread ( opt : CmdOpt , user_key : Option < UserKey > , cache : Cache < Vec < Query > , Message > , timeout : Duration ) -> Result < ( ) > {
105- let udp_listener = Arc :: new ( UdpSocket :: bind ( & opt. listen_addr ) . await ?) ;
104+ let listener = Arc :: new ( UdpSocket :: bind ( & opt. listen_addr ) . await ?) ;
106105 log:: info!( "Udp listening on: {}" , opt. listen_addr) ;
107- let ( sender, mut receiver) = mpsc:: channel :: < ( SocketAddr , Vec < u8 > ) > ( 1024 ) ;
108-
109- let listener = udp_listener. clone ( ) ;
110-
111- // to avoid move semantic occurs, we defined a function instead of a closure
112- async fn channel_end (
113- receiver : & mut Receiver < ( SocketAddr , Vec < u8 > ) > ,
114- opt : & CmdOpt ,
115- cache : & Cache < Vec < Query > , Message > ,
116- listener : & Arc < UdpSocket > ,
117- user_key : & Option < UserKey > ,
118- timeout : Duration ,
119- ) -> Result < ( ) > {
120- while let Some ( ( src, mut buf) ) = receiver. recv ( ) . await {
121- let message = parse_data_to_dns_message ( & buf, false ) ?;
122- let domain = extract_domain_from_dns_message ( & message) ?;
123-
124- if opt. cache_records {
125- if let Some ( cached_message) = dns_cache_get_message ( cache, & message) {
126- let data = cached_message. to_vec ( ) . map_err ( |e| e. to_string ( ) ) ?;
127- listener. send_to ( & data, & src) . await ?;
128- log_dns_message ( "DNS query via UDP cache hit" , & domain, & cached_message) ;
129- continue ;
130- }
131- }
132-
133- let proxy_addr = opt. socks5_server ;
134- let udp_server_addr = opt. dns_remote_server ;
135- let auth = user_key. clone ( ) ;
136-
137- let data = if opt. force_tcp {
138- let mut new_buf = ( buf. len ( ) as u16 ) . to_be_bytes ( ) . to_vec ( ) ;
139- new_buf. append ( & mut buf) ;
140- tcp_via_socks5_server ( proxy_addr, udp_server_addr, auth, & new_buf, timeout) . await ?
141- } else {
142- client:: UdpClientImpl :: datagram ( proxy_addr, udp_server_addr, auth)
143- . await ?
144- . transfer_data ( & buf, timeout)
145- . await ?
146- } ;
147- let message = parse_data_to_dns_message ( & data, opt. force_tcp ) ?;
148- let msg_buf = message. to_vec ( ) . map_err ( |e| e. to_string ( ) ) ?;
149-
150- listener. send_to ( & msg_buf, & src) . await ?;
151-
152- log_dns_message ( "DNS query via UDP" , & domain, & message) ;
153- if opt. cache_records {
154- dns_cache_put_message ( cache, & message) . await ;
155- }
156- }
157- Ok :: < ( ) , Error > ( ( ) )
158- }
159-
160- tokio:: spawn ( async move {
161- loop {
162- if let Err ( e) = channel_end ( & mut receiver, & opt, & cache, & listener, & user_key, timeout) . await {
163- log:: error!( "UDP channel_end thread error \" {}\" " , e) ;
164- }
165- }
166- } ) ;
167106
168107 loop {
169- let udp_listener = udp_listener. clone ( ) ;
170- let sender = sender. clone ( ) ;
108+ let listener = listener. clone ( ) ;
109+ let opt = opt. clone ( ) ;
110+ let cache = cache. clone ( ) ;
111+ let auth = user_key. clone ( ) ;
171112 let block = async move {
172113 let mut buf = vec ! [ 0u8 ; MAX_BUFFER_SIZE ] ;
173- let ( len, src) = udp_listener . recv_from ( & mut buf) . await ?;
114+ let ( len, src) = listener . recv_from ( & mut buf) . await ?;
174115 buf. resize ( len, 0 ) ;
175- sender. send ( ( src, buf) ) . await . map_err ( |e| e. to_string ( ) ) ?;
116+ tokio:: spawn ( async move {
117+ if let Err ( e) = udp_incoming_handler ( listener, buf, src, opt, cache, auth, timeout) . await {
118+ log:: error!( "DNS query via UDP incoming handler error \" {}\" " , e) ;
119+ }
120+ } ) ;
176121 Ok :: < ( ) , Error > ( ( ) )
177122 } ;
178123 if let Err ( e) = block. await {
@@ -181,6 +126,56 @@ async fn udp_thread(opt: CmdOpt, user_key: Option<UserKey>, cache: Cache<Vec<Que
181126 }
182127}
183128
129+ async fn udp_incoming_handler (
130+ listener : Arc < UdpSocket > ,
131+ mut buf : Vec < u8 > ,
132+ src : SocketAddr ,
133+ opt : CmdOpt ,
134+ cache : Cache < Vec < Query > , Message > ,
135+ auth : Option < UserKey > ,
136+ timeout : Duration ,
137+ ) -> Result < ( ) > {
138+ let message = parse_data_to_dns_message ( & buf, false ) ?;
139+ let domain = extract_domain_from_dns_message ( & message) ?;
140+
141+ if opt. cache_records {
142+ if let Some ( cached_message) = dns_cache_get_message ( & cache, & message) {
143+ let data = cached_message. to_vec ( ) . map_err ( |e| e. to_string ( ) ) ?;
144+ listener. send_to ( & data, & src) . await ?;
145+ log_dns_message ( "DNS query via UDP cache hit" , & domain, & cached_message) ;
146+ return Ok ( ( ) ) ;
147+ }
148+ }
149+
150+ let proxy_addr = opt. socks5_server ;
151+ let udp_server_addr = opt. dns_remote_server ;
152+
153+ let data = if opt. force_tcp {
154+ let mut new_buf = ( buf. len ( ) as u16 ) . to_be_bytes ( ) . to_vec ( ) ;
155+ new_buf. append ( & mut buf) ;
156+ tcp_via_socks5_server ( proxy_addr, udp_server_addr, auth, & new_buf, timeout)
157+ . await
158+ . map_err ( |e| format ! ( "querying \" {domain}\" {e}" ) ) ?
159+ } else {
160+ client:: UdpClientImpl :: datagram ( proxy_addr, udp_server_addr, auth)
161+ . await
162+ . map_err ( |e| format ! ( "preparing to query \" {domain}\" {e}" ) ) ?
163+ . transfer_data ( & buf, timeout)
164+ . await
165+ . map_err ( |e| format ! ( "querying \" {domain}\" {e}" ) ) ?
166+ } ;
167+ let message = parse_data_to_dns_message ( & data, opt. force_tcp ) ?;
168+ let msg_buf = message. to_vec ( ) . map_err ( |e| e. to_string ( ) ) ?;
169+
170+ listener. send_to ( & msg_buf, & src) . await ?;
171+
172+ log_dns_message ( "DNS query via UDP" , & domain, & message) ;
173+ if opt. cache_records {
174+ dns_cache_put_message ( & cache, & message) . await ;
175+ }
176+ Ok :: < ( ) , Error > ( ( ) )
177+ }
178+
184179async fn tcp_thread ( opt : CmdOpt , user_key : Option < UserKey > , cache : Cache < Vec < Query > , Message > , timeout : Duration ) -> Result < ( ) > {
185180 let listener = TcpListener :: bind ( & opt. listen_addr ) . await ?;
186181 log:: info!( "TCP listening on: {}" , opt. listen_addr) ;
0 commit comments