3333import json
3434import logging
3535import os
36+ import threading
37+
3638import psutil
3739from typing import Any , Callable , Dict , List , Optional , Union
3840
@@ -209,6 +211,14 @@ def __init__(
209211 self .task_counter = 0
210212 self .loop = None
211213 self .start_independent_lsp_process = start_independent_lsp_process
214+
215+ # Add thread locks for shared resources to prevent race conditions
216+ self ._stdin_lock = threading .Lock ()
217+ self ._request_id_lock = threading .Lock ()
218+ self ._response_handlers_lock = threading .Lock ()
219+ self ._tasks_lock = threading .Lock ()
220+
221+
212222
213223 def is_running (self ) -> bool :
214224 """
@@ -244,10 +254,14 @@ async def start(self) -> None:
244254 raise RuntimeError (f"Process terminated immediately with code { self .process .returncode } . Error: { error_message } " )
245255
246256 self .loop = asyncio .get_event_loop ()
247- self .tasks [self .task_counter ] = self .loop .create_task (self .run_forever ())
248- self .task_counter += 1
249- self .tasks [self .task_counter ] = self .loop .create_task (self .run_forever_stderr ())
250- self .task_counter += 1
257+
258+ # Use lock to prevent race conditions on tasks and task_counter during startup
259+ with self ._tasks_lock :
260+ self .tasks [self .task_counter ] = self .loop .create_task (self .run_forever ())
261+ self .task_counter += 1
262+ self .tasks [self .task_counter ] = self .loop .create_task (self .run_forever_stderr ())
263+ self .task_counter += 1
264+
251265
252266
253267 async def stop (self ) -> None :
@@ -269,18 +283,23 @@ async def stop(self) -> None:
269283 async def _cancel_pending_tasks (self ):
270284 """Cancel all pending tasks and wait for them to complete or timeout."""
271285 pending_tasks = []
272- for task in self .tasks .values ():
273- if not task .done ():
274- task .cancel ()
275- pending_tasks .append (task )
286+
287+ # Use lock to safely access tasks dictionary
288+ with self ._tasks_lock :
289+ for task in self .tasks .values ():
290+ if not task .done ():
291+ task .cancel ()
292+ pending_tasks .append (task )
276293
277294 if pending_tasks :
278295 try :
279296 await asyncio .wait_for (asyncio .gather (* pending_tasks , return_exceptions = True ), timeout = 5.0 )
280297 except (asyncio .TimeoutError , Exception ):
281298 pass
282-
283- self .tasks = {}
299+
300+ # Clear tasks dictionary under lock
301+ with self ._tasks_lock :
302+ self .tasks = {}
284303
285304 async def _cleanup_process (self , process ):
286305 """Clean up a process: close stdin, terminate/kill process, close stdout/stderr."""
@@ -406,12 +425,15 @@ async def run_forever(self) -> bool:
406425 continue
407426 body = await self .process .stdout .readexactly (num_bytes )
408427
409- self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (self ._handle_body (body ))
410- self .task_counter += 1
428+ # Use lock to prevent race conditions on tasks and task_counter
429+ with self ._tasks_lock :
430+ self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (self ._handle_body (body ))
431+ self .task_counter += 1
411432 except (BrokenPipeError , ConnectionResetError , StopLoopException ):
412433 pass
413434 return self ._received_shutdown
414435
436+
415437 async def run_forever_stderr (self ) -> None :
416438 """
417439 Continuously read from the language server process stderr and log the messages
@@ -467,28 +489,40 @@ def send_response(self, request_id: Any, params: PayloadLike) -> None:
467489 """
468490 Send response to the given request id to the server with the given parameters
469491 """
470- self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (
471- self ._send_payload (make_response (request_id , params ))
472- )
473- self .task_counter += 1
492+ # Use lock to prevent race conditions on tasks and task_counter
493+ with self ._tasks_lock :
494+ self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (
495+ self ._send_payload (make_response (request_id , params ))
496+ )
497+ self .task_counter += 1
498+
474499
475500 def send_error_response (self , request_id : Any , err : Error ) -> None :
476501 """
477502 Send error response to the given request id to the server with the given error
478503 """
479- self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (
480- self ._send_payload (make_error_response (request_id , err ))
481- )
482- self .task_counter += 1
504+ # Use lock to prevent race conditions on tasks and task_counter
505+ with self ._tasks_lock :
506+ self .tasks [self .task_counter ] = asyncio .get_event_loop ().create_task (
507+ self ._send_payload (make_error_response (request_id , err ))
508+ )
509+ self .task_counter += 1
510+
483511
484512 async def send_request (self , method : str , params : Optional [dict ] = None ) -> PayloadLike :
485513 """
486514 Send request to the server, register the request id, and wait for the response
487515 """
488516 request = Request ()
489- request_id = self .request_id
490- self .request_id += 1
491- self ._response_handlers [request_id ] = request
517+
518+ # Use lock to prevent race conditions on request_id and _response_handlers
519+ with self ._request_id_lock :
520+ request_id = self .request_id
521+ self .request_id += 1
522+
523+ with self ._response_handlers_lock :
524+ self ._response_handlers [request_id ] = request
525+
492526 async with request .cv :
493527 await self ._send_payload (make_request (method , request_id , params ))
494528 self ._log (f"Waiting for asyncio condition for request { method } with params:\n { params } " )
@@ -499,6 +533,7 @@ async def send_request(self, method: str, params: Optional[dict] = None) -> Payl
499533 self ._log (f"Returning non-error result, which is:\n { request .result } " )
500534 return request .result
501535
536+
502537 def _send_payload_sync (self , payload : StringDict ) -> None :
503538 """
504539 Send the payload to the server by writing to its stdin synchronously
@@ -508,7 +543,17 @@ def _send_payload_sync(self, payload: StringDict) -> None:
508543 msg = create_message (payload )
509544 if self .logger :
510545 self .logger ("client" , "server" , payload )
511- self .process .stdin .writelines (msg )
546+
547+ # Use lock to prevent concurrent writes to stdin that cause buffer corruption
548+ with self ._stdin_lock :
549+ try :
550+ self .process .stdin .writelines (msg )
551+ except (BrokenPipeError , ConnectionResetError , OSError ) as e :
552+ # Log the error but don't raise to prevent cascading failures
553+ if self .logger :
554+ self .logger ("client" , "logger" , f"Failed to write to stdin: { e } " )
555+ return
556+
512557
513558 async def _send_payload (self , payload : StringDict ) -> None :
514559 """
@@ -518,8 +563,18 @@ async def _send_payload(self, payload: StringDict) -> None:
518563 return
519564 self ._log (payload )
520565 msg = create_message (payload )
521- self .process .stdin .writelines (msg )
522- await self .process .stdin .drain ()
566+
567+ # Use lock to prevent concurrent writes to stdin that cause buffer corruption
568+ with self ._stdin_lock :
569+ try :
570+ self .process .stdin .writelines (msg )
571+ await self .process .stdin .drain ()
572+ except (BrokenPipeError , ConnectionResetError , OSError ) as e :
573+ # Log the error but don't raise to prevent cascading failures
574+ if self .logger :
575+ self .logger ("client" , "logger" , f"Failed to write to stdin: { e } " )
576+ return
577+
523578
524579 def on_request (self , method : str , cb ) -> None :
525580 """
@@ -537,14 +592,17 @@ async def _response_handler(self, response: StringDict) -> None:
537592 """
538593 Handle the response received from the server for a request, using the id to determine the request
539594 """
540- request = self ._response_handlers .pop (response ["id" ])
595+ with self ._response_handlers_lock :
596+ request = self ._response_handlers .pop (response ["id" ])
597+
541598 if "result" in response and "error" not in response :
542599 await request .on_result (response ["result" ])
543600 elif "result" not in response and "error" in response :
544601 await request .on_error (Error .from_lsp (response ["error" ]))
545602 else :
546603 await request .on_error (Error (ErrorCodes .InvalidRequest , "" ))
547604
605+
548606 async def _request_handler (self , response : StringDict ) -> None :
549607 """
550608 Handle the request received from the server: call the appropriate callback function and return the result
0 commit comments