diff --git a/src/multilspy/language_server.py b/src/multilspy/language_server.py index dd123c938..e910e43c3 100644 --- a/src/multilspy/language_server.py +++ b/src/multilspy/language_server.py @@ -227,7 +227,7 @@ def __init__( self.completions_available = asyncio.Event() if config.trace_lsp_communication: def logging_fn(source: str, target: str, msg: StringDict | str): - self.logger.log(f"LSP: {source} -> {target}: {str(msg)}", logging.DEBUG) + self.logger.log(f"LSP: {source} -> {target}: {str(msg)}", self.logger.logger.level) else: logging_fn = None diff --git a/src/multilspy/lsp_protocol_handler/server.py b/src/multilspy/lsp_protocol_handler/server.py index 3cba4d430..a22414332 100644 --- a/src/multilspy/lsp_protocol_handler/server.py +++ b/src/multilspy/lsp_protocol_handler/server.py @@ -33,6 +33,8 @@ import json import logging import os +import threading + import psutil from typing import Any, Callable, Dict, List, Optional, Union @@ -209,6 +211,14 @@ def __init__( self.task_counter = 0 self.loop = None self.start_independent_lsp_process = start_independent_lsp_process + + # Add thread locks for shared resources to prevent race conditions + self._stdin_lock = threading.Lock() + self._request_id_lock = threading.Lock() + self._response_handlers_lock = threading.Lock() + self._tasks_lock = threading.Lock() + + def is_running(self) -> bool: """ @@ -244,10 +254,14 @@ async def start(self) -> None: raise RuntimeError(f"Process terminated immediately with code {self.process.returncode}. Error: {error_message}") self.loop = asyncio.get_event_loop() - self.tasks[self.task_counter] = self.loop.create_task(self.run_forever()) - self.task_counter += 1 - self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr()) - self.task_counter += 1 + + # Use lock to prevent race conditions on tasks and task_counter during startup + with self._tasks_lock: + self.tasks[self.task_counter] = self.loop.create_task(self.run_forever()) + self.task_counter += 1 + self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr()) + self.task_counter += 1 + async def stop(self) -> None: @@ -269,18 +283,23 @@ async def stop(self) -> None: async def _cancel_pending_tasks(self): """Cancel all pending tasks and wait for them to complete or timeout.""" pending_tasks = [] - for task in self.tasks.values(): - if not task.done(): - task.cancel() - pending_tasks.append(task) + + # Use lock to safely access tasks dictionary + with self._tasks_lock: + for task in self.tasks.values(): + if not task.done(): + task.cancel() + pending_tasks.append(task) if pending_tasks: try: await asyncio.wait_for(asyncio.gather(*pending_tasks, return_exceptions=True), timeout=5.0) except (asyncio.TimeoutError, Exception): pass - - self.tasks = {} + + # Clear tasks dictionary under lock + with self._tasks_lock: + self.tasks = {} async def _cleanup_process(self, process): """Clean up a process: close stdin, terminate/kill process, close stdout/stderr.""" @@ -406,12 +425,15 @@ async def run_forever(self) -> bool: continue body = await self.process.stdout.readexactly(num_bytes) - self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body)) - self.task_counter += 1 + # Use lock to prevent race conditions on tasks and task_counter + with self._tasks_lock: + self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body)) + self.task_counter += 1 except (BrokenPipeError, ConnectionResetError, StopLoopException): pass return self._received_shutdown + async def run_forever_stderr(self) -> None: """ Continuously read from the language server process stderr and log the messages @@ -467,28 +489,40 @@ def send_response(self, request_id: Any, params: PayloadLike) -> None: """ Send response to the given request id to the server with the given parameters """ - self.tasks[self.task_counter] = asyncio.get_event_loop().create_task( - self._send_payload(make_response(request_id, params)) - ) - self.task_counter += 1 + # Use lock to prevent race conditions on tasks and task_counter + with self._tasks_lock: + self.tasks[self.task_counter] = asyncio.get_event_loop().create_task( + self._send_payload(make_response(request_id, params)) + ) + self.task_counter += 1 + def send_error_response(self, request_id: Any, err: Error) -> None: """ Send error response to the given request id to the server with the given error """ - self.tasks[self.task_counter] = asyncio.get_event_loop().create_task( - self._send_payload(make_error_response(request_id, err)) - ) - self.task_counter += 1 + # Use lock to prevent race conditions on tasks and task_counter + with self._tasks_lock: + self.tasks[self.task_counter] = asyncio.get_event_loop().create_task( + self._send_payload(make_error_response(request_id, err)) + ) + self.task_counter += 1 + async def send_request(self, method: str, params: Optional[dict] = None) -> PayloadLike: """ Send request to the server, register the request id, and wait for the response """ request = Request() - request_id = self.request_id - self.request_id += 1 - self._response_handlers[request_id] = request + + # Use lock to prevent race conditions on request_id and _response_handlers + with self._request_id_lock: + request_id = self.request_id + self.request_id += 1 + + with self._response_handlers_lock: + self._response_handlers[request_id] = request + async with request.cv: await self._send_payload(make_request(method, request_id, params)) self._log(f"Waiting for asyncio condition for request {method} with params:\n{params}") @@ -499,6 +533,7 @@ async def send_request(self, method: str, params: Optional[dict] = None) -> Payl self._log(f"Returning non-error result, which is:\n{request.result}") return request.result + def _send_payload_sync(self, payload: StringDict) -> None: """ Send the payload to the server by writing to its stdin synchronously @@ -508,7 +543,17 @@ def _send_payload_sync(self, payload: StringDict) -> None: msg = create_message(payload) if self.logger: self.logger("client", "server", payload) - self.process.stdin.writelines(msg) + + # Use lock to prevent concurrent writes to stdin that cause buffer corruption + with self._stdin_lock: + try: + self.process.stdin.writelines(msg) + except (BrokenPipeError, ConnectionResetError, OSError) as e: + # Log the error but don't raise to prevent cascading failures + if self.logger: + self.logger("client", "logger", f"Failed to write to stdin: {e}") + return + async def _send_payload(self, payload: StringDict) -> None: """ @@ -518,8 +563,18 @@ async def _send_payload(self, payload: StringDict) -> None: return self._log(payload) msg = create_message(payload) - self.process.stdin.writelines(msg) - await self.process.stdin.drain() + + # Use lock to prevent concurrent writes to stdin that cause buffer corruption + with self._stdin_lock: + try: + self.process.stdin.writelines(msg) + await self.process.stdin.drain() + except (BrokenPipeError, ConnectionResetError, OSError) as e: + # Log the error but don't raise to prevent cascading failures + if self.logger: + self.logger("client", "logger", f"Failed to write to stdin: {e}") + return + def on_request(self, method: str, cb) -> None: """ @@ -537,7 +592,9 @@ async def _response_handler(self, response: StringDict) -> None: """ Handle the response received from the server for a request, using the id to determine the request """ - request = self._response_handlers.pop(response["id"]) + with self._response_handlers_lock: + request = self._response_handlers.pop(response["id"]) + if "result" in response and "error" not in response: await request.on_result(response["result"]) elif "result" not in response and "error" in response: @@ -545,6 +602,7 @@ async def _response_handler(self, response: StringDict) -> None: else: await request.on_error(Error(ErrorCodes.InvalidRequest, "")) + async def _request_handler(self, response: StringDict) -> None: """ Handle the request received from the server: call the appropriate callback function and return the result