@@ -14,8 +14,9 @@ use crate::util::parsing::{
1414use crate :: util:: workflows:: validate_workflow;
1515use crate :: workflows:: registry:: WorkflowRegistry ;
1616use crate :: workflows:: workflow:: { Workflow , WorkflowId } ;
17- use crate :: workflows:: { RasterWebsocketStreamHandler , VectorWebsocketStreamHandler } ;
17+ use crate :: workflows:: { WebsocketStreamTask , handle_websocket_message , send_websocket_message } ;
1818use actix_web:: { FromRequest , HttpRequest , HttpResponse , Responder , web} ;
19+ use futures:: StreamExt ;
1920use futures:: future:: join_all;
2021use geoengine_datatypes:: error:: { BoxedResultExt , ErrorSource } ;
2122use geoengine_datatypes:: primitives:: {
@@ -538,18 +539,46 @@ async fn raster_stream_websocket<C: ApplicationContext>(
538539 RasterStreamWebsocketResultType :: Arrow
539540 ) ) ;
540541
541- let stream_handler = RasterWebsocketStreamHandler :: new :: < C :: SessionContext > (
542+ let mut stream_task = WebsocketStreamTask :: new_raster :: < C :: SessionContext > (
542543 operator,
543544 query_rectangle,
544545 ctx. execution_context ( ) ?,
545546 ctx. query_context ( workflow_id. 0 , Uuid :: new_v4 ( ) ) ?,
546547 )
547548 . await ?;
548549
549- match actix_web_actors:: ws:: start ( stream_handler, & request, stream) {
550- Ok ( websocket) => Ok ( websocket) ,
551- Err ( e) => Ok ( e. error_response ( ) ) ,
552- }
550+ let ( response, mut session, mut msg_stream) = match actix_ws:: handle ( & request, stream) {
551+ Ok ( ( response, session, msg_stream) ) => ( response, session, msg_stream) ,
552+ Err ( e) => return Ok ( e. error_response ( ) ) ,
553+ } ;
554+
555+ actix_web:: rt:: spawn ( async move {
556+ loop {
557+ let indicator = tokio:: select! {
558+ Some ( Ok ( msg) ) = msg_stream. next( ) => {
559+ handle_websocket_message( msg, & mut stream_task, & mut session) . await
560+ }
561+
562+ tile = stream_task. receive_tile( ) => {
563+ send_websocket_message( tile, session. clone( ) ) . await
564+ }
565+
566+ else => {
567+ None
568+ }
569+ } ;
570+
571+ if indicator. is_none ( ) {
572+ // the stream ended or session was closed, stop processing
573+ break ;
574+ }
575+ }
576+
577+ stream_task. abort_processing ( ) ;
578+ let _ = session. close ( None ) . await ;
579+ } ) ;
580+
581+ Ok ( response)
553582}
554583
555584/// The query parameters for `raster_stream_websocket`.
@@ -624,18 +653,46 @@ async fn vector_stream_websocket<C: ApplicationContext>(
624653 RasterStreamWebsocketResultType :: Arrow
625654 ) ) ;
626655
627- let stream_handler = VectorWebsocketStreamHandler :: new :: < C :: SessionContext > (
656+ let mut stream_task = WebsocketStreamTask :: new_vector :: < C :: SessionContext > (
628657 operator,
629658 query_rectangle,
630659 ctx. execution_context ( ) ?,
631660 ctx. query_context ( workflow_id. 0 , Uuid :: new_v4 ( ) ) ?,
632661 )
633662 . await ?;
634663
635- match actix_web_actors:: ws:: start ( stream_handler, & request, stream) {
636- Ok ( websocket) => Ok ( websocket) ,
637- Err ( e) => Ok ( e. error_response ( ) ) ,
638- }
664+ let ( response, mut session, mut msg_stream) = match actix_ws:: handle ( & request, stream) {
665+ Ok ( ( response, session, msg_stream) ) => ( response, session, msg_stream) ,
666+ Err ( e) => return Ok ( e. error_response ( ) ) ,
667+ } ;
668+
669+ actix_web:: rt:: spawn ( async move {
670+ loop {
671+ let indicator = tokio:: select! {
672+ Some ( Ok ( msg) ) = msg_stream. next( ) => {
673+ handle_websocket_message( msg, & mut stream_task, & mut session) . await
674+ }
675+
676+ tile = stream_task. receive_tile( ) => {
677+ send_websocket_message( tile, session. clone( ) ) . await
678+ }
679+
680+ else => {
681+ None
682+ }
683+ } ;
684+
685+ if indicator. is_none ( ) {
686+ // the stream ended or session was closed, stop processing
687+ break ;
688+ }
689+ }
690+
691+ stream_task. abort_processing ( ) ;
692+ let _ = session. close ( None ) . await ;
693+ } ) ;
694+
695+ Ok ( response)
639696}
640697
641698#[ derive( Debug , Snafu ) ]
@@ -669,18 +726,22 @@ mod tests {
669726 use crate :: tasks:: util:: test:: wait_for_task_to_finish;
670727 use crate :: tasks:: { TaskManager , TaskStatus } ;
671728 use crate :: users:: UserAuth ;
729+ use crate :: util:: tests:: add_ports_to_datasets;
672730 use crate :: util:: tests:: admin_login;
673731 use crate :: util:: tests:: {
674732 TestDataUploads , add_ndvi_to_datasets, check_allowed_http_methods,
675733 check_allowed_http_methods2, read_body_string, register_ndvi_workflow_helper,
676734 send_test_request,
677735 } ;
736+ use crate :: util:: websocket_tests;
678737 use crate :: workflows:: registry:: WorkflowRegistry ;
679738 use actix_web:: dev:: ServiceResponse ;
680739 use actix_web:: { http:: Method , http:: header, test} ;
681740 use actix_web_httpauth:: headers:: authorization:: Bearer ;
741+ use futures:: StreamExt ;
682742 use geoengine_datatypes:: collections:: MultiPointCollection ;
683743 use geoengine_datatypes:: primitives:: CacheHint ;
744+ use geoengine_datatypes:: primitives:: DateTime ;
684745 use geoengine_datatypes:: primitives:: {
685746 ContinuousMeasurement , FeatureData , Measurement , MultiPoint , RasterQueryRectangle ,
686747 SpatialPartition2D , SpatialResolution , TimeInterval ,
@@ -689,6 +750,7 @@ mod tests {
689750 use geoengine_datatypes:: spatial_reference:: SpatialReference ;
690751 use geoengine_datatypes:: test_data;
691752 use geoengine_datatypes:: util:: ImageFormat ;
753+ use geoengine_datatypes:: util:: arrow:: arrow_ipc_file_to_record_batches;
692754 use geoengine_datatypes:: util:: assert_image_equals_with_format;
693755 use geoengine_operators:: engine:: {
694756 ExecutionContext , MultipleRasterOrSingleVectorSource , PlotOperator , RasterBandDescriptor ,
@@ -700,6 +762,8 @@ mod tests {
700762 MockRasterSourceParams ,
701763 } ;
702764 use geoengine_operators:: plot:: { Statistics , StatisticsParams } ;
765+ use geoengine_operators:: source:: OgrSource ;
766+ use geoengine_operators:: source:: OgrSourceParameters ;
703767 use geoengine_operators:: source:: { GdalSource , GdalSourceParameters } ;
704768 use geoengine_operators:: util:: input:: MultiRasterOrVectorOperator :: Raster ;
705769 use geoengine_operators:: util:: raster_stream_to_geotiff:: {
@@ -1494,4 +1558,161 @@ mod tests {
14941558 ImageFormat :: Tiff ,
14951559 ) ;
14961560 }
1561+
1562+ #[ ge_context:: test]
1563+ #[ allow( clippy:: too_many_lines) ]
1564+ async fn it_serves_raster_streams_via_websockets ( app_ctx : PostgresContext < NoTls > ) {
1565+ let session = app_ctx. create_anonymous_session ( ) . await . unwrap ( ) ;
1566+ let ctx = app_ctx. session_context ( session. clone ( ) ) ;
1567+
1568+ let ( _, dataset) = add_ndvi_to_datasets ( & app_ctx) . await ;
1569+
1570+ let workflow = Workflow {
1571+ operator : TypedOperator :: Raster (
1572+ GdalSource {
1573+ params : GdalSourceParameters { data : dataset } ,
1574+ }
1575+ . boxed ( ) ,
1576+ ) ,
1577+ } ;
1578+
1579+ let workflow_id = ctx. db ( ) . register_workflow ( workflow) . await . unwrap ( ) ;
1580+
1581+ let ( req, payload, mut input_tx, send_next_msg_trigger) =
1582+ websocket_tests:: test_client ( ) . await ;
1583+
1584+ tokio:: task:: spawn ( async move {
1585+ // Simulate sending messages to the websocket
1586+ for _ in 0 ..4 {
1587+ websocket_tests:: send_text ( & mut input_tx, "NEXT" ) . await ;
1588+ }
1589+ websocket_tests:: send_close ( & mut input_tx) . await ;
1590+ } ) ;
1591+
1592+ tokio:: task:: LocalSet :: new ( )
1593+ . run_until ( async move {
1594+ let response = raster_stream_websocket (
1595+ web:: Data :: new ( app_ctx. clone ( ) ) ,
1596+ session. clone ( ) ,
1597+ web:: Path :: from ( workflow_id) ,
1598+ web:: Query ( RasterStreamWebsocketQuery {
1599+ spatial_bounds : SpatialPartition2D :: new (
1600+ ( -180. , 90. ) . into ( ) ,
1601+ ( 180. , -90. ) . into ( ) ,
1602+ )
1603+ . unwrap ( ) ,
1604+ time_interval : TimeInterval :: new_instant ( DateTime :: new_utc (
1605+ 2014 , 3 , 1 , 0 , 0 , 0 ,
1606+ ) )
1607+ . unwrap ( )
1608+ . into ( ) ,
1609+ spatial_resolution : SpatialResolution :: one ( ) ,
1610+ attributes : geoengine_datatypes:: primitives:: BandSelection :: first ( ) . into ( ) ,
1611+ result_type : RasterStreamWebsocketResultType :: Arrow ,
1612+ } ) ,
1613+ req,
1614+ payload,
1615+ )
1616+ . await
1617+ . unwrap ( ) ;
1618+
1619+ let mut response_stream =
1620+ websocket_tests:: response_messages ( response, send_next_msg_trigger)
1621+ . boxed_local ( ) ;
1622+
1623+ for _ in 0 ..4 {
1624+ let tile_bytes = response_stream. next ( ) . await . unwrap ( ) ;
1625+
1626+ let record_batches = arrow_ipc_file_to_record_batches ( & tile_bytes) . unwrap ( ) ;
1627+ assert_eq ! ( record_batches. len( ) , 1 ) ;
1628+ let record_batch = record_batches. first ( ) . unwrap ( ) ;
1629+ let schema = record_batch. schema ( ) ;
1630+
1631+ assert_eq ! ( schema. metadata( ) [ "spatialReference" ] , "EPSG:4326" ) ;
1632+ }
1633+
1634+ assert ! ( response_stream. next( ) . await . is_none( ) ) ; // No more messages expected
1635+ } )
1636+ . await ;
1637+ }
1638+
1639+ #[ ge_context:: test]
1640+ #[ allow( clippy:: too_many_lines) ]
1641+ async fn it_serves_vector_streams_via_websockets ( app_ctx : PostgresContext < NoTls > ) {
1642+ let session = app_ctx. create_anonymous_session ( ) . await . unwrap ( ) ;
1643+ let ctx = app_ctx. session_context ( session. clone ( ) ) ;
1644+
1645+ let ( _, dataset) = add_ports_to_datasets ( & app_ctx, true , true ) . await ;
1646+
1647+ let workflow = Workflow {
1648+ operator : TypedOperator :: Vector (
1649+ OgrSource {
1650+ params : OgrSourceParameters {
1651+ data : dataset,
1652+ attribute_projection : None ,
1653+ attribute_filters : None ,
1654+ } ,
1655+ }
1656+ . boxed ( ) ,
1657+ ) ,
1658+ } ;
1659+
1660+ let workflow_id = ctx. db ( ) . register_workflow ( workflow) . await . unwrap ( ) ;
1661+
1662+ let ( req, payload, mut input_tx, send_next_msg_trigger) =
1663+ websocket_tests:: test_client ( ) . await ;
1664+
1665+ tokio:: task:: spawn ( async move {
1666+ // Simulate sending messages to the websocket
1667+ for _ in 0 ..1 {
1668+ websocket_tests:: send_text ( & mut input_tx, "NEXT" ) . await ;
1669+ }
1670+ websocket_tests:: send_close ( & mut input_tx) . await ;
1671+ } ) ;
1672+
1673+ tokio:: task:: LocalSet :: new ( )
1674+ . run_until ( async move {
1675+ let response = vector_stream_websocket (
1676+ web:: Data :: new ( app_ctx. clone ( ) ) ,
1677+ session. clone ( ) ,
1678+ web:: Path :: from ( workflow_id) ,
1679+ web:: Query ( VectorStreamWebsocketQuery {
1680+ spatial_bounds : BoundingBox2D :: new (
1681+ ( -180. , -90. ) . into ( ) ,
1682+ ( 180. , 90. ) . into ( ) ,
1683+ )
1684+ . unwrap ( ) ,
1685+ time_interval : TimeInterval :: new_instant ( DateTime :: new_utc (
1686+ 2014 , 3 , 1 , 0 , 0 , 0 ,
1687+ ) )
1688+ . unwrap ( )
1689+ . into ( ) ,
1690+ spatial_resolution : SpatialResolution :: one ( ) ,
1691+ result_type : RasterStreamWebsocketResultType :: Arrow ,
1692+ } ) ,
1693+ req,
1694+ payload,
1695+ )
1696+ . await
1697+ . unwrap ( ) ;
1698+
1699+ let mut response_stream =
1700+ websocket_tests:: response_messages ( response, send_next_msg_trigger)
1701+ . boxed_local ( ) ;
1702+
1703+ for _ in 0 ..1 {
1704+ let tile_bytes = response_stream. next ( ) . await . unwrap ( ) ;
1705+
1706+ let record_batches = arrow_ipc_file_to_record_batches ( & tile_bytes) . unwrap ( ) ;
1707+ assert_eq ! ( record_batches. len( ) , 1 ) ;
1708+ let record_batch = record_batches. first ( ) . unwrap ( ) ;
1709+ let schema = record_batch. schema ( ) ;
1710+
1711+ assert_eq ! ( schema. metadata( ) [ "spatialReference" ] , "EPSG:4326" ) ;
1712+ }
1713+
1714+ assert ! ( response_stream. next( ) . await . is_none( ) ) ; // No more messages expected
1715+ } )
1716+ . await ;
1717+ }
14971718}
0 commit comments