Skip to content

Commit 20b8567

Browse files
refactor: tokio main, isolate crossbeam channels
1 parent d76481b commit 20b8567

File tree

3 files changed

+122
-70
lines changed

3 files changed

+122
-70
lines changed

crates/pg_lsp/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pg_base_db.workspace = true
3232
pg_schema_cache.workspace = true
3333
pg_workspace.workspace = true
3434
pg_diagnostics.workspace = true
35+
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread", "sync"] }
3536

3637
[dev-dependencies]
3738

crates/pg_lsp/src/main.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
use lsp_server::Connection;
22
use pg_lsp::server::Server;
33

4-
fn main() -> anyhow::Result<()> {
4+
#[tokio::main]
5+
async fn main() -> anyhow::Result<()> {
56
let (connection, threads) = Connection::stdio();
6-
Server::init(connection)?;
7+
let server = Server::init(connection)?;
8+
9+
server.run().await?;
710
threads.join()?;
811

912
Ok(())

crates/pg_lsp/src/server.rs

Lines changed: 116 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ mod dispatch;
33
pub mod options;
44

55
use async_std::task::{self};
6-
use crossbeam_channel::{unbounded, Receiver, Sender};
76
use lsp_server::{Connection, ErrorCode, Message, RequestId};
87
use lsp_types::{
98
notification::{
@@ -33,6 +32,8 @@ use std::{collections::HashSet, sync::Arc, time::Duration};
3332
use text_size::TextSize;
3433
use threadpool::ThreadPool;
3534

35+
use tokio::sync::{mpsc, oneshot};
36+
3637
use crate::{
3738
client::{client_flags::ClientFlags, LspClient},
3839
utils::{file_path, from_proto, line_index_ext::LineIndexExt, normalize_uri, to_proto},
@@ -68,11 +69,39 @@ impl DbConnection {
6869
}
6970
}
7071

72+
/// `lsp-servers` `Connection` type uses a crossbeam channel, which is not compatible with tokio's async runtime.
73+
/// For now, we move it into a separate task and use tokio's channels to communicate.
74+
fn get_client_receiver(
75+
connection: Connection,
76+
) -> (mpsc::UnboundedReceiver<Message>, oneshot::Receiver<()>) {
77+
let (message_tx, message_rx) = mpsc::unbounded_channel();
78+
let (close_tx, close_rx) = oneshot::channel();
79+
80+
tokio::task::spawn(async move {
81+
// TODO: improve Result handling
82+
loop {
83+
let msg = connection.receiver.recv().unwrap();
84+
85+
match msg {
86+
Message::Request(r) if connection.handle_shutdown(&r).unwrap() => {
87+
close_tx.send(()).unwrap();
88+
return;
89+
}
90+
91+
_ => message_tx.send(msg).unwrap(),
92+
};
93+
}
94+
});
95+
96+
(message_rx, close_rx)
97+
}
98+
7199
pub struct Server {
72-
connection: Arc<Connection>,
100+
client_rx: mpsc::UnboundedReceiver<Message>,
101+
close_rx: oneshot::Receiver<()>,
73102
client: LspClient,
74-
internal_tx: Sender<InternalMessage>,
75-
internal_rx: Receiver<InternalMessage>,
103+
internal_tx: mpsc::UnboundedSender<InternalMessage>,
104+
internal_rx: mpsc::UnboundedReceiver<InternalMessage>,
76105
pool: Arc<ThreadPool>,
77106
client_flags: Arc<ClientFlags>,
78107
ide: Arc<Workspace>,
@@ -81,10 +110,10 @@ pub struct Server {
81110
}
82111

83112
impl Server {
84-
pub fn init(connection: Connection) -> anyhow::Result<()> {
113+
pub fn init(connection: Connection) -> anyhow::Result<Self> {
85114
let client = LspClient::new(connection.sender.clone());
86115

87-
let (internal_tx, internal_rx) = unbounded();
116+
let (internal_tx, internal_rx) = mpsc::unbounded_channel();
88117

89118
let (id, params) = connection.initialize_start()?;
90119
let params: InitializeParams = serde_json::from_value(params)?;
@@ -110,8 +139,11 @@ impl Server {
110139
let cloned_pool = pool.clone();
111140
let cloned_client = client.clone();
112141

142+
let (client_rx, close_rx) = get_client_receiver(connection);
143+
113144
let server = Self {
114-
connection: Arc::new(connection),
145+
close_rx,
146+
client_rx,
115147
internal_rx,
116148
internal_tx,
117149
client,
@@ -158,8 +190,7 @@ impl Server {
158190
pool,
159191
};
160192

161-
server.run()?;
162-
Ok(())
193+
Ok(server)
163194
}
164195

165196
fn compute_now(&self) {
@@ -763,67 +794,84 @@ impl Server {
763794
Ok(())
764795
}
765796

766-
fn process_messages(&mut self) -> anyhow::Result<()> {
797+
async fn process_messages(&mut self) -> anyhow::Result<()> {
767798
loop {
768-
crossbeam_channel::select! {
769-
recv(&self.connection.receiver) -> msg => {
770-
match msg? {
771-
Message::Request(request) => {
772-
if self.connection.handle_shutdown(&request)? {
773-
return Ok(());
774-
}
775-
776-
if let Some(response) = dispatch::RequestDispatcher::new(request)
777-
.on::<InlayHintRequest, _>(|id, params| self.inlay_hint(id, params))?
778-
.on::<HoverRequest, _>(|id, params| self.hover(id, params))?
779-
.on::<ExecuteCommand,_>(|id, params| self.execute_command(id, params))?
780-
.on::<Completion, _>(|id, params| {
781-
self.completion(id, params)
782-
})?
783-
.on::<CodeActionRequest, _>(|id, params| {
784-
self.code_actions(id, params)
785-
})?
786-
.default()
787-
{
788-
self.client.send_response(response)?;
789-
}
790-
}
791-
Message::Notification(notification) => {
792-
dispatch::NotificationDispatcher::new(notification)
793-
.on::<DidChangeConfiguration, _>(|params| {
794-
self.did_change_configuration(params)
795-
})?
796-
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
797-
.on::<DidOpenTextDocument, _>(|params| self.did_open(params))?
798-
.on::<DidChangeTextDocument, _>(|params| self.did_change(params))?
799-
.on::<DidSaveTextDocument, _>(|params| self.did_save(params))?
800-
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
801-
.default();
802-
}
803-
Message::Response(response) => {
804-
self.client.recv_response(response)?;
805-
}
806-
};
799+
tokio::select! {
800+
_ = &mut self.close_rx => {
801+
return Ok(())
807802
},
808-
recv(&self.internal_rx) -> msg => {
809-
match msg? {
810-
InternalMessage::SetSchemaCache(c) => {
811-
self.ide.set_schema_cache(c);
812-
self.compute_now();
813-
}
814-
InternalMessage::RefreshSchemaCache => {
815-
self.refresh_schema_cache();
816-
}
817-
InternalMessage::PublishDiagnostics(uri) => {
818-
self.publish_diagnostics(uri)?;
819-
}
820-
InternalMessage::SetOptions(options) => {
821-
self.update_options(options);
822-
}
823-
};
803+
804+
msg = self.internal_rx.recv() => {
805+
match msg {
806+
// TODO: handle internal sender close? Is that valid state?
807+
None => return Ok(()),
808+
Some(m) => self.handle_internal_message(m)
809+
}
810+
},
811+
812+
msg = self.client_rx.recv() => {
813+
match msg {
814+
// the client sender is closed, we can return
815+
None => return Ok(()),
816+
Some(m) => self.handle_message(m)
817+
}
818+
},
819+
}?;
820+
}
821+
}
822+
823+
fn handle_message(&mut self, msg: Message) -> anyhow::Result<()> {
824+
match msg {
825+
Message::Request(request) => {
826+
if let Some(response) = dispatch::RequestDispatcher::new(request)
827+
.on::<InlayHintRequest, _>(|id, params| self.inlay_hint(id, params))?
828+
.on::<HoverRequest, _>(|id, params| self.hover(id, params))?
829+
.on::<ExecuteCommand, _>(|id, params| self.execute_command(id, params))?
830+
.on::<Completion, _>(|id, params| self.completion(id, params))?
831+
.on::<CodeActionRequest, _>(|id, params| self.code_actions(id, params))?
832+
.default()
833+
{
834+
self.client.send_response(response)?;
824835
}
825-
};
836+
}
837+
Message::Notification(notification) => {
838+
dispatch::NotificationDispatcher::new(notification)
839+
.on::<DidChangeConfiguration, _>(|params| {
840+
self.did_change_configuration(params)
841+
})?
842+
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
843+
.on::<DidOpenTextDocument, _>(|params| self.did_open(params))?
844+
.on::<DidChangeTextDocument, _>(|params| self.did_change(params))?
845+
.on::<DidSaveTextDocument, _>(|params| self.did_save(params))?
846+
.on::<DidCloseTextDocument, _>(|params| self.did_close(params))?
847+
.default();
848+
}
849+
Message::Response(response) => {
850+
self.client.recv_response(response)?;
851+
}
826852
}
853+
854+
Ok(())
855+
}
856+
857+
fn handle_internal_message(&mut self, msg: InternalMessage) -> anyhow::Result<()> {
858+
match msg {
859+
InternalMessage::SetSchemaCache(c) => {
860+
self.ide.set_schema_cache(c);
861+
self.compute_now();
862+
}
863+
InternalMessage::RefreshSchemaCache => {
864+
self.refresh_schema_cache();
865+
}
866+
InternalMessage::PublishDiagnostics(uri) => {
867+
self.publish_diagnostics(uri)?;
868+
}
869+
InternalMessage::SetOptions(options) => {
870+
self.update_options(options);
871+
}
872+
}
873+
874+
Ok(())
827875
}
828876

829877
fn pull_options(&mut self) {
@@ -881,10 +929,10 @@ impl Server {
881929
}
882930
}
883931

884-
pub fn run(mut self) -> anyhow::Result<()> {
932+
pub async fn run(mut self) -> anyhow::Result<()> {
885933
self.register_configuration();
886934
self.pull_options();
887-
self.process_messages()?;
935+
self.process_messages().await?;
888936
self.pool.join();
889937
Ok(())
890938
}

0 commit comments

Comments
 (0)