Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/multilspy/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def __init__(
self.completions_available = asyncio.Event()
if config.trace_lsp_communication:
def logging_fn(source: str, target: str, msg: StringDict | str):
self.logger.log(f"LSP: {source} -> {target}: {str(msg)}", logging.DEBUG)
self.logger.log(f"LSP: {source} -> {target}: {str(msg)}", self.logger.logger.level)
else:
logging_fn = None

Expand Down
112 changes: 85 additions & 27 deletions src/multilspy/lsp_protocol_handler/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import json
import logging
import os
import threading

import psutil
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -209,6 +211,14 @@ def __init__(
self.task_counter = 0
self.loop = None
self.start_independent_lsp_process = start_independent_lsp_process

# Add thread locks for shared resources to prevent race conditions
self._stdin_lock = threading.Lock()
self._request_id_lock = threading.Lock()
self._response_handlers_lock = threading.Lock()
self._tasks_lock = threading.Lock()



def is_running(self) -> bool:
"""
Expand Down Expand Up @@ -244,10 +254,14 @@ async def start(self) -> None:
raise RuntimeError(f"Process terminated immediately with code {self.process.returncode}. Error: {error_message}")

self.loop = asyncio.get_event_loop()
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever())
self.task_counter += 1
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr())
self.task_counter += 1

# Use lock to prevent race conditions on tasks and task_counter during startup
with self._tasks_lock:
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever())
self.task_counter += 1
self.tasks[self.task_counter] = self.loop.create_task(self.run_forever_stderr())
self.task_counter += 1



async def stop(self) -> None:
Expand All @@ -269,18 +283,23 @@ async def stop(self) -> None:
async def _cancel_pending_tasks(self):
"""Cancel all pending tasks and wait for them to complete or timeout."""
pending_tasks = []
for task in self.tasks.values():
if not task.done():
task.cancel()
pending_tasks.append(task)

# Use lock to safely access tasks dictionary
with self._tasks_lock:
for task in self.tasks.values():
if not task.done():
task.cancel()
pending_tasks.append(task)

if pending_tasks:
try:
await asyncio.wait_for(asyncio.gather(*pending_tasks, return_exceptions=True), timeout=5.0)
except (asyncio.TimeoutError, Exception):
pass

self.tasks = {}

# Clear tasks dictionary under lock
with self._tasks_lock:
self.tasks = {}

async def _cleanup_process(self, process):
"""Clean up a process: close stdin, terminate/kill process, close stdout/stderr."""
Expand Down Expand Up @@ -406,12 +425,15 @@ async def run_forever(self) -> bool:
continue
body = await self.process.stdout.readexactly(num_bytes)

self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body))
self.task_counter += 1
# Use lock to prevent race conditions on tasks and task_counter
with self._tasks_lock:
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(self._handle_body(body))
self.task_counter += 1
except (BrokenPipeError, ConnectionResetError, StopLoopException):
pass
return self._received_shutdown


async def run_forever_stderr(self) -> None:
"""
Continuously read from the language server process stderr and log the messages
Expand Down Expand Up @@ -467,28 +489,40 @@ def send_response(self, request_id: Any, params: PayloadLike) -> None:
"""
Send response to the given request id to the server with the given parameters
"""
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
self._send_payload(make_response(request_id, params))
)
self.task_counter += 1
# Use lock to prevent race conditions on tasks and task_counter
with self._tasks_lock:
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
self._send_payload(make_response(request_id, params))
)
self.task_counter += 1


def send_error_response(self, request_id: Any, err: Error) -> None:
"""
Send error response to the given request id to the server with the given error
"""
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
self._send_payload(make_error_response(request_id, err))
)
self.task_counter += 1
# Use lock to prevent race conditions on tasks and task_counter
with self._tasks_lock:
self.tasks[self.task_counter] = asyncio.get_event_loop().create_task(
self._send_payload(make_error_response(request_id, err))
)
self.task_counter += 1


async def send_request(self, method: str, params: Optional[dict] = None) -> PayloadLike:
"""
Send request to the server, register the request id, and wait for the response
"""
request = Request()
request_id = self.request_id
self.request_id += 1
self._response_handlers[request_id] = request

# Use lock to prevent race conditions on request_id and _response_handlers
with self._request_id_lock:
request_id = self.request_id
self.request_id += 1

with self._response_handlers_lock:
self._response_handlers[request_id] = request

async with request.cv:
await self._send_payload(make_request(method, request_id, params))
self._log(f"Waiting for asyncio condition for request {method} with params:\n{params}")
Expand All @@ -499,6 +533,7 @@ async def send_request(self, method: str, params: Optional[dict] = None) -> Payl
self._log(f"Returning non-error result, which is:\n{request.result}")
return request.result


def _send_payload_sync(self, payload: StringDict) -> None:
"""
Send the payload to the server by writing to its stdin synchronously
Expand All @@ -508,7 +543,17 @@ def _send_payload_sync(self, payload: StringDict) -> None:
msg = create_message(payload)
if self.logger:
self.logger("client", "server", payload)
self.process.stdin.writelines(msg)

# Use lock to prevent concurrent writes to stdin that cause buffer corruption
with self._stdin_lock:
try:
self.process.stdin.writelines(msg)
except (BrokenPipeError, ConnectionResetError, OSError) as e:
# Log the error but don't raise to prevent cascading failures
if self.logger:
self.logger("client", "logger", f"Failed to write to stdin: {e}")
return


async def _send_payload(self, payload: StringDict) -> None:
"""
Expand All @@ -518,8 +563,18 @@ async def _send_payload(self, payload: StringDict) -> None:
return
self._log(payload)
msg = create_message(payload)
self.process.stdin.writelines(msg)
await self.process.stdin.drain()

# Use lock to prevent concurrent writes to stdin that cause buffer corruption
with self._stdin_lock:
try:
self.process.stdin.writelines(msg)
await self.process.stdin.drain()
except (BrokenPipeError, ConnectionResetError, OSError) as e:
# Log the error but don't raise to prevent cascading failures
if self.logger:
self.logger("client", "logger", f"Failed to write to stdin: {e}")
return


def on_request(self, method: str, cb) -> None:
"""
Expand All @@ -537,14 +592,17 @@ async def _response_handler(self, response: StringDict) -> None:
"""
Handle the response received from the server for a request, using the id to determine the request
"""
request = self._response_handlers.pop(response["id"])
with self._response_handlers_lock:
request = self._response_handlers.pop(response["id"])

if "result" in response and "error" not in response:
await request.on_result(response["result"])
elif "result" not in response and "error" in response:
await request.on_error(Error.from_lsp(response["error"]))
else:
await request.on_error(Error(ErrorCodes.InvalidRequest, ""))


async def _request_handler(self, response: StringDict) -> None:
"""
Handle the request received from the server: call the appropriate callback function and return the result
Expand Down
Loading