Skip to content

Commit 954e779

Browse files
refactor(services): use actix-ws instead of actix-web-actors (#1064)
* rust 1.88 * clippy auto-fixes * manual clippy fixes * update deps * cargo update * update onnx * cargo fmt * update sqlfluff * new websocket library
1 parent f1fdc28 commit 954e779

File tree

11 files changed

+622
-632
lines changed

11 files changed

+622
-632
lines changed

Cargo.lock

Lines changed: 15 additions & 19 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ actix-http = { version = "3.11", features = ["ws"] }
6868
actix-multipart = "0.7"
6969
actix-rt = "2.9"
7070
actix-web = "4.11"
71-
actix-web-actors = "4.3"
7271
actix-web-httpauth = "0.8"
72+
actix-ws = "0.3"
7373
aes-gcm = "0.10.3"
7474
anyhow = "1.0"
7575
approx = "0.5"

services/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ actix-http = { workspace = true }
1717
actix-multipart = { workspace = true }
1818
actix-rt = { workspace = true }
1919
actix-web = { workspace = true }
20-
actix-web-actors = { workspace = true }
2120
actix-web-httpauth = { workspace = true }
21+
actix-ws = { workspace = true }
2222
aes-gcm = { workspace = true }
2323
anyhow.workspace = true
2424
aruna-rust-api = { workspace = true }

services/src/api/handlers/workflows.rs

Lines changed: 232 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ use crate::util::parsing::{
1414
use crate::util::workflows::validate_workflow;
1515
use crate::workflows::registry::WorkflowRegistry;
1616
use crate::workflows::workflow::{Workflow, WorkflowId};
17-
use crate::workflows::{RasterWebsocketStreamHandler, VectorWebsocketStreamHandler};
17+
use crate::workflows::{WebsocketStreamTask, handle_websocket_message, send_websocket_message};
1818
use actix_web::{FromRequest, HttpRequest, HttpResponse, Responder, web};
19+
use futures::StreamExt;
1920
use futures::future::join_all;
2021
use geoengine_datatypes::error::{BoxedResultExt, ErrorSource};
2122
use geoengine_datatypes::primitives::{
@@ -538,18 +539,46 @@ async fn raster_stream_websocket<C: ApplicationContext>(
538539
RasterStreamWebsocketResultType::Arrow
539540
));
540541

541-
let stream_handler = RasterWebsocketStreamHandler::new::<C::SessionContext>(
542+
let mut stream_task = WebsocketStreamTask::new_raster::<C::SessionContext>(
542543
operator,
543544
query_rectangle,
544545
ctx.execution_context()?,
545546
ctx.query_context(workflow_id.0, Uuid::new_v4())?,
546547
)
547548
.await?;
548549

549-
match actix_web_actors::ws::start(stream_handler, &request, stream) {
550-
Ok(websocket) => Ok(websocket),
551-
Err(e) => Ok(e.error_response()),
552-
}
550+
let (response, mut session, mut msg_stream) = match actix_ws::handle(&request, stream) {
551+
Ok((response, session, msg_stream)) => (response, session, msg_stream),
552+
Err(e) => return Ok(e.error_response()),
553+
};
554+
555+
actix_web::rt::spawn(async move {
556+
loop {
557+
let indicator = tokio::select! {
558+
Some(Ok(msg)) = msg_stream.next() => {
559+
handle_websocket_message(msg, &mut stream_task, &mut session).await
560+
}
561+
562+
tile = stream_task.receive_tile() => {
563+
send_websocket_message(tile, session.clone()).await
564+
}
565+
566+
else => {
567+
None
568+
}
569+
};
570+
571+
if indicator.is_none() {
572+
// the stream ended or session was closed, stop processing
573+
break;
574+
}
575+
}
576+
577+
stream_task.abort_processing();
578+
let _ = session.close(None).await;
579+
});
580+
581+
Ok(response)
553582
}
554583

555584
/// The query parameters for `raster_stream_websocket`.
@@ -624,18 +653,46 @@ async fn vector_stream_websocket<C: ApplicationContext>(
624653
RasterStreamWebsocketResultType::Arrow
625654
));
626655

627-
let stream_handler = VectorWebsocketStreamHandler::new::<C::SessionContext>(
656+
let mut stream_task = WebsocketStreamTask::new_vector::<C::SessionContext>(
628657
operator,
629658
query_rectangle,
630659
ctx.execution_context()?,
631660
ctx.query_context(workflow_id.0, Uuid::new_v4())?,
632661
)
633662
.await?;
634663

635-
match actix_web_actors::ws::start(stream_handler, &request, stream) {
636-
Ok(websocket) => Ok(websocket),
637-
Err(e) => Ok(e.error_response()),
638-
}
664+
let (response, mut session, mut msg_stream) = match actix_ws::handle(&request, stream) {
665+
Ok((response, session, msg_stream)) => (response, session, msg_stream),
666+
Err(e) => return Ok(e.error_response()),
667+
};
668+
669+
actix_web::rt::spawn(async move {
670+
loop {
671+
let indicator = tokio::select! {
672+
Some(Ok(msg)) = msg_stream.next() => {
673+
handle_websocket_message(msg, &mut stream_task, &mut session).await
674+
}
675+
676+
tile = stream_task.receive_tile() => {
677+
send_websocket_message(tile, session.clone()).await
678+
}
679+
680+
else => {
681+
None
682+
}
683+
};
684+
685+
if indicator.is_none() {
686+
// the stream ended or session was closed, stop processing
687+
break;
688+
}
689+
}
690+
691+
stream_task.abort_processing();
692+
let _ = session.close(None).await;
693+
});
694+
695+
Ok(response)
639696
}
640697

641698
#[derive(Debug, Snafu)]
@@ -669,18 +726,22 @@ mod tests {
669726
use crate::tasks::util::test::wait_for_task_to_finish;
670727
use crate::tasks::{TaskManager, TaskStatus};
671728
use crate::users::UserAuth;
729+
use crate::util::tests::add_ports_to_datasets;
672730
use crate::util::tests::admin_login;
673731
use crate::util::tests::{
674732
TestDataUploads, add_ndvi_to_datasets, check_allowed_http_methods,
675733
check_allowed_http_methods2, read_body_string, register_ndvi_workflow_helper,
676734
send_test_request,
677735
};
736+
use crate::util::websocket_tests;
678737
use crate::workflows::registry::WorkflowRegistry;
679738
use actix_web::dev::ServiceResponse;
680739
use actix_web::{http::Method, http::header, test};
681740
use actix_web_httpauth::headers::authorization::Bearer;
741+
use futures::StreamExt;
682742
use geoengine_datatypes::collections::MultiPointCollection;
683743
use geoengine_datatypes::primitives::CacheHint;
744+
use geoengine_datatypes::primitives::DateTime;
684745
use geoengine_datatypes::primitives::{
685746
ContinuousMeasurement, FeatureData, Measurement, MultiPoint, RasterQueryRectangle,
686747
SpatialPartition2D, SpatialResolution, TimeInterval,
@@ -689,6 +750,7 @@ mod tests {
689750
use geoengine_datatypes::spatial_reference::SpatialReference;
690751
use geoengine_datatypes::test_data;
691752
use geoengine_datatypes::util::ImageFormat;
753+
use geoengine_datatypes::util::arrow::arrow_ipc_file_to_record_batches;
692754
use geoengine_datatypes::util::assert_image_equals_with_format;
693755
use geoengine_operators::engine::{
694756
ExecutionContext, MultipleRasterOrSingleVectorSource, PlotOperator, RasterBandDescriptor,
@@ -700,6 +762,8 @@ mod tests {
700762
MockRasterSourceParams,
701763
};
702764
use geoengine_operators::plot::{Statistics, StatisticsParams};
765+
use geoengine_operators::source::OgrSource;
766+
use geoengine_operators::source::OgrSourceParameters;
703767
use geoengine_operators::source::{GdalSource, GdalSourceParameters};
704768
use geoengine_operators::util::input::MultiRasterOrVectorOperator::Raster;
705769
use geoengine_operators::util::raster_stream_to_geotiff::{
@@ -1494,4 +1558,161 @@ mod tests {
14941558
ImageFormat::Tiff,
14951559
);
14961560
}
1561+
1562+
#[ge_context::test]
1563+
#[allow(clippy::too_many_lines)]
1564+
async fn it_serves_raster_streams_via_websockets(app_ctx: PostgresContext<NoTls>) {
1565+
let session = app_ctx.create_anonymous_session().await.unwrap();
1566+
let ctx = app_ctx.session_context(session.clone());
1567+
1568+
let (_, dataset) = add_ndvi_to_datasets(&app_ctx).await;
1569+
1570+
let workflow = Workflow {
1571+
operator: TypedOperator::Raster(
1572+
GdalSource {
1573+
params: GdalSourceParameters { data: dataset },
1574+
}
1575+
.boxed(),
1576+
),
1577+
};
1578+
1579+
let workflow_id = ctx.db().register_workflow(workflow).await.unwrap();
1580+
1581+
let (req, payload, mut input_tx, send_next_msg_trigger) =
1582+
websocket_tests::test_client().await;
1583+
1584+
tokio::task::spawn(async move {
1585+
// Simulate sending messages to the websocket
1586+
for _ in 0..4 {
1587+
websocket_tests::send_text(&mut input_tx, "NEXT").await;
1588+
}
1589+
websocket_tests::send_close(&mut input_tx).await;
1590+
});
1591+
1592+
tokio::task::LocalSet::new()
1593+
.run_until(async move {
1594+
let response = raster_stream_websocket(
1595+
web::Data::new(app_ctx.clone()),
1596+
session.clone(),
1597+
web::Path::from(workflow_id),
1598+
web::Query(RasterStreamWebsocketQuery {
1599+
spatial_bounds: SpatialPartition2D::new(
1600+
(-180., 90.).into(),
1601+
(180., -90.).into(),
1602+
)
1603+
.unwrap(),
1604+
time_interval: TimeInterval::new_instant(DateTime::new_utc(
1605+
2014, 3, 1, 0, 0, 0,
1606+
))
1607+
.unwrap()
1608+
.into(),
1609+
spatial_resolution: SpatialResolution::one(),
1610+
attributes: geoengine_datatypes::primitives::BandSelection::first().into(),
1611+
result_type: RasterStreamWebsocketResultType::Arrow,
1612+
}),
1613+
req,
1614+
payload,
1615+
)
1616+
.await
1617+
.unwrap();
1618+
1619+
let mut response_stream =
1620+
websocket_tests::response_messages(response, send_next_msg_trigger)
1621+
.boxed_local();
1622+
1623+
for _ in 0..4 {
1624+
let tile_bytes = response_stream.next().await.unwrap();
1625+
1626+
let record_batches = arrow_ipc_file_to_record_batches(&tile_bytes).unwrap();
1627+
assert_eq!(record_batches.len(), 1);
1628+
let record_batch = record_batches.first().unwrap();
1629+
let schema = record_batch.schema();
1630+
1631+
assert_eq!(schema.metadata()["spatialReference"], "EPSG:4326");
1632+
}
1633+
1634+
assert!(response_stream.next().await.is_none()); // No more messages expected
1635+
})
1636+
.await;
1637+
}
1638+
1639+
#[ge_context::test]
1640+
#[allow(clippy::too_many_lines)]
1641+
async fn it_serves_vector_streams_via_websockets(app_ctx: PostgresContext<NoTls>) {
1642+
let session = app_ctx.create_anonymous_session().await.unwrap();
1643+
let ctx = app_ctx.session_context(session.clone());
1644+
1645+
let (_, dataset) = add_ports_to_datasets(&app_ctx, true, true).await;
1646+
1647+
let workflow = Workflow {
1648+
operator: TypedOperator::Vector(
1649+
OgrSource {
1650+
params: OgrSourceParameters {
1651+
data: dataset,
1652+
attribute_projection: None,
1653+
attribute_filters: None,
1654+
},
1655+
}
1656+
.boxed(),
1657+
),
1658+
};
1659+
1660+
let workflow_id = ctx.db().register_workflow(workflow).await.unwrap();
1661+
1662+
let (req, payload, mut input_tx, send_next_msg_trigger) =
1663+
websocket_tests::test_client().await;
1664+
1665+
tokio::task::spawn(async move {
1666+
// Simulate sending messages to the websocket
1667+
for _ in 0..1 {
1668+
websocket_tests::send_text(&mut input_tx, "NEXT").await;
1669+
}
1670+
websocket_tests::send_close(&mut input_tx).await;
1671+
});
1672+
1673+
tokio::task::LocalSet::new()
1674+
.run_until(async move {
1675+
let response = vector_stream_websocket(
1676+
web::Data::new(app_ctx.clone()),
1677+
session.clone(),
1678+
web::Path::from(workflow_id),
1679+
web::Query(VectorStreamWebsocketQuery {
1680+
spatial_bounds: BoundingBox2D::new(
1681+
(-180., -90.).into(),
1682+
(180., 90.).into(),
1683+
)
1684+
.unwrap(),
1685+
time_interval: TimeInterval::new_instant(DateTime::new_utc(
1686+
2014, 3, 1, 0, 0, 0,
1687+
))
1688+
.unwrap()
1689+
.into(),
1690+
spatial_resolution: SpatialResolution::one(),
1691+
result_type: RasterStreamWebsocketResultType::Arrow,
1692+
}),
1693+
req,
1694+
payload,
1695+
)
1696+
.await
1697+
.unwrap();
1698+
1699+
let mut response_stream =
1700+
websocket_tests::response_messages(response, send_next_msg_trigger)
1701+
.boxed_local();
1702+
1703+
for _ in 0..1 {
1704+
let tile_bytes = response_stream.next().await.unwrap();
1705+
1706+
let record_batches = arrow_ipc_file_to_record_batches(&tile_bytes).unwrap();
1707+
assert_eq!(record_batches.len(), 1);
1708+
let record_batch = record_batches.first().unwrap();
1709+
let schema = record_batch.schema();
1710+
1711+
assert_eq!(schema.metadata()["spatialReference"], "EPSG:4326");
1712+
}
1713+
1714+
assert!(response_stream.next().await.is_none()); // No more messages expected
1715+
})
1716+
.await;
1717+
}
14971718
}

0 commit comments

Comments
 (0)