Skip to content

Commit d156d19

Browse files
committed
Lock sending payloads in LSServerHandler
1 parent 3a20b07 commit d156d19

1 file changed

Lines changed: 85 additions & 27 deletions

File tree

src/multilspy/lsp_protocol_handler/server.py

Lines changed: 85 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import json
3434
import logging
3535
import os
36+
import threading
37+
3638
import psutil
3739
from typing import Any, Callable, Dict, List, Optional, Union
3840

@@ -209,6 +211,14 @@ def __init__(
209211
self.task_counter = 0
210212
self.loop = None
211213
self.start_independent_lsp_process = start_independent_lsp_process
214+
215+
# Add thread locks for shared resources to prevent race conditions
216+
self._stdin_lock = threading.Lock()
217+
self._request_id_lock = threading.Lock()
218+
self._response_handlers_lock = threading.Lock()
219+
self._tasks_lock = threading.Lock()
220+
221+
212222

213223
def is_running(self) -> bool:
214224
"""
@@ -244,10 +254,14 @@ async def start(self) -> None:
244254
raise RuntimeError(f"Process terminated immediately with code {self.process.returncode}. Error: {error_message}")
245255

246256
self.loop = asyncio.get_event_loop()
247-
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever())
248-
self.task_counter += 1
249-
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr())
250-
self.task_counter += 1
257+
258+
# Use lock to prevent race conditions on tasks and task_counter during startup
259+
with self._tasks_lock:
260+
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever())
261+
self.task_counter += 1
262+
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr())
263+
self.task_counter += 1
264+
251265

252266

253267
async def stop(self) -> None:
@@ -269,18 +283,23 @@ async def stop(self) -> None:
269283
async def _cancel_pending_tasks(self):
270284
"""Cancel all pending tasks and wait for them to complete or timeout."""
271285
pending_tasks = []
272-
for task in self.tasks.values():
273-
if not task.done():
274-
task.cancel()
275-
pending_tasks.append(task)
286+
287+
# Use lock to safely access tasks dictionary
288+
with self._tasks_lock:
289+
for task in self.tasks.values():
290+
if not task.done():
291+
task.cancel()
292+
pending_tasks.append(task)
276293

277294
if pending_tasks:
278295
try:
279296
await asyncio.wait_for(asyncio.gather(*pending_tasks, return_exceptions=True), timeout=5.0)
280297
except (asyncio.TimeoutError, Exception):
281298
pass
282-
283-
self.tasks = {}
299+
300+
# Clear tasks dictionary under lock
301+
with self._tasks_lock:
302+
self.tasks = {}
284303

285304
async def _cleanup_process(self, process):
286305
"""Clean up a process: close stdin, terminate/kill process, close stdout/stderr."""
@@ -406,12 +425,15 @@ async def run_forever(self) -> bool:
406425
continue
407426
body = await self.process.stdout.readexactly(num_bytes)
408427

409-
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body))
410-
self.task_counter += 1
428+
# Use lock to prevent race conditions on tasks and task_counter
429+
with self._tasks_lock:
430+
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body))
431+
self.task_counter += 1
411432
except (BrokenPipeError, ConnectionResetError, StopLoopException):
412433
pass
413434
return self._received_shutdown
414435

436+
415437
async def run_forever_stderr(self) -> None:
416438
"""
417439
Continuously read from the language server process stderr and log the messages
@@ -467,28 +489,40 @@ def send_response(self, request_id: Any, params: PayloadLike) -> None:
467489
"""
468490
Send response to the given request id to the server with the given parameters
469491
"""
470-
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
471-
self._send_payload(make_response(request_id, params))
472-
)
473-
self.task_counter += 1
492+
# Use lock to prevent race conditions on tasks and task_counter
493+
with self._tasks_lock:
494+
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
495+
self._send_payload(make_response(request_id, params))
496+
)
497+
self.task_counter += 1
498+
474499

475500
def send_error_response(self, request_id: Any, err: Error) -> None:
476501
"""
477502
Send error response to the given request id to the server with the given error
478503
"""
479-
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
480-
self._send_payload(make_error_response(request_id, err))
481-
)
482-
self.task_counter += 1
504+
# Use lock to prevent race conditions on tasks and task_counter
505+
with self._tasks_lock:
506+
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
507+
self._send_payload(make_error_response(request_id, err))
508+
)
509+
self.task_counter += 1
510+
483511

484512
async def send_request(self, method: str, params: Optional[dict] = None) -> PayloadLike:
485513
"""
486514
Send request to the server, register the request id, and wait for the response
487515
"""
488516
request = Request()
489-
request_id = self.request_id
490-
self.request_id += 1
491-
self._response_handlers[request_id] = request
517+
518+
# Use lock to prevent race conditions on request_id and _response_handlers
519+
with self._request_id_lock:
520+
request_id = self.request_id
521+
self.request_id += 1
522+
523+
with self._response_handlers_lock:
524+
self._response_handlers[request_id] = request
525+
492526
async with request.cv:
493527
await self._send_payload(make_request(method, request_id, params))
494528
self._log(f"Waiting for asyncio condition for request {method} with params:\n{params}")
@@ -499,6 +533,7 @@ async def send_request(self, method: str, params: Optional[dict] = None) -> Payl
499533
self._log(f"Returning non-error result, which is:\n{request.result}")
500534
return request.result
501535

536+
502537
def _send_payload_sync(self, payload: StringDict) -> None:
503538
"""
504539
Send the payload to the server by writing to its stdin synchronously
@@ -508,7 +543,17 @@ def _send_payload_sync(self, payload: StringDict) -> None:
508543
msg = create_message(payload)
509544
if self.logger:
510545
self.logger("client", "server", payload)
511-
self.process.stdin.writelines(msg)
546+
547+
# Use lock to prevent concurrent writes to stdin that cause buffer corruption
548+
with self._stdin_lock:
549+
try:
550+
self.process.stdin.writelines(msg)
551+
except (BrokenPipeError, ConnectionResetError, OSError) as e:
552+
# Log the error but don't raise to prevent cascading failures
553+
if self.logger:
554+
self.logger("client", "logger", f"Failed to write to stdin: {e}")
555+
return
556+
512557

513558
async def _send_payload(self, payload: StringDict) -> None:
514559
"""
@@ -518,8 +563,18 @@ async def _send_payload(self, payload: StringDict) -> None:
518563
return
519564
self._log(payload)
520565
msg = create_message(payload)
521-
self.process.stdin.writelines(msg)
522-
await self.process.stdin.drain()
566+
567+
# Use lock to prevent concurrent writes to stdin that cause buffer corruption
568+
with self._stdin_lock:
569+
try:
570+
self.process.stdin.writelines(msg)
571+
await self.process.stdin.drain()
572+
except (BrokenPipeError, ConnectionResetError, OSError) as e:
573+
# Log the error but don't raise to prevent cascading failures
574+
if self.logger:
575+
self.logger("client", "logger", f"Failed to write to stdin: {e}")
576+
return
577+
523578

524579
def on_request(self, method: str, cb) -> None:
525580
"""
@@ -537,14 +592,17 @@ async def _response_handler(self, response: StringDict) -> None:
537592
"""
538593
Handle the response received from the server for a request, using the id to determine the request
539594
"""
540-
request = self._response_handlers.pop(response["id"])
595+
with self._response_handlers_lock:
596+
request = self._response_handlers.pop(response["id"])
597+
541598
if "result" in response and "error" not in response:
542599
await request.on_result(response["result"])
543600
elif "result" not in response and "error" in response:
544601
await request.on_error(Error.from_lsp(response["error"]))
545602
else:
546603
await request.on_error(Error(ErrorCodes.InvalidRequest, ""))
547604

605+
548606
async def _request_handler(self, response: StringDict) -> None:
549607
"""
550608
Handle the request received from the server: call the appropriate callback function and return the result

0 commit comments

Comments
 (0)