diff --git a/Cargo.lock b/Cargo.lock index 038ea3b35f..cb3902e295 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -808,6 +808,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "doctest-file" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2db04e74f0a9a93103b50e90b96024c9b2bdca8bce6a632ec71b88736d3d359" + [[package]] name = "dupe" version = "0.9.1" @@ -1496,6 +1502,19 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "71dd52191aae121e8611f1e8dc3e324dd0dd1dee1e6dd91d10ee07a3cfb4d9d8" +[[package]] +name = "interprocess" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6be5e5c847dbdb44564bd85294740d031f4f8aeb3464e5375ef7141f7538db69" +dependencies = [ + "doctest-file", + "libc", + "recvmsg", + "widestring", + "windows-sys 0.52.0", +] + [[package]] name = "is-macro" version = "0.3.6" @@ -2226,6 +2245,7 @@ dependencies = [ "fuzzy-matcher", "fxhash", "indicatif", + "interprocess", "itertools 0.14.0", "lsp-server", "lsp-types", @@ -2609,6 +2629,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "recvmsg" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3edd4d5d42c92f0a659926464d4cce56b562761267ecf0f469d85b7de384175" + [[package]] name = "redox_syscall" version = "0.2.10" @@ -3922,6 +3948,12 @@ dependencies = [ "rustix 0.38.44", ] +[[package]] +name = "widestring" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72069c3113ab32ab29e5584db3c6ec55d416895e60715417b5b883a357c3e471" + [[package]] name = "winapi" version = "0.3.9" diff --git a/crates/tsp_types/protocol_generator/tsp.json b/crates/tsp_types/protocol_generator/tsp.json index 06ce1b9f3d..d2d65ec9e8 100644 --- a/crates/tsp_types/protocol_generator/tsp.json +++ b/crates/tsp_types/protocol_generator/tsp.json @@ -65,6 +65,20 @@ } ], "requests": [ + { + "method": "typeServer/connection", + "typeName": "ConnectionRequest", + "messageDirection": "clientToServer", + "documentation": "Main-connection-only request used to open or close an extra TSP transport. Extra transports must remain read-only and must not be used for LSP traffic.", + "params": { + "kind": "reference", + "name": "ConnectionRequestParams" + }, + "result": { + "kind": "reference", + "name": "ConnectionRequestResult" + } + }, { "method": "typeServer/getComputedType", "typeName": "GetComputedTypeRequest", @@ -301,13 +315,21 @@ "v0_1_0": "0.1.0", "v0_2_0": "0.2.0", "v0_3_0": "0.3.0", - "current": "0.4.0" + "v0_4_0": "0.4.0", + "current": "0.5.0" }, "valueDocumentation": { "v0_1_0": "Initial protocol version", "v0_2_0": "Added new request types and fields", "v0_3_0": "Switch to more complex types", - "current": "Switch to Type union and using stubs" + "v0_4_0": "Switch to Type union and using stubs", + "current": "Add multi-connection negotiation and control requests" + } + }, + "ConnectionTransportKind": { + "kind": "stringEnum", + "values": { + "Ipc": "ipc" } }, "Variance": { @@ -345,6 +367,77 @@ ], "documentation": "Represents a location in source code (a node in the AST). Used to point to specific declarations, expressions, or statements in Python source files. Used for: - Pointing to where a type is declared - Identifying the location of expressions for type inference - Error reporting and diagnostics - Linking types back to their source definitions Examples: - For `def foo():`, the node points to the function declaration - For a variable `x = 42`, the node points to the assignment - For default parameter values in functions" }, + "TypeServerMultiConnectionCapability": { + "kind": "interface", + "properties": [ + { + "name": "supportedTransports", + "type": { + "kind": "array", + "element": { + "kind": "reference", + "name": "ConnectionTransportKind" + } + }, + "optional": false + } + ], + "documentation": "Capability shape exchanged via the LSP initialize request/response under `capabilities.experimental.typeServerMultiConnection`." + }, + "ConnectionRequestParams": { + "kind": "interface", + "properties": [ + { + "name": "type", + "type": { + "kind": "base", + "name": "string" + }, + "optional": false + }, + { + "name": "kind", + "type": { + "kind": "reference", + "name": "ConnectionTransportKind" + }, + "optional": false + }, + { + "name": "args", + "type": { + "kind": "array", + "element": { + "kind": "base", + "name": "string" + } + }, + "optional": true + } + ], + "documentation": "Main-connection-only control request used to open or close extra read-only TSP channels after LSP initialization has completed." + }, + "ConnectionRequestResult": { + "kind": "interface", + "properties": [ + { + "name": "success", + "type": { + "kind": "base", + "name": "boolean" + }, + "optional": false + }, + { + "name": "message", + "type": { + "kind": "base", + "name": "string" + }, + "optional": true + } + ] + }, "ModuleName": { "kind": "interface", "properties": [ @@ -501,8 +594,8 @@ { "name": "returnType", "type": { - "kind": "reference", - "name": "Type" + "kind": "reference", + "name": "Type" }, "optional": true, "documentation": "Specialized type of the declared return type. Undefined if there is no declared return type. Example: For `def foo[T](x: T) -> T` specialized to `T=int`, returnType=int." @@ -589,7 +682,7 @@ "documentation": "Discriminator field that determines which declaration variant this is. Regular: Has source code and AST node Synthesized: Created by type checker, no source node" } ], - "documentation": "Base interface for all declaration types. Provides the discriminator field for the Declaration union." + "documentation": "Base interface for all declaration types. Provides the discriminator field for the Declaration union. This is a generic interface that is extended by: - RegularDeclaration (kind = Regular) - SynthesizedDeclaration (kind = Synthesized) The type parameter T ensures that the kind field matches the implementing interface. Used for type-safe discrimination: ```typescript if (declaration.kind === DeclarationKind.Regular) { // TypeScript knows this is RegularDeclaration const node = declaration.node; } ```" }, "RegularDeclaration": { "kind": "interface", @@ -615,8 +708,8 @@ { "name": "name", "type": { - "kind": "base", - "name": "string" + "kind": "base", + "name": "string" }, "optional": true, "documentation": "Name of the declared symbol, or undefined for anonymous declarations. Example: \"foo\" for `def foo():`, undefined for lambda functions." diff --git a/crates/tsp_types/src/protocol.rs b/crates/tsp_types/src/protocol.rs index d5b3ec3c50..2355b1e6d3 100644 --- a/crates/tsp_types/src/protocol.rs +++ b/crates/tsp_types/src/protocol.rs @@ -13,8 +13,7 @@ // 1. Create tsp.json and tsp.schema.json from typeServerProtocol.ts // 2. Install lsprotocol generator: `pip install git+https://github.com/microsoft/lsprotocol.git` // 3. Run: `python generate_protocol.py` -use serde::Deserialize; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use serde_repr::Deserialize_repr; use serde_repr::Serialize_repr; @@ -110,6 +109,8 @@ pub enum LSPNull { #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] pub enum TSPRequestMethods { + #[serde(rename = "typeServer/connection")] + TypeServerConnection, #[serde(rename = "typeServer/getComputedType")] TypeServerGetComputedType, #[serde(rename = "typeServer/getDeclaredType")] @@ -129,6 +130,11 @@ pub enum TSPRequestMethods { #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] #[serde(tag = "method")] pub enum TSPRequests { + #[serde(rename = "typeServer/connection")] + ConnectionRequest { + id: serde_json::Value, + params: ConnectionRequestParams, + }, #[serde(rename = "typeServer/getComputedType")] GetComputedTypeRequest { id: serde_json::Value, @@ -347,9 +353,19 @@ pub enum TypeServerVersion { /// Switch to Type union and using stubs #[serde(rename = "0.4.0")] + V040, + + /// Add multi-connection negotiation and control requests + #[serde(rename = "0.5.0")] Current, } +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] +pub enum ConnectionTransportKind { + #[serde(rename = "ipc")] + Ipc, +} + #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] pub enum Variance { /// Variance not yet determined, will be inferred @@ -418,6 +434,33 @@ pub struct Node { pub uri: String, } +/// Capability shape exchanged via the LSP initialize request/response under `capabilities.experimental.typeServerMultiConnection`. +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct TypeServerMultiConnectionCapability { + pub supported_transports: Vec, +} + +/// Main-connection-only control request used to open or close extra read-only TSP channels after LSP initialization has completed. +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct ConnectionRequestParams { + pub args: Option>, + + pub kind: ConnectionTransportKind, + + #[serde(rename = "type")] + pub type_: String, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct ConnectionRequestResult { + pub message: Option, + + pub success: bool, +} + /// Represents a Python module name, handling both absolute and relative imports. Used for: - Import statement resolution - Tracking module dependencies - Resolving relative imports (from . import, from .. import) Examples: - `import os.path`: leadingDots=0, nameParts=['os', 'path'] - `from . import utils`: leadingDots=1, nameParts=['utils'] - `from ...parent import module`: leadingDots=3, nameParts=['parent', 'module'] - `import mymodule`: leadingDots=0, nameParts=['mymodule'] #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] #[serde(rename_all = "camelCase", deny_unknown_fields)] @@ -510,7 +553,7 @@ pub struct SentinelLiteral { pub module_name: String, } -/// Base interface for all declaration types. Provides the discriminator field for the Declaration union. +/// Base interface for all declaration types. Provides the discriminator field for the Declaration union. This is a generic interface that is extended by: - RegularDeclaration (kind = Regular) - SynthesizedDeclaration (kind = Synthesized) The type parameter T ensures that the kind field matches the implementing interface. Used for type-safe discrimination: ```typescript if (declaration.kind === DeclarationKind.Regular) { // TypeScript knows this is RegularDeclaration const node = declaration.node; } ``` #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] #[serde(rename_all = "camelCase", deny_unknown_fields)] pub struct DeclarationBase { @@ -864,6 +907,25 @@ pub enum LSPIdOptional { None, } +/// Main-connection-only request used to open or close an extra TSP transport. Extra transports must remain read-only and must not be used for LSP traffic. +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub struct ConnectionRequest { + /// The version of the JSON RPC protocol. + pub jsonrpc: String, + + /// The method to be invoked. + pub method: TSPRequestMethods, + + /// The request id. + pub id: LSPId, + + pub params: ConnectionRequestParams, +} + +/// Response to the [ConnectionRequest]. +pub type ConnectionResponse = ConnectionRequestResult; + /// Requests and notifications for the type server protocol. Request for the computed type of a declaration or node. Computed type is the type that is inferred based on the code flow. Example: def foo(a: int | str): if instanceof(a, int): b = a + 1 # Computed type of 'b' is 'int' #[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Clone)] #[serde(rename_all = "camelCase", deny_unknown_fields)] diff --git a/crates/tsp_types/tests/protocol_types.rs b/crates/tsp_types/tests/protocol_types.rs index e3f4bd454c..4e3a387714 100644 --- a/crates/tsp_types/tests/protocol_types.rs +++ b/crates/tsp_types/tests/protocol_types.rs @@ -76,7 +76,7 @@ fn test_variance_round_trip() { fn test_type_server_version_round_trip() { let v = TypeServerVersion::Current; let json = serde_json::to_value(&v).unwrap(); - assert_eq!(json, serde_json::json!("0.4.0")); + assert_eq!(json, serde_json::json!("0.5.0")); let back: TypeServerVersion = serde_json::from_value(json).unwrap(); assert_eq!(back, v); } diff --git a/pyrefly/Cargo.toml b/pyrefly/Cargo.toml index 47208849c2..89451aece9 100644 --- a/pyrefly/Cargo.toml +++ b/pyrefly/Cargo.toml @@ -34,6 +34,7 @@ faster-hex = "0.6.1" fuzzy-matcher = "0.3.7" fxhash = "0.2.1" indicatif = { version = "0.18.4", features = ["futures", "improved_unicode", "rayon", "tokio"] } +interprocess = "2.2.3" itertools = "0.14.0" lsp-server = "0.7.9" lsp-types = { git = "https://github.com/astral-sh/lsp-types", rev = "3512a9f33eadc5402cfab1b8f7340824c8ca1439" } diff --git a/pyrefly/lib/commands/tsp.rs b/pyrefly/lib/commands/tsp.rs index 7fde731129..5cac02a223 100644 --- a/pyrefly/lib/commands/tsp.rs +++ b/pyrefly/lib/commands/tsp.rs @@ -38,6 +38,10 @@ pub struct TspArgs { /// Note that indexing files is a performance-intensive task. #[arg(long, default_value_t = if cfg!(fbcode_build) {0} else {2000})] pub(crate) workspace_indexing_limit: usize, + /// Selects the transport for the main JSON-RPC connection. + /// Use `stdio` (default) or `ipc://` for a local socket / named pipe. + #[arg(long, default_value = "stdio")] + pub(crate) transport: String, } pub fn run_tsp( @@ -107,9 +111,7 @@ impl TspArgs { // Note that we must have our logging only write out to stderr. eprintln!("starting TSP server"); - // Create the transport. Includes the stdio (stdin and stdout) versions but this could - // also be implemented to use sockets or HTTP. - let (connection, reader, io_threads) = Connection::stdio(); + let (connection, reader, io_threads) = Connection::from_transport(&self.transport)?; run_tsp(connection, reader, self, telemetry, wrapper, thread_count)?; io_threads.join()?; diff --git a/pyrefly/lib/lsp/non_wasm/server.rs b/pyrefly/lib/lsp/non_wasm/server.rs index a88bc78e57..c7ff596ac5 100644 --- a/pyrefly/lib/lsp/non_wasm/server.rs +++ b/pyrefly/lib/lsp/non_wasm/server.rs @@ -31,6 +31,14 @@ use crossbeam_channel::Receiver; use crossbeam_channel::Sender; use dupe::Dupe; use dupe::OptionDupedExt; +use interprocess::TryClone; +#[cfg(unix)] +use interprocess::local_socket::GenericFilePath; +use interprocess::local_socket::ToFsName; +use interprocess::local_socket::prelude::LocalSocketStream; +use interprocess::local_socket::traits::Stream; +#[cfg(windows)] +use interprocess::os::windows::local_socket::NamedPipe; use itertools::Itertools; use lsp_server::ErrorCode; use lsp_server::RequestId; @@ -494,6 +502,7 @@ pub struct Connection { pub enum MessageReader { Channel(Receiver), Stdio(BufReader), + Stream(BufReader>), } impl MessageReader { @@ -504,6 +513,7 @@ impl MessageReader { match self { MessageReader::Channel(r) => r.recv().ok(), MessageReader::Stdio(r) => read_lsp_message(r).ok().flatten(), + MessageReader::Stream(r) => read_lsp_message(r).ok().flatten(), } } } @@ -544,6 +554,47 @@ impl Connection { ) } + pub fn ipc(pipe_name: &str) -> std::io::Result<(Self, MessageReader, IoThread)> { + #[cfg(windows)] + let socket_name = pipe_name.to_fs_name::()?; + #[cfg(unix)] + let socket_name = pipe_name.to_fs_name::()?; + + let stream = LocalSocketStream::connect(socket_name)?; + let reader_stream = stream.try_clone()?; + let (writer_sender, writer_receiver) = crossbeam_channel::unbounded(); + let writer = std::thread::spawn(move || { + let mut output = stream; + while let Ok(msg) = writer_receiver.recv() { + write_lsp_message(&mut output, msg)?; + } + Ok(()) + }); + Ok(( + Self { + sender: writer_sender, + channel_receiver: None, + }, + MessageReader::Stream(BufReader::new(Box::new(reader_stream))), + IoThread { writer }, + )) + } + + pub fn from_transport(transport: &str) -> std::io::Result<(Self, MessageReader, IoThread)> { + if transport == "stdio" { + return Ok(Self::stdio()); + } + + if let Some(pipe_name) = transport.strip_prefix("ipc://") { + return Self::ipc(pipe_name); + } + + Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Unsupported TSP transport: {transport}"), + )) + } + pub fn memory() -> ((Self, MessageReader), (Self, MessageReader)) { let (s1, r1) = crossbeam_channel::unbounded(); let (s2, r2) = crossbeam_channel::unbounded(); @@ -1077,7 +1128,7 @@ pub fn initialize_start( #[serde(rename_all = "camelCase")] pub struct ServerCapabilitiesWithTypeHierarchy { #[serde(flatten)] - base: ServerCapabilities, + pub(crate) base: ServerCapabilities, #[serde(skip_serializing_if = "Option::is_none")] type_hierarchy_provider: Option, } diff --git a/pyrefly/lib/tsp/requests/get_computed_type.rs b/pyrefly/lib/tsp/requests/get_computed_type.rs index 6949547ca4..143b47452c 100644 --- a/pyrefly/lib/tsp/requests/get_computed_type.rs +++ b/pyrefly/lib/tsp/requests/get_computed_type.rs @@ -14,7 +14,7 @@ use tsp_types::Type; use crate::lsp::non_wasm::server::TspInterface; use crate::tsp::server::TspServer; -impl TspServer { +impl TspServer { /// Return the computed (inferred) type at the given position. /// /// The computed type reflects the type checker's analysis of the code diff --git a/pyrefly/lib/tsp/requests/get_declared_type.rs b/pyrefly/lib/tsp/requests/get_declared_type.rs index 4e833ea21a..c82b544df0 100644 --- a/pyrefly/lib/tsp/requests/get_declared_type.rs +++ b/pyrefly/lib/tsp/requests/get_declared_type.rs @@ -14,7 +14,7 @@ use tsp_types::Type; use crate::lsp::non_wasm::server::TspInterface; use crate::tsp::server::TspServer; -impl TspServer { +impl TspServer { /// Return the declared type at the given position. /// /// The declared type is the annotation explicitly written by the user. diff --git a/pyrefly/lib/tsp/requests/get_expected_type.rs b/pyrefly/lib/tsp/requests/get_expected_type.rs index 9c1ecdf95b..7177332e8e 100644 --- a/pyrefly/lib/tsp/requests/get_expected_type.rs +++ b/pyrefly/lib/tsp/requests/get_expected_type.rs @@ -14,7 +14,7 @@ use tsp_types::Type; use crate::lsp::non_wasm::server::TspInterface; use crate::tsp::server::TspServer; -impl TspServer { +impl TspServer { /// Return the expected type at the given position. /// /// The expected type is the type that a surrounding context demands. diff --git a/pyrefly/lib/tsp/requests/get_python_search_paths.rs b/pyrefly/lib/tsp/requests/get_python_search_paths.rs index b64195dc55..db35b018ef 100644 --- a/pyrefly/lib/tsp/requests/get_python_search_paths.rs +++ b/pyrefly/lib/tsp/requests/get_python_search_paths.rs @@ -19,7 +19,7 @@ use crate::tsp::server::TspServer; use crate::tsp::validation::internal_error; use crate::tsp::validation::parse_file_uri; -impl TspServer { +impl TspServer { /// Handle a `typeServer/getPythonSearchPaths` request. /// /// Validates the snapshot, parses the `from_uri`, and delegates to diff --git a/pyrefly/lib/tsp/requests/resolve_import.rs b/pyrefly/lib/tsp/requests/resolve_import.rs index 8c56c85d63..ecb15aa8ca 100644 --- a/pyrefly/lib/tsp/requests/resolve_import.rs +++ b/pyrefly/lib/tsp/requests/resolve_import.rs @@ -27,7 +27,7 @@ use crate::tsp::server::TspServer; use crate::tsp::validation::invalid_params_error; use crate::tsp::validation::parse_file_uri; -impl TspServer { +impl TspServer { /// Handle a `typeServer/resolveImport` request. /// /// Converts the TSP [`ResolveImportParams`] into pyrefly's internal diff --git a/pyrefly/lib/tsp/server.rs b/pyrefly/lib/tsp/server.rs index f66fcee3cb..edce559cae 100644 --- a/pyrefly/lib/tsp/server.rs +++ b/pyrefly/lib/tsp/server.rs @@ -5,11 +5,14 @@ * LICENSE file in the root directory of this source tree. */ +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; use std::sync::Mutex; +use lsp_server::ErrorCode; use lsp_server::RequestId; +use lsp_server::ResponseError; use lsp_types::InitializeParams; use pyrefly_util::telemetry::QueueName; use pyrefly_util::telemetry::Telemetry; @@ -17,6 +20,9 @@ use pyrefly_util::telemetry::TelemetryEvent; use pyrefly_util::telemetry::TelemetryEventKind; use tracing::info; use tracing::warn; +use tsp_types::ConnectionRequestParams; +use tsp_types::ConnectionRequestResult; +use tsp_types::ConnectionTransportKind; use tsp_types::GetTypeParams; use tsp_types::TSPNotificationMethods; use tsp_types::TSPRequests; @@ -27,6 +33,7 @@ use crate::lsp::non_wasm::protocol::Notification; use crate::lsp::non_wasm::protocol::Request; use crate::lsp::non_wasm::protocol::Response; use crate::lsp::non_wasm::queue::LspEvent; +use crate::lsp::non_wasm::server::Connection; use crate::lsp::non_wasm::server::InitializeInfo; use crate::lsp::non_wasm::server::MessageReader; use crate::lsp::non_wasm::server::ProcessEvent; @@ -36,18 +43,52 @@ use crate::lsp::non_wasm::server::capabilities; use crate::lsp::non_wasm::transaction_manager::TransactionManager; use crate::tsp::type_conversion::convert_type_with_resolver; +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ConnectionRole { + Main, + ReadOnlyExtra, +} + +struct ExtraConnectionHandle { + close_tx: crossbeam_channel::Sender<()>, +} + /// TSP server that delegates to LSP server infrastructure while handling only TSP requests pub struct TspServer { - pub inner: T, + pub inner: Arc, /// Current snapshot version, updated on RecheckFinished events pub(crate) current_snapshot: Arc>, + response_sender: crossbeam_channel::Sender, + extra_connections: Arc>>, + role: ConnectionRole, } -impl TspServer { +impl TspServer { pub fn new(lsp_server: T) -> Self { + let inner = Arc::new(lsp_server); + let response_sender = inner.sender().clone(); + Self::with_connection( + inner, + Arc::new(Mutex::new(0)), + response_sender, + Arc::new(Mutex::new(HashMap::new())), + ConnectionRole::Main, + ) + } + + fn with_connection( + inner: Arc, + current_snapshot: Arc>, + response_sender: crossbeam_channel::Sender, + extra_connections: Arc>>, + role: ConnectionRole, + ) -> Self { Self { - inner: lsp_server, - current_snapshot: Arc::new(Mutex::new(0)), // Start at 0, increments on RecheckFinished + inner, + current_snapshot, + response_sender, + extra_connections, + role, } } @@ -86,9 +127,9 @@ impl TspServer { return Ok(ProcessEvent::Continue); } // If it's not a TSP request, let the LSP server reject it since TSP server shouldn't handle LSP requests - self.inner.send_response(Response::new_err( + self.send_response(Response::new_err( request.id.clone(), - lsp_server::ErrorCode::MethodNotFound as i32, + ErrorCode::MethodNotFound as i32, format!("TSP server does not support LSP method: {}", request.method), )); return Ok(ProcessEvent::Continue); @@ -128,9 +169,12 @@ impl TspServer { .expect("TSPNotificationMethods serializes to a string") .to_owned(); + if self.role != ConnectionRole::Main { + return; + } + if let Err(e) = self - .inner - .sender() + .response_sender .send(Message::Notification(Notification { method: method_str, params: serde_json::json!({ "old": old_snapshot, "new": new_snapshot }), @@ -141,6 +185,12 @@ impl TspServer { } } + pub(crate) fn send_response(&self, response: Response) { + if let Err(error) = self.response_sender.send(Message::Response(response)) { + warn!("Failed to send TSP response: {error}"); + } + } + fn handle_tsp_request<'a>( &'a self, ide_transaction_manager: &mut TransactionManager<'a>, @@ -158,7 +208,28 @@ impl TspServer { return Ok(false); }; + if self.role == ConnectionRole::ReadOnlyExtra + && matches!(msg, TSPRequests::ConnectionRequest { .. }) + { + self.send_err( + request.id.clone(), + ResponseError { + code: ErrorCode::InvalidRequest as i32, + message: format!( + "TSP method {} is only allowed on the main connection", + request.method + ), + data: None, + }, + ); + return Ok(true); + } + match msg { + TSPRequests::ConnectionRequest { params, .. } => { + self.handle_connection_request(request.id.clone(), params); + Ok(true) + } TSPRequests::GetSupportedProtocolVersionRequest { .. } => { self.send_ok(request.id.clone(), self.get_supported_protocol_version()); Ok(true) @@ -228,10 +299,192 @@ impl TspServer { } } } + + fn handle_connection_request(&self, id: RequestId, params: ConnectionRequestParams) { + if self.role != ConnectionRole::Main { + self.send_err( + id, + ResponseError { + code: ErrorCode::InvalidRequest as i32, + message: "Connection management is only supported on the main TSP connection" + .to_owned(), + data: None, + }, + ); + return; + } + + let result = match params.type_.as_str() { + "open" => self.open_extra_connection(params), + "close" => Ok(self.close_extra_connection(params)), + other => Err(crate::tsp::validation::invalid_params_error(&format!( + "Unsupported connection request type: {other}" + ))), + }; + + match result { + Ok(connection_result) => self.send_ok(id, connection_result), + Err(error) => self.send_err(id, error), + } + } + + fn open_extra_connection( + &self, + params: ConnectionRequestParams, + ) -> Result { + let pipe_name = self.get_pipe_name(¶ms)?; + + let mut extra_connections = self.extra_connections.lock().map_err(|_| { + crate::tsp::validation::internal_error("extra connection state was poisoned") + })?; + + if extra_connections.contains_key(&pipe_name) { + return Ok(ConnectionRequestResult { + success: true, + message: Some(format!("Extra connection already open: {pipe_name}")), + }); + } + + let (connection, mut reader, _io_thread) = + Connection::ipc(&pipe_name).map_err(|error| { + crate::tsp::validation::internal_error(&format!( + "Failed to connect to IPC endpoint {pipe_name}: {error}" + )) + })?; + + let extra_server = Self::with_connection( + self.inner.clone(), + self.current_snapshot.clone(), + connection.sender.clone(), + self.extra_connections.clone(), + ConnectionRole::ReadOnlyExtra, + ); + let (message_tx, message_rx) = crossbeam_channel::unbounded(); + let (close_tx, close_rx) = crossbeam_channel::bounded::<()>(1); + let pipe_name_for_thread = pipe_name.clone(); + + extra_connections.insert( + pipe_name.clone(), + ExtraConnectionHandle { + close_tx: close_tx.clone(), + }, + ); + drop(extra_connections); + + std::thread::spawn(move || { + std::thread::spawn(move || { + while let Some(message) = reader.recv() { + if message_tx.send(message).is_err() { + break; + } + } + }); + + loop { + crossbeam_channel::select! { + recv(close_rx) -> _ => break, + recv(message_rx) -> message => { + let Ok(message) = message else { + break; + }; + + match message { + Message::Request(request) => { + let mut ide_transaction_manager = TransactionManager::default(); + if let Err(error) = extra_server.handle_extra_request(&mut ide_transaction_manager, request) { + warn!("Extra TSP connection exited with error: {error}"); + break; + } + } + Message::Notification(_) | Message::Response(_) => { + // Extra connections are read-only query channels. + } + } + } + } + } + + if let Ok(mut handles) = extra_server.extra_connections.lock() { + handles.remove(&pipe_name_for_thread); + } + }); + + Ok(ConnectionRequestResult { + success: true, + message: Some(format!("Opened extra IPC connection: {pipe_name}")), + }) + } + + fn close_extra_connection(&self, params: ConnectionRequestParams) -> ConnectionRequestResult { + let Ok(pipe_name) = self.get_pipe_name(¶ms) else { + return ConnectionRequestResult { + success: false, + message: Some("Missing IPC pipe name in connection args".to_owned()), + }; + }; + + let handle = self + .extra_connections + .lock() + .ok() + .and_then(|mut handles| handles.remove(&pipe_name)); + + if let Some(handle) = handle { + let _ = handle.close_tx.send(()); + ConnectionRequestResult { + success: true, + message: Some(format!("Closing extra IPC connection: {pipe_name}")), + } + } else { + ConnectionRequestResult { + success: true, + message: Some(format!("Extra IPC connection already closed: {pipe_name}")), + } + } + } + + fn get_pipe_name(&self, params: &ConnectionRequestParams) -> Result { + if params.kind != ConnectionTransportKind::Ipc { + return Err(crate::tsp::validation::invalid_params_error( + "Only IPC extra connections are supported", + )); + } + + params + .args + .as_ref() + .and_then(|args| args.first()) + .filter(|pipe_name| !pipe_name.is_empty()) + .cloned() + .ok_or_else(|| { + crate::tsp::validation::invalid_params_error( + "Connection request args must include the IPC pipe name", + ) + }) + } + + fn handle_extra_request<'a>( + &'a self, + ide_transaction_manager: &mut TransactionManager<'a>, + request: Request, + ) -> anyhow::Result<()> { + if !self.handle_tsp_request(ide_transaction_manager, &request)? { + self.send_response(Response::new_err( + request.id, + ErrorCode::MethodNotFound as i32, + format!( + "Extra TSP connection does not support method: {}", + request.method + ), + )); + } + + Ok(()) + } } pub fn tsp_loop( - lsp_server: impl TspInterface, + lsp_server: impl TspInterface + 'static, mut reader: MessageReader, _initialization: InitializeInfo, telemetry: &impl Telemetry, @@ -295,7 +548,11 @@ pub fn tsp_capabilities( indexing_mode: IndexingMode, initialization_params: &InitializeParams, ) -> ServerCapabilitiesWithTypeHierarchy { - // Use the same capabilities as LSP - TSP server supports the same features - // but will only respond to TSP protocol requests - capabilities(indexing_mode, initialization_params) + let mut result = capabilities(indexing_mode, initialization_params); + result.base.experimental = Some(serde_json::json!({ + "typeServerMultiConnection": { + "supportedTransports": ["ipc"] + } + })); + result } diff --git a/pyrefly/lib/tsp/type_conversion.rs b/pyrefly/lib/tsp/type_conversion.rs index 3643e9fdcc..20b1b71b20 100644 --- a/pyrefly/lib/tsp/type_conversion.rs +++ b/pyrefly/lib/tsp/type_conversion.rs @@ -739,19 +739,16 @@ mod tests { #[test] fn test_type_flags_bitwise_operations() { - // Test BitOr let combined = TypeFlags::INSTANCE | TypeFlags::CALLABLE; assert!(combined.contains(TypeFlags::INSTANCE)); assert!(combined.contains(TypeFlags::CALLABLE)); assert!(!combined.contains(TypeFlags::LITERAL)); - // Test BitOrAssign let mut flags = TypeFlags::NONE; flags |= TypeFlags::INSTANTIABLE; assert!(flags.contains(TypeFlags::INSTANTIABLE)); assert!(!flags.contains(TypeFlags::INSTANCE)); - // Test with_ builders let flags = TypeFlags::new().with_instance().with_callable(); assert!(flags.contains(TypeFlags::INSTANCE)); assert!(flags.contains(TypeFlags::CALLABLE)); @@ -760,20 +757,16 @@ mod tests { #[test] fn test_type_flags_serialization() { - // INSTANCE = 2 let json = serde_json::to_value(TypeFlags::INSTANCE).unwrap(); assert_eq!(json, serde_json::json!(2)); - // CALLABLE = 4 let json = serde_json::to_value(TypeFlags::CALLABLE).unwrap(); assert_eq!(json, serde_json::json!(4)); - // Combined flags (INSTANCE | CALLABLE = 6) let combined = TypeFlags::INSTANCE | TypeFlags::CALLABLE; let json = serde_json::to_value(combined).unwrap(); assert_eq!(json, serde_json::json!(6)); - // Deserialization let flags: TypeFlags = serde_json::from_value(serde_json::json!(6)).unwrap(); assert!(flags.contains(TypeFlags::INSTANCE)); assert!(flags.contains(TypeFlags::CALLABLE)); diff --git a/pyrefly/lib/tsp/validation.rs b/pyrefly/lib/tsp/validation.rs index 995c65a8da..c435c6c854 100644 --- a/pyrefly/lib/tsp/validation.rs +++ b/pyrefly/lib/tsp/validation.rs @@ -86,7 +86,7 @@ pub fn parse_file_uri(uri: &str) -> Result { // Snapshot validation // --------------------------------------------------------------------------- -impl TspServer { +impl TspServer { /// Validate that the client-supplied snapshot matches the server's current /// snapshot. Returns `Ok(())` on match or `Err(ResponseError)` on mismatch. pub fn validate_snapshot(&self, client_snapshot: i32) -> Result<(), ResponseError> { @@ -100,12 +100,12 @@ impl TspServer { /// Send a successful JSON-RPC response for `id` with `result`. pub fn send_ok(&self, id: RequestId, result: R) { - self.inner.send_response(new_response(id, Ok(result))); + self.send_response(new_response(id, Ok(result))); } /// Send a JSON-RPC error response for `id`. pub fn send_err(&self, id: RequestId, error: ResponseError) { - self.inner.send_response(Response { + self.send_response(Response { id, result: None, error: Some(error),