@@ -63,28 +63,31 @@ def __init__(
6363 self .config = config
6464 self .auto_retry = auto_retry
6565 self .mu = threading .Lock ()
66- self .conn : HLLConnection | None = None
66+ self .conns : dict [ int , HLLConnection ] = {}
6767
6868 @contextmanager
6969 def with_connection (self ) -> Generator [HLLConnection , None , None ]:
70- # Not sure if multithreading is still a thing we use...
71- logger .debug ("Waiting to acquire lock %s" , threading .get_ident ())
70+ # TODO: Cleanup connections from threads that no longer exist
71+
72+ thread_id = threading .get_ident ()
73+ logger .debug ("Waiting to acquire lock %s" , thread_id )
7274 if not self .mu .acquire (timeout = 30 ):
7375 raise TimeoutError ()
7476
7577 try :
76- if self .conn is None :
77- self .conn = HLLConnection ()
78+ conn = self .conns .get (thread_id )
79+ if conn is None :
80+ conn = HLLConnection ()
7881 try :
79- self ._connect (self . conn )
82+ self ._connect (conn )
8083 except Exception :
81- self .conn = None
8284 raise
85+ self .conns [thread_id ] = conn
8386 finally :
8487 self .mu .release ()
8588
8689 try :
87- yield self . conn
90+ yield conn
8891 except Exception as e :
8992 # All other errors, that might be caught (like UnicodeDecodeError) do not really qualify as an error of the
9093 # connection itself. Instead of reconnecting the existing connection here (conditionally), we simply discard
@@ -94,22 +97,38 @@ def with_connection(self) -> Generator[HLLConnection, None, None]:
9497 ):
9598 logger .warning (
9699 "Connection (%s) errored in thread %s: %s, removing" ,
97- self . conn .id ,
98- threading . get_ident () ,
100+ conn .id ,
101+ thread_id ,
99102 e ,
100103 )
101- self .conn .close ()
102- self .conn = None
104+
105+ if not self .mu .acquire (timeout = 30 ):
106+ raise TimeoutError ()
107+
108+ try :
109+ conn .close ()
110+ self .conns .pop (thread_id , None )
111+ finally :
112+ self .mu .release ()
113+
103114 raise
104115
105116 elif exception_in_chain (e , HLLBrokenConnectionError ):
106117 logger .warning (
107118 "Connection (%s) marked as broken in thread %s, removing" ,
108- self . conn .id ,
109- threading . get_ident () ,
119+ conn .id ,
120+ thread_id ,
110121 )
111- self .conn .close ()
112- self .conn = None
122+
123+ if not self .mu .acquire (timeout = 30 ):
124+ raise TimeoutError ()
125+
126+ try :
127+ conn .close ()
128+ self .conns .pop (thread_id , None )
129+ finally :
130+ self .mu .release ()
131+
113132 if e .__context__ is not None :
114133 raise e .__context__
115134 raise
@@ -281,19 +300,7 @@ def get_name(self) -> str:
281300 return self .exchange ("GetServerInformation" , 2 , {"Name" : "session" , "Value" : "" }).content_dict ["serverName" ]
282301
283302 def get_map (self ) -> str :
284- # TODO: Currently returns pretty name instead of map name, f.e. "CARENTAN" instead of "carentan_warfare"
285- session = self .exchange ("GetServerInformation" , 2 , {"Name" : "session" , "Value" : "" }).content_dict
286- layer = next (
287- (
288- l for l in LAYERS .values ()
289- if l .map .name == session ["mapName" ]
290- and l .game_mode == GameMode (session ["gameMode" ].lower ())
291- ),
292- None ,
293- )
294- if not layer :
295- layer = LAYERS [UNKNOWN_MAP_NAME ]
296- return layer .id
303+ return self .get_gamestate ()["current_map" ]["id" ]
297304
298305 def get_maps (self ) -> list [str ]:
299306 details = self .exchange ("GetClientReferenceData" , 2 , "AddMapToRotation" )
@@ -374,14 +381,14 @@ def get_autobalance_enabled(self) -> bool:
374381
375382 def get_logs (
376383 self ,
377- since_min_ago : str | int ,
384+ since_min_ago : int ,
378385 filter_ : str = "" ,
379386 conn : HLLConnection | None = None ,
380387 ) -> list [str ]:
381388 return [
382389 entry ["message" ]
383390 for entry in self .exchange ("GetAdminLog" , 2 , {
384- "LogBackTrackTime" : since_min_ago ,
391+ "LogBackTrackTime" : since_min_ago * 60 ,
385392 "Filters" : filter_
386393 }, conn = conn ).content_dict ["entries" ]
387394 ]
@@ -585,6 +592,18 @@ def get_gamestate(self) -> GameStateType:
585592 seconds_remaining = int (time_remaining .total_seconds ())
586593 raw_time_remaining = f"{ seconds_remaining // 3600 } :{ (seconds_remaining // 60 ) % 60 :02} :{ seconds_remaining % 60 :02} "
587594
595+ game_mode = GameMode (s ["gameMode" ].lower ())
596+ current_map = next (
597+ (
598+ l for l in LAYERS .values ()
599+ if l .map .name == s ["mapName" ]
600+ and l .game_mode == game_mode
601+ ),
602+ None ,
603+ )
604+ if not current_map :
605+ current_map = LAYERS [UNKNOWN_MAP_NAME ]
606+
588607 # TODO: next_map is not included in session, map_name is pretty name instead of ID
589608 return GameStateType (
590609 next_map = LAYERS [UNKNOWN_MAP_NAME ].model_dump (),
@@ -594,22 +613,10 @@ def get_gamestate(self) -> GameStateType:
594613 allied_score = s ["alliedScore" ],
595614 allied_faction = s ["alliedFaction" ],
596615 num_allied_players = s ["alliedPlayerCount" ],
597- current_map = LayerType (
598- id = s ["mapName" ],
599- map = next (
600- (m for m in MAPS .values () if m .name == s ["mapName" ]),
601- MAPS [UNKNOWN_MAP_NAME ]
602- ).model_dump (),
603- game_mode = s ["gameMode" ].lower (),
604- attackers = None ,
605- environment = Environment .DAY ,
606- pretty_name = s ["mapName" ].capitalize (),
607- image_name = "" ,
608- image_url = "" ,
609- ),
616+ current_map = current_map .model_dump (),
610617 raw_time_remaining = raw_time_remaining ,
611618 time_remaining = time_remaining ,
612- game_mode = GameMode ( s [ "gameMode" ]. lower ()) ,
619+ game_mode = game_mode ,
613620 match_time = s ["matchTime" ],
614621 queue_count = s ["queueCount" ],
615622 max_queue_count = s ["maxQueueCount" ],
0 commit comments