11pub mod collect_servers;
22pub mod migrate_mongo_to_postgres;
33
4- use std:: { collections:: HashSet , net:: Ipv4Addr , sync:: Arc , time:: Duration } ;
4+ use std:: {
5+ collections:: HashSet ,
6+ fmt:: { self , Display } ,
7+ net:: Ipv4Addr ,
8+ ops:: Deref ,
9+ str:: FromStr ,
10+ sync:: Arc ,
11+ time:: Duration ,
12+ } ;
513
614use futures_util:: stream:: StreamExt ;
715use lru_cache:: LruCache ;
816use parking_lot:: Mutex ;
917use rustc_hash:: FxHashMap ;
10- use sqlx:: { PgPool , Row } ;
18+ use sqlx:: {
19+ PgPool , Postgres , Row ,
20+ encode:: IsNull ,
21+ error:: BoxDynError ,
22+ postgres:: { PgArgumentBuffer , PgTypeInfo , PgValueFormat , PgValueRef } ,
23+ } ;
1124use tracing:: { error, info} ;
1225
1326use crate :: database:: collect_servers:: CollectServersCache ;
@@ -80,8 +93,8 @@ impl Database {
8093 . fetch_all ( & self . pool )
8194 . await ?;
8295 for row in rows {
83- let ip = Ipv4Addr :: from_bits ( row. get :: < i32 , _ > ( 0 ) as u32 ) ;
84- let allowed_port = row. get :: < i16 , _ > ( 1 ) as u16 ;
96+ let ip = Ipv4Addr :: from_bits ( row. get :: < PgU32 , _ > ( 0 ) . 0 ) ;
97+ let allowed_port = row. get :: < PgU16 , _ > ( 1 ) . 0 ;
8598 aliased_ips_to_allowed_port. insert ( ip, allowed_port) ;
8699 }
87100 self . shared . lock ( ) . aliased_ips_to_allowed_port = aliased_ips_to_allowed_port;
@@ -102,14 +115,14 @@ impl Database {
102115 let mut txn = self . pool . begin ( ) . await ?;
103116
104117 sqlx:: query ( "INSERT INTO ips_with_aliased_servers (ip, allowed_port) VALUES ($1, $2)" )
105- . bind ( ip. to_bits ( ) as i32 )
106- . bind ( allowed_port as i16 )
118+ . bind ( PgU32 ( ip. to_bits ( ) ) )
119+ . bind ( PgU16 ( allowed_port) )
107120 . execute ( & mut * txn)
108121 . await ?;
109122 // delete all servers with this ip that aren't on the allowed port
110123 let delete_res = sqlx:: query ( "DELETE FROM servers WHERE ip = $1 AND port != $2" )
111- . bind ( ip. to_bits ( ) as i32 )
112- . bind ( allowed_port as i16 )
124+ . bind ( PgU32 ( ip. to_bits ( ) ) )
125+ . bind ( PgU16 ( allowed_port) )
113126 . execute ( & mut * txn)
114127 . await ?;
115128 let deleted_count = delete_res. rows_affected ( ) ;
@@ -146,7 +159,7 @@ impl Database {
146159 . fetch ( & self . pool ) ;
147160
148161 while let Some ( Ok ( row) ) = rows. next ( ) . await {
149- let ip = row. get :: < i32 , _ > ( 0 ) ;
162+ let ip = row. get :: < PgU32 , _ > ( 0 ) . 0 ;
150163 let player_count = row. get :: < i64 , _ > ( 1 ) ;
151164
152165 let delete_count = player_count - KEEP_PLAYER_COUNT ;
@@ -160,7 +173,7 @@ impl Database {
160173 )
161174 " ,
162175 )
163- . bind ( ip )
176+ . bind ( PgU32 ( ip ) )
164177 . bind ( delete_count)
165178 . execute ( & self . pool )
166179 . await ?;
@@ -174,3 +187,79 @@ impl Database {
174187pub fn sanitize_text_for_postgres ( s : & str ) -> String {
175188 s. replace ( '\0' , "" )
176189}
190+
191+ pub struct PgU32 ( pub u32 ) ;
192+ impl Deref for PgU32 {
193+ type Target = u32 ;
194+ fn deref ( & self ) -> & Self :: Target {
195+ & self . 0
196+ }
197+ }
198+ impl Display for PgU32 {
199+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
200+ write ! ( f, "{}" , self . 0 )
201+ }
202+ }
203+ impl FromStr for PgU32 {
204+ type Err = std:: num:: ParseIntError ;
205+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
206+ Ok ( Self ( s. parse ( ) ?) )
207+ }
208+ }
209+ impl sqlx:: Type < Postgres > for PgU32 {
210+ fn type_info ( ) -> PgTypeInfo {
211+ PgTypeInfo :: with_name ( "uint4" )
212+ }
213+ }
214+ impl sqlx:: Decode < ' _ , Postgres > for PgU32 {
215+ fn decode ( value : PgValueRef < ' _ > ) -> Result < Self , BoxDynError > {
216+ Ok ( match value. format ( ) {
217+ PgValueFormat :: Binary => Self ( u32:: from_be_bytes ( value. as_bytes ( ) ?. try_into ( ) ?) ) ,
218+ PgValueFormat :: Text => value. as_str ( ) ?. parse ( ) ?,
219+ } )
220+ }
221+ }
222+ impl sqlx:: Encode < ' _ , Postgres > for PgU32 {
223+ fn encode_by_ref ( & self , buf : & mut PgArgumentBuffer ) -> Result < IsNull , BoxDynError > {
224+ buf. extend ( & self . to_be_bytes ( ) ) ;
225+ Ok ( IsNull :: No )
226+ }
227+ }
228+
229+ pub struct PgU16 ( pub u16 ) ;
230+ impl Deref for PgU16 {
231+ type Target = u16 ;
232+ fn deref ( & self ) -> & Self :: Target {
233+ & self . 0
234+ }
235+ }
236+ impl Display for PgU16 {
237+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
238+ write ! ( f, "{}" , self . 0 )
239+ }
240+ }
241+ impl FromStr for PgU16 {
242+ type Err = std:: num:: ParseIntError ;
243+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
244+ Ok ( Self ( s. parse ( ) ?) )
245+ }
246+ }
247+ impl sqlx:: Type < Postgres > for PgU16 {
248+ fn type_info ( ) -> PgTypeInfo {
249+ PgTypeInfo :: with_name ( "uint2" )
250+ }
251+ }
252+ impl sqlx:: Decode < ' _ , Postgres > for PgU16 {
253+ fn decode ( value : PgValueRef < ' _ > ) -> Result < Self , BoxDynError > {
254+ Ok ( match value. format ( ) {
255+ PgValueFormat :: Binary => Self ( u16:: from_be_bytes ( value. as_bytes ( ) ?. try_into ( ) ?) ) ,
256+ PgValueFormat :: Text => value. as_str ( ) ?. parse ( ) ?,
257+ } )
258+ }
259+ }
260+ impl sqlx:: Encode < ' _ , Postgres > for PgU16 {
261+ fn encode_by_ref ( & self , buf : & mut PgArgumentBuffer ) -> Result < IsNull , BoxDynError > {
262+ buf. extend ( & self . to_be_bytes ( ) ) ;
263+ Ok ( IsNull :: No )
264+ }
265+ }
0 commit comments