diff --git a/Settings-default.toml b/Settings-default.toml index 30f2fe3a6..bac509662 100644 --- a/Settings-default.toml +++ b/Settings-default.toml @@ -25,6 +25,17 @@ origin_coordinate_y = 0.0 tile_shape_pixels_x = 512 tile_shape_pixels_y = 512 +[executor] +queue_size = 10 +# Timeout for the raster scheduler in ms (a value of 0 means that the scheduler makes no attempt to merge tasks) +raster_scheduler_timeout_ms = 100 +# Determines how much dead space in percent (viewing queries as a 3D Cube) is allowed when mergin 2 requests +raster_scheduler_merge_threshold = 0.01 +# Timeout for the feature scheduler in ms (a value of 0 means that the scheduler makes no attempt to merge tasks) +feature_scheduler_timeout_ms=0 +# Determines how much dead space in percent (viewing queries as a 3D Cube) is allowed when mergin 2 requests +feature_scheduler_merge_threshold = 0.01 + [query_context] chunk_byte_size = 1048576 # TODO: find reasonable default diff --git a/datatypes/src/collections/feature_collection.rs b/datatypes/src/collections/feature_collection.rs index c88c7e7ac..6b50832c9 100644 --- a/datatypes/src/collections/feature_collection.rs +++ b/datatypes/src/collections/feature_collection.rs @@ -23,7 +23,7 @@ use std::rc::Rc; use std::sync::Arc; use std::{mem, slice}; -use crate::primitives::{BoolDataRef, Coordinate2D, DateTimeDataRef, TimeInstance}; +use crate::primitives::{BoolDataRef, BoundingBox2D, Coordinate2D, DateTimeDataRef, TimeInstance}; use crate::primitives::{ CategoryDataRef, FeatureData, FeatureDataRef, FeatureDataType, FeatureDataValue, FloatDataRef, Geometry, IntDataRef, TextDataRef, TimeInterval, @@ -39,6 +39,7 @@ use crate::{ collections::{FeatureCollectionError, IntoGeometryOptionsIterator}, operations::reproject::CoordinateProjection, }; +use geo::intersects::Intersects; use std::iter::FromIterator; use super::{geo_feature_collection::ReplaceRawArrayCoords, GeometryCollection}; @@ -833,6 +834,21 @@ impl<'a, GeometryRef> FeatureCollectionRow<'a, GeometryRef> { } } +impl<'a, GR> Intersects for FeatureCollectionRow<'a, GR> +where + GR: Intersects, +{ + fn intersects(&self, rhs: &BoundingBox2D) -> bool { + self.geometry.intersects(rhs) + } +} + +impl<'a, GR> Intersects for FeatureCollectionRow<'a, GR> { + fn intersects(&self, rhs: &TimeInterval) -> bool { + self.time_interval.intersects(rhs) + } +} + pub struct FeatureCollectionIterator<'a, GeometryIter> { geometries: GeometryIter, time_intervals: slice::Iter<'a, TimeInterval>, diff --git a/datatypes/src/primitives/bounding_box.rs b/datatypes/src/primitives/bounding_box.rs index dafc4b819..c037226b0 100644 --- a/datatypes/src/primitives/bounding_box.rs +++ b/datatypes/src/primitives/bounding_box.rs @@ -688,6 +688,17 @@ mod tests { assert!(bbox.contains_bbox(&bbox_in)); } + #[test] + fn bounding_box_contains_bbox_equal() { + let ll = Coordinate2D::new(1.0, 1.0); + let ur = Coordinate2D::new(4.0, 4.0); + let bbox = BoundingBox2D::new(ll, ur).unwrap(); + + let bbox_in = BoundingBox2D::new(ll, ur).unwrap(); + + assert!(bbox.contains_bbox(&bbox_in)); + } + #[test] fn bounding_box_contains_bbox_overlap() { let ll = Coordinate2D::new(1.0, 1.0); diff --git a/datatypes/src/primitives/multi_line_string.rs b/datatypes/src/primitives/multi_line_string.rs index d90952d3f..9504ad894 100644 --- a/datatypes/src/primitives/multi_line_string.rs +++ b/datatypes/src/primitives/multi_line_string.rs @@ -3,14 +3,14 @@ use std::convert::TryFrom; use arrow::array::{ArrayBuilder, BooleanArray}; use arrow::error::ArrowError; use float_cmp::{ApproxEq, F64Margin}; -use geo::algorithm::intersects::Intersects; +use geo::intersects::Intersects; use serde::{Deserialize, Serialize}; use snafu::ensure; use crate::collections::VectorDataType; use crate::error::Error; use crate::primitives::{ - error, BoundingBox2D, GeometryRef, MultiPoint, PrimitivesError, TypedGeometry, + error, BoundingBox2D, GeometryRef, MultiPoint, PrimitivesError, SpatialBounded, TypedGeometry, }; use crate::primitives::{Coordinate2D, Geometry}; use crate::util::arrow::{downcast_array, ArrowTyped}; @@ -280,6 +280,19 @@ impl<'g> MultiLineStringAccess for MultiLineStringRef<'g> { &self.point_coordinates } } +impl<'g> SpatialBounded for MultiLineStringRef<'g> { + fn spatial_bounds(&self) -> BoundingBox2D { + let coords = self.point_coordinates.iter().flat_map(|&x| x.iter()); + BoundingBox2D::from_coord_ref_iter(coords) + .expect("there must be at least one coordinate in a multilinestring") + } +} + +impl<'g> Intersects for MultiLineStringRef<'g> { + fn intersects(&self, rhs: &BoundingBox2D) -> bool { + self.spatial_bounds().intersects_bbox(rhs) + } +} impl<'g> From> for geojson::Geometry { fn from(geometry: MultiLineStringRef<'g>) -> geojson::Geometry { @@ -354,6 +367,32 @@ mod tests { ); } + #[test] + fn test_ref_intersects() { + let coordinates = vec![vec![(0.0, 0.0).into(), (10.0, 10.0).into()]]; + let multi_line_string_ref = + MultiLineStringRef::new(coordinates.iter().map(AsRef::as_ref).collect()).unwrap(); + + assert!( + multi_line_string_ref.intersects(&BoundingBox2D::new_unchecked( + (-1., -1.,).into(), + (11., 11.).into() + )) + ); + assert!( + multi_line_string_ref.intersects(&BoundingBox2D::new_unchecked( + (2., 2.,).into(), + (9., 9.).into() + )) + ); + assert!( + !multi_line_string_ref.intersects(&BoundingBox2D::new_unchecked( + (-2., -2.,).into(), + (-2., 12.).into() + )) + ); + } + #[test] fn approx_equal() { let a = MultiLineString::new(vec![ diff --git a/datatypes/src/primitives/multi_point.rs b/datatypes/src/primitives/multi_point.rs index 295bffc49..24c4e976d 100644 --- a/datatypes/src/primitives/multi_point.rs +++ b/datatypes/src/primitives/multi_point.rs @@ -3,6 +3,7 @@ use std::convert::{TryFrom, TryInto}; use arrow::array::{ArrayBuilder, BooleanArray}; use arrow::error::ArrowError; use float_cmp::{ApproxEq, F64Margin}; +use geo::intersects::Intersects; use serde::{Deserialize, Serialize}; use snafu::ensure; @@ -286,6 +287,14 @@ where } } +impl<'f> Intersects for MultiPointRef<'f> { + fn intersects(&self, rhs: &BoundingBox2D) -> bool { + self.point_coordinates + .iter() + .any(|c| rhs.contains_coordinate(c)) + } +} + impl ApproxEq for &MultiPoint { type Margin = F64Margin; @@ -392,4 +401,23 @@ mod tests { assert!(!approx_eq!(&MultiPoint, &a, &b, F64Margin::default())); } + + #[test] + fn ref_intersects_bbox() -> Result<()> { + let bbox = BoundingBox2D::new((0.0, 0.0).into(), (1.0, 1.0).into())?; + + let v1: Vec = vec![(0.5, 0.5).into()]; + let v2: Vec = vec![(1.0, 1.0).into()]; + let v3: Vec = vec![(0.5, 0.5).into(), (1.5, 1.5).into()]; + let v4: Vec = vec![(1.1, 1.1).into()]; + let v5: Vec = vec![(-0.1, -0.1).into(), (1.1, 1.1).into()]; + + assert!(MultiPointRef::new(&v1)?.intersects(&bbox)); + assert!(MultiPointRef::new(&v2)?.intersects(&bbox)); + assert!(MultiPointRef::new(&v3)?.intersects(&bbox)); + assert!(!MultiPointRef::new(&v4)?.intersects(&bbox)); + assert!(!MultiPointRef::new(&v5)?.intersects(&bbox)); + + Ok(()) + } } diff --git a/datatypes/src/primitives/multi_polygon.rs b/datatypes/src/primitives/multi_polygon.rs index 869c90977..67ffb2977 100644 --- a/datatypes/src/primitives/multi_polygon.rs +++ b/datatypes/src/primitives/multi_polygon.rs @@ -10,7 +10,8 @@ use snafu::ensure; use crate::collections::VectorDataType; use crate::error::Error; use crate::primitives::{ - error, BoundingBox2D, GeometryRef, MultiLineString, PrimitivesError, TypedGeometry, + error, BoundingBox2D, GeometryRef, MultiLineString, PrimitivesError, SpatialBounded, + TypedGeometry, }; use crate::primitives::{Coordinate2D, Geometry}; use crate::util::arrow::{downcast_array, ArrowTyped}; @@ -356,6 +357,25 @@ impl<'g> MultiPolygonAccess for MultiPolygonRef<'g> { } } +impl<'g> SpatialBounded for MultiPolygonRef<'g> { + fn spatial_bounds(&self) -> BoundingBox2D { + let outer_ring_coords = self + .polygons + .iter() + // Use exterior ring (first ring of a polygon) + .filter_map(|p| p.iter().next()) + .flat_map(|&exterior| exterior.iter()); + BoundingBox2D::from_coord_ref_iter(outer_ring_coords) + .expect("there must be at least one coordinate in a multipolygon") + } +} + +impl<'g> Intersects for MultiPolygonRef<'g> { + fn intersects(&self, rhs: &BoundingBox2D) -> bool { + self.spatial_bounds().intersects_bbox(rhs) + } +} + impl<'g> From> for geojson::Geometry { fn from(geometry: MultiPolygonRef<'g>) -> geojson::Geometry { geojson::Geometry::new(match geometry.polygons.len() { @@ -481,6 +501,48 @@ mod tests { assert_eq!(aggregate(&multi_polygon), aggregate(&multi_polygon_ref)); } + #[test] + fn test_ref_intersects() { + let coordinates = vec![vec![ + vec![ + (0.0, 0.0).into(), + (10.0, 0.0).into(), + (10.0, 10.0).into(), + (0.0, 10.0).into(), + (0.0, 0.0).into(), + ], + vec![ + (4.0, 4.0).into(), + (6.0, 4.0).into(), + (6.0, 6.0).into(), + (4.0, 6.0).into(), + (4.0, 4.0).into(), + ], + ]]; + let multi_polygon_ref = MultiPolygonRef::new( + coordinates + .iter() + .map(|r| r.iter().map(AsRef::as_ref).collect()) + .collect(), + ) + .unwrap(); + + assert!(multi_polygon_ref.intersects(&BoundingBox2D::new_unchecked( + (-1., -1.,).into(), + (11., 11.).into() + ))); + + assert!(multi_polygon_ref.intersects(&BoundingBox2D::new_unchecked( + (4.5, 4.5,).into(), + (5.5, 5.5).into() + ))); + + assert!(!multi_polygon_ref.intersects(&BoundingBox2D::new_unchecked( + (-11., -1.,).into(), + (-1., 11.).into() + ))); + } + #[test] fn approx_equal() { let a = MultiPolygon::new(vec![ diff --git a/datatypes/src/primitives/no_geometry.rs b/datatypes/src/primitives/no_geometry.rs index dbd0d714c..960409b3f 100644 --- a/datatypes/src/primitives/no_geometry.rs +++ b/datatypes/src/primitives/no_geometry.rs @@ -4,6 +4,7 @@ use std::convert::TryFrom; use arrow::array::{Array, ArrayBuilder, ArrayData, ArrayRef, BooleanArray, JsonEqual}; use arrow::datatypes::DataType; use arrow::error::ArrowError; +use geo::prelude::Intersects; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -26,6 +27,12 @@ impl Geometry for NoGeometry { } } +impl Intersects for NoGeometry { + fn intersects(&self, _rhs: &BoundingBox2D) -> bool { + true + } +} + impl GeometryRef for NoGeometry {} impl TryFrom for NoGeometry { diff --git a/datatypes/src/primitives/spatial_partition.rs b/datatypes/src/primitives/spatial_partition.rs index 026666740..8ca6acb81 100644 --- a/datatypes/src/primitives/spatial_partition.rs +++ b/datatypes/src/primitives/spatial_partition.rs @@ -167,11 +167,11 @@ impl SpatialPartition2D { } fn contains_x(&self, other: &Self) -> bool { - crate::util::ranges::value_in_range( + crate::util::ranges::value_in_range_inclusive( other.upper_left_coordinate.x, self.upper_left_coordinate.x, self.lower_right_coordinate.x, - ) && crate::util::ranges::value_in_range( + ) && crate::util::ranges::value_in_range_inclusive( other.lower_right_coordinate.x, self.upper_left_coordinate.x, self.lower_right_coordinate.x, @@ -179,11 +179,11 @@ impl SpatialPartition2D { } fn contains_y(&self, other: &Self) -> bool { - crate::util::ranges::value_in_range_inv( + crate::util::ranges::value_in_range_inclusive( other.lower_right_coordinate.y, self.lower_right_coordinate.y, self.upper_left_coordinate.y, - ) && crate::util::ranges::value_in_range_inv( + ) && crate::util::ranges::value_in_range_inclusive( other.upper_left_coordinate.y, self.lower_right_coordinate.y, self.upper_left_coordinate.y, @@ -305,6 +305,14 @@ mod tests { assert!(!p2.contains(&p1)); } + #[test] + fn it_contains_itself() { + let p1 = SpatialPartition2D::new_unchecked((0., 1.).into(), (1., 0.).into()); + let p2 = SpatialPartition2D::new_unchecked((0., 1.).into(), (1., 0.).into()); + assert!(p1.contains(&p2)); + assert!(p2.contains(&p1)); + } + #[test] fn it_contains_coord() { let p1 = SpatialPartition2D::new_unchecked((0., 1.).into(), (1., 0.).into()); diff --git a/operators/Cargo.toml b/operators/Cargo.toml index 6ffa9dfae..553c86833 100644 --- a/operators/Cargo.toml +++ b/operators/Cargo.toml @@ -30,7 +30,7 @@ libloading = "0.7" log = "0.4" num-traits = "0.2" num = "0.4" -ouroboros = "0.14" +ouroboros = "0.15" paste = "1.0" pest = "2.1" pest_derive = "2.1" @@ -52,6 +52,7 @@ uuid = { version = "0.8", features = ["serde", "v4", "v5"] } [dev-dependencies] async-stream = "0.3" geo-rand = { git = "https://github.com/lelongg/geo-rand", tag = "v0.3.0" } +tokio-util = "0.6" rand = "0.8" diff --git a/operators/src/error.rs b/operators/src/error.rs index 6a08fbcfe..543c9fe4d 100644 --- a/operators/src/error.rs +++ b/operators/src/error.rs @@ -271,10 +271,18 @@ pub enum Error { source: crate::util::statistics::StatisticsError, }, + #[cfg(feature = "pro")] + #[snafu(display("Executor error: {}", source))] + Executor { + #[snafu(implicit)] + source: crate::pro::executor::error::ExecutorError, + }, + #[snafu(display("SparseTilesFillAdapter error: {}", source))] SparseTilesFillAdapter { source: crate::adapters::SparseTilesFillAdapterError, }, + #[snafu(context(false))] ExpressionOperator { source: crate::processing::ExpressionError, diff --git a/operators/src/pro/executor/error.rs b/operators/src/pro/executor/error.rs new file mode 100644 index 000000000..ae3c20169 --- /dev/null +++ b/operators/src/pro/executor/error.rs @@ -0,0 +1,48 @@ +use crate::error::Error; +use snafu::Snafu; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::oneshot::error::RecvError; +use tokio::task::JoinError; + +pub type Result = std::result::Result; + +#[derive(Debug, Clone, Snafu)] +pub enum ExecutorError { + Submission { message: String }, + Panic, + Cancelled, +} + +impl From for ExecutorError { + fn from(src: JoinError) -> Self { + if src.is_cancelled() { + ExecutorError::Cancelled + } else { + ExecutorError::Panic + } + } +} + +impl From> for ExecutorError { + fn from(e: SendError) -> Self { + Self::Submission { + message: e.to_string(), + } + } +} + +impl From for ExecutorError { + fn from(e: RecvError) -> Self { + Self::Submission { + message: e.to_string(), + } + } +} + +impl From for ExecutorError { + fn from(e: Error) -> Self { + Self::Submission { + message: e.to_string(), + } + } +} diff --git a/operators/src/pro/executor/mod.rs b/operators/src/pro/executor/mod.rs new file mode 100644 index 000000000..1a442b931 --- /dev/null +++ b/operators/src/pro/executor/mod.rs @@ -0,0 +1,808 @@ +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::future::BoxFuture; +use futures::stream::{BoxStream, FuturesUnordered}; +use futures::{Future, Stream, StreamExt}; +use tokio::sync::mpsc; +use tokio::task::{JoinError, JoinHandle}; + +use error::{ExecutorError, Result}; +use replay::SendError; + +pub mod error; +pub mod operators; +pub mod replay; + +/// Description of an [Executor] task. The description consists provides a primary +/// key to identify identical computations (`KeyType`). +pub trait ExecutorTaskDescription: Clone + Send + Sync + Debug + 'static { + /// A unique identifier for the actual computation of a task + type KeyType: Debug + Clone + Send + PartialEq + Eq + Hash; + /// The result type of the computation + type ResultType: Sync + Send; + /// Returns the unique identifies + fn primary_key(&self) -> &Self::KeyType; + /// Determines if this computation can be satisfied by using the result + /// of `other`. E.g., if the bounding box of `other` contains this tasks bounding box. + /// + /// Note: It is not required to check the `primary_key`s in this method. + fn is_contained_in(&self, other: &Self) -> bool; + /// Extracts the result for this task from the result of a possibly 'greater' + /// result. + /// + /// If the computation returns a stream of `ResultType` irrelevant elements may by filtered + /// out by returning `None`. Moreover, it is also possible to extract a subset of elements + /// (e.g., features or pixels) from the given result. In those cases this method simply + /// returns a new instance of `ResultType`. + /// + fn slice_result(&self, result: &Self::ResultType) -> Option; +} + +/// Size of the task submission queue. +const TASK_SUBMISSION_QUEUE_SIZE: usize = 128; + +type TerminationMessage = + std::result::Result<(), SendError, ExecutorError>>>; + +type ReplayReceiver = replay::Receiver>>; + +/// Encapsulates a requested stream computation to send +/// it to the executor task. +struct KeyedComputation +where + Desc: ExecutorTaskDescription, +{ + key: Desc, + response: tokio::sync::oneshot::Sender>>, + stream_future: BoxFuture<'static, Result>>, +} + +/// A helper to retrieve a computation's key even if +/// the executing task failed. +#[pin_project::pin_project] +struct KeyedJoinHandle +where + Desc: ExecutorTaskDescription, +{ + // The computation's key + key: Desc, + // A unique sequence number of the computation. See `ComputationEntry` + // for a detailed description. + id: usize, + // The sender side of the replay channel. + sender: replay::Sender>>, + // The join handle of the underlying task + #[pin] + handle: JoinHandle<()>, +} + +/// The result when waiting on a `KeyedJoinHandle` +struct KeyedJoinResult +where + Desc: ExecutorTaskDescription, +{ + // The computation's key + key: Desc, + // A unique sequence number of the computation. See `ComputationEntry` + // for a detailed description. + id: usize, + // The sender side of the replay channel. + sender: replay::Sender>>, + // The result returned by the task. + result: std::result::Result<(), JoinError>, +} + +impl Future for KeyedJoinHandle +where + Desc: ExecutorTaskDescription, +{ + type Output = KeyedJoinResult; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.handle.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => Poll::Ready(KeyedJoinResult { + key: this.key.clone(), + id: *this.id, + sender: this.sender.clone(), + result, + }), + } + } +} + +/// An entry for the map of currently running computations. +/// It contains a sender from the associated replay channel +/// in order to register new consumers. +/// +/// See [`replay::Sender::subscribe`]. +struct ComputationEntry +where + Desc: ExecutorTaskDescription, +{ + // A unique sequence number of the computation. + // This is used to decide whether or not to drop computation + // infos after completion. If a computation task for a given key + // is finished, it might be that another one took its place + // in the list of running computations. This happens, of a computation + // progressed too far already and a new consumer cannot join it anymore. + // In those cases a new task is started and takes the slot in the list + // of active computation. Now, if the old computation terminates, the newly + // started computation must *NOT* be removed from this list. + // We use this ID to handle such cases. + id: usize, + key: Desc, + sender: replay::Sender>>, +} + +/// This encapsulates the main loop of the executor task. +/// It manages all new, running and finished computations. +struct ExecutorLooper +where + Desc: ExecutorTaskDescription, +{ + buffer_size: usize, + receiver: mpsc::Receiver>, + computations: HashMap>>, + tasks: FuturesUnordered>, + termination_msgs: FuturesUnordered>>, + id_seq: usize, +} + +impl ExecutorLooper +where + Desc: ExecutorTaskDescription, +{ + /// Creates a new Looper. + fn new( + buffer_size: usize, + receiver: mpsc::Receiver>, + ) -> ExecutorLooper { + ExecutorLooper { + buffer_size, + receiver, + computations: HashMap::new(), + tasks: FuturesUnordered::new(), + termination_msgs: FuturesUnordered::new(), + id_seq: 0, + } + } + + fn next_id(current: &mut usize) -> usize { + *current += 1; + *current + } + + /// Handles a new computation request. It first tries to join + /// a running computation. If this is not possible, a new + /// computation is started. + async fn handle_new_task(&mut self, kc: KeyedComputation) { + let key = kc.key; + + let entry = self.computations.entry(key.primary_key().clone()); + let receiver = match entry { + // There is a computation running + std::collections::hash_map::Entry::Occupied(mut oe) => { + let entries = oe.get_mut(); + let append = entries + .iter_mut() + .filter(|ce| key.is_contained_in(&ce.key)) + .find_map(|ce| match ce.sender.subscribe() { + Ok(rx) => Some((&ce.key, rx)), + Err(_) => None, + }); + + match append { + Some((desc, rx)) => { + log::debug!( + "Joining running computation for request. New: {:?}, Running: {:?}", + &key, + desc + ); + rx + } + None => { + log::debug!("Stream progressed too far or results do not cover requested result. Starting new computation for request: {:?}", &key); + + let stream = match kc.stream_future.await { + Ok(s) => s, + Err(e) => return kc.response.send(Err(e)).unwrap_or(()), + }; + + let (entry, rx) = Self::start_computation( + self.buffer_size, + Self::next_id(&mut self.id_seq), + key, + stream, + &mut self.tasks, + ); + entries.push(entry); + rx + } + } + } + // Start a new computation + std::collections::hash_map::Entry::Vacant(ve) => { + log::debug!("Starting new computation for request: {:?}", &key); + + let stream = match kc.stream_future.await { + Ok(s) => s, + Err(e) => return kc.response.send(Err(e)).unwrap_or(()), + }; + + let (entry, rx) = Self::start_computation( + self.buffer_size, + Self::next_id(&mut self.id_seq), + key, + stream, + &mut self.tasks, + ); + ve.insert(vec![entry]); + rx + } + }; + + // This can only fail, if the receiver side is dropped + if kc.response.send(Ok(receiver)).is_err() { + log::warn!("Result consumer dropped unexpectedly."); + } + } + + /// Starts a new computation. It spawns a separate task + /// in which the computation is executed. Therefore it establishes + /// a new [`replay::channel`] and returns infos about the computation + /// and the receiving side of the replay channel. + fn start_computation( + buffer_size: usize, + id: usize, + key: Desc, + mut stream: BoxStream<'static, Desc::ResultType>, + tasks: &mut FuturesUnordered>, + ) -> (ComputationEntry, ReplayReceiver) { + let (tx, rx) = replay::channel(buffer_size); + + let entry = ComputationEntry { + id, + key: key.clone(), + sender: tx.clone(), + }; + + let jh = { + let tx = tx.clone(); + let key = key.clone(); + tokio::spawn(async move { + while let Some(v) = stream.next().await { + if let Err(replay::SendError::Closed(_)) = tx.send(Ok(Arc::new(v))).await { + log::debug!("All consumers left. Cancelling task: {:?}", &key); + break; + } + } + }) + }; + + tasks.push(KeyedJoinHandle { + key, + id, + sender: tx, + handle: jh, + }); + (entry, rx) + } + + /// Handles computations that ran to completion. It removes + /// them from the map of running computations and notifies + /// consumers. Furthermore, if the computation task was cancelled + /// or panicked, this is also propagated. + /// + fn handle_completed_task(&mut self, completed_task: KeyedJoinResult) { + let id = completed_task.id; + + // Remove the map entry only, if the completed task's id matches the stored task's id + // There may be older tasks around (with smaller ids) that should not trigger a removal from the map. + if let std::collections::hash_map::Entry::Occupied(mut oe) = self + .computations + .entry(completed_task.key.primary_key().clone()) + { + if let Some(idx) = oe.get().iter().position(|x| x.id == id) { + oe.get_mut().swap_remove(idx); + } + if oe.get().is_empty() { + oe.remove(); + } + } + + match completed_task.result { + Err(e) => { + self.termination_msgs.push(Box::pin(async move { + if e.try_into_panic().is_ok() { + log::warn!( + "Stream task panicked. Notifying consumer streams. Request: {:?}", + &completed_task.key + ); + completed_task.sender.send(Err(ExecutorError::Panic)).await + } else { + log::warn!( + "Stream task was cancelled. Notifying consumer streams. Request: {:?}", + &completed_task.key + ); + completed_task + .sender + .send(Err(ExecutorError::Cancelled)) + .await + } + })); + } + Ok(_) => { + log::debug!( + "Computation finished. Notifying consumer streams. Request: {:?}", + &completed_task.key + ); + // After destroying the sender all remaining receivers will receive end-of-stream + } + } + } + + /// This is the main loop. Here we check for new and completed tasks, + /// and also drive termination messages forward. + pub async fn main_loop(&mut self) { + log::info!("Starting executor loop."); + loop { + tokio::select! { + new_task = self.receiver.recv() => { + if let Some(kc) = new_task { + self.handle_new_task(kc).await; + } + else { + log::info!("Executor terminated."); + break; + } + }, + Some(completed_task) = self.tasks.next() => { + self.handle_completed_task(completed_task); + }, + Some(_) = self.termination_msgs.next() => { + log::debug!("Successfully delivered termination message."); + } + } + } + log::info!("Finished executor loop."); + } +} + +/// The `Executor` runs async (streaming) computations. It allows multiple consumers +/// per stream so that results are computed only once. +/// A pre-defined buffer size determines how many elements of a stream are kept. This +/// size can be seen as a window that slides forward, if the slowest consumer consumes +/// the oldest element. +/// New consumers join the same computation if the window did not slide forward at +/// the time of task submission. Otherwise, a new computation task is started. +pub struct Executor +where + Desc: ExecutorTaskDescription, +{ + sender: mpsc::Sender>, + driver: JoinHandle<()>, +} + +impl Executor +where + Desc: ExecutorTaskDescription, +{ + /// Creates a new `Executor` instance, ready to serve computations. The buffer + /// size determines how much elements are at most kept in memory per computation. + pub fn new(buffer_size: usize) -> Executor { + let (sender, receiver) = + tokio::sync::mpsc::channel::>(TASK_SUBMISSION_QUEUE_SIZE); + + let mut looper = ExecutorLooper::new(buffer_size, receiver); + + // This is the task that is responsible for driving the async computations and + // notifying consumers about success and failure. + let driver = tokio::spawn(async move { looper.main_loop().await }); + + Executor { sender, driver } + } + + /// Submits a streaming computation to this executor. + /// + /// #Errors + /// This call fails, if the `Executor` was already closed. + pub async fn submit_stream( + &self, + key: Desc, + stream: Pin + Send + 'static>>, + ) -> Result> { + self.submit_stream_future(key, futures::future::ok(stream)) + .await + } + + /// Submits a streaming computation to this executor. + /// + /// #Errors + /// This call fails, if the `Executor` was already closed. + pub async fn submit_stream_future( + &self, + key: Desc, + stream_future: F, + ) -> Result> + where + F: Future + Send + 'static>>>> + + Send + + 'static, + { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let kc = KeyedComputation { + key: key.clone(), + stream_future: Box::pin(stream_future), + response: tx, + }; + + self.sender.send(kc).await?; + let res = rx.await??; + + Ok(StreamReceiver::new(key.clone(), res.into())) + } + + /// Submits a single-result computation to this executor. + /// + /// #Errors + /// This call fails, if the `Executor` was already closed. + #[allow(clippy::missing_panics_doc)] + pub async fn submit(&self, key: Desc, f: F) -> Result + where + F: Future + Send + 'static, + { + let mut stream = self + .submit_stream(key, Box::pin(futures::stream::once(f))) + .await?; + Ok(stream + .next() + .await + .expect("Futures always produce a result.")) + } + + pub async fn close(self) -> Result<()> { + drop(self.sender); + Ok(self.driver.await?) + } +} + +#[pin_project::pin_project] +pub struct StreamReceiver +where + Desc: ExecutorTaskDescription, +{ + key: Desc, + #[pin] + input: replay::ReceiverStream>>, +} + +impl StreamReceiver +where + Desc: ExecutorTaskDescription, +{ + fn new( + key: Desc, + input: replay::ReceiverStream>>, + ) -> StreamReceiver { + StreamReceiver { key, input } + } +} + +impl Stream for StreamReceiver +where + Desc: ExecutorTaskDescription, +{ + type Item = Desc::ResultType; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + match this.input.as_mut().poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Ready(Some(Ok(res))) => { + if let Some(sliced) = this.key.slice_result(res.as_ref()) { + // Slicing produced a result -> Otherwise loop again + return Poll::Ready(Some(sliced)); + } + } + Poll::Ready(Some(Err(ExecutorError::Panic))) => panic!("Executor task panicked!"), + Poll::Ready(Some(Err(ExecutorError::Cancelled))) => { + panic!("Executor task cancelled!") + } + // Submission already succeeded -> Unreachable. + Poll::Ready(Some(Err(ExecutorError::Submission { .. }))) => unreachable!(), + } + } + } +} + +#[cfg(test)] +mod tests { + use std::pin::Pin; + use std::sync::Arc; + use std::task::{Context, Poll}; + + use futures::{Stream, StreamExt}; + + use crate::pro::executor::error::ExecutorError; + use crate::pro::executor::ExecutorTaskDescription; + + use super::Executor; + + impl ExecutorTaskDescription for i32 { + type KeyType = i32; + type ResultType = i32; + + fn primary_key(&self) -> &Self::KeyType { + self + } + + fn is_contained_in(&self, other: &Self) -> bool { + self == other + } + + fn slice_result(&self, result: &Self::ResultType) -> Option { + Some(*result) + } + } + + #[tokio::test] + async fn test_stream_empty_stream() -> Result<(), ExecutorError> { + let e = Executor::::new(5); + + let sf1 = e + .submit_stream(1, Box::pin(futures::stream::iter(Vec::::new()))) + .await + .unwrap(); + + let results = sf1.collect::>().await; + + assert!(results.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn test_stream_single_consumer() -> Result<(), ExecutorError> { + let e = Executor::::new(5); + + let sf1 = e + .submit_stream(1, Box::pin(futures::stream::iter(vec![1, 2, 3]))) + .await + .unwrap(); + + let results = sf1.collect::>().await; + + assert_eq!(vec![1, 2, 3], results); + Ok(()) + } + + #[tokio::test] + async fn test_stream_two_consumers() -> Result<(), ExecutorError> { + let e = Executor::new(5); + + let sf1 = e.submit_stream(1, Box::pin(futures::stream::iter(vec![1, 2, 3]))); + let sf2 = e.submit_stream(1, Box::pin(futures::stream::iter(vec![1, 2, 3]))); + + let (sf1, sf2) = tokio::join!(sf1, sf2); + + let (mut sf1, mut sf2) = (sf1?, sf2?); + + let mut res1 = vec![]; + let mut res2 = vec![]; + + loop { + tokio::select! { + Some(v) = sf1.next() => { + res1.push(v); + }, + Some(v) = sf2.next() => { + res2.push(v); + }, + else => { + break; + } + } + } + + assert_eq!(vec![1, 2, 3], res1); + assert_eq!(vec![1, 2, 3], res2); + + Ok(()) + } + + struct PanicStream {} + + impl Stream for PanicStream { + type Item = i32; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + panic!("Expected panic!"); + } + } + + #[tokio::test] + #[should_panic(expected = "Executor task panicked!")] + async fn test_stream_propagate_panic() { + let e = Executor::new(5); + let sf = e.submit_stream(1, Box::pin(PanicStream {})); + let mut sf = sf.await.unwrap(); + sf.next().await; + } + + #[tokio::test] + async fn test_stream_consumer_drop() { + let e = Executor::new(2); + let chk = { + e.submit_stream(1, Box::pin(futures::stream::iter(vec![1, 2, 3]))) + .await + .unwrap() + .next() + .await + }; + // Assert that the task is dropped if all consumers are dropped. + // Therefore, resubmit another one and ensure we produce new results. + let chk2 = { + e.submit_stream(1, Box::pin(futures::stream::iter(vec![2, 2, 3]))) + .await + .unwrap() + .next() + .await + }; + + assert_eq!(Some(1), chk); + assert_eq!(Some(2), chk2); + } + + #[tokio::test] + async fn test_simple() -> Result<(), ExecutorError> { + let e = Executor::new(5); + let f = e.submit(1, async { 2 }); + + assert_eq!(2, f.await?); + + let f = e.submit(1, async { 42 }); + assert_eq!(42, f.await?); + + Ok(()) + } + + #[tokio::test] + async fn test_multi_consumers() -> Result<(), ExecutorError> { + let e = Executor::new(5); + let f = e.submit(1, async { 2 }); + let f2 = e.submit(1, async { 2 }); + + let (r1, r2) = tokio::join!(f, f2); + let (r1, r2) = (r1?, r2?); + + assert_eq!(r1, r2); + + let f = e.submit(1, async { 2 }); + let f2 = e.submit(1, async { 2 }); + + let r1 = f.await?; + let r2 = f2.await?; + assert_eq!(r1, r2); + + Ok(()) + } + + #[tokio::test] + #[should_panic] + async fn test_panic() { + let e = Executor::new(5); + let f = e.submit(1, async { panic!("booom") }); + + f.await.unwrap(); + } + + #[tokio::test] + async fn test_close() -> Result<(), ExecutorError> { + let e = Executor::new(5); + let f = e.submit(1, async { 2 }); + assert_eq!(2, f.await?); + let c = e.close(); + c.await?; + + Ok(()) + } + + #[derive(Clone, Debug)] + struct IntDesc { + key: i32, + do_slice: bool, + } + + impl ExecutorTaskDescription for IntDesc { + type KeyType = i32; + type ResultType = Arc; + + fn primary_key(&self) -> &Self::KeyType { + &self.key + } + + fn is_contained_in(&self, other: &Self) -> bool { + self.key == other.key + } + + fn slice_result(&self, result: &Self::ResultType) -> Option { + if !self.do_slice || *result.as_ref() % 2 == 0 { + Some(result.clone()) + } else { + None + } + } + } + + #[tokio::test] + async fn test_slicing() { + let d1 = IntDesc { + key: 1, + do_slice: false, + }; + + let d2 = IntDesc { + key: 1, + do_slice: true, + }; + + let e = Executor::new(1); + + let mut s1 = e + .submit_stream( + d1, + Box::pin(futures::stream::iter(vec![ + Arc::new(1), + Arc::new(2), + Arc::new(3), + Arc::new(4), + ])), + ) + .await + .unwrap(); + + let mut s2 = e + .submit_stream( + d2, + Box::pin(futures::stream::iter(vec![ + Arc::new(1), + Arc::new(2), + Arc::new(3), + Arc::new(4), + ])), + ) + .await + .unwrap(); + + let mut res1: Vec> = vec![]; + let mut res2: Vec> = vec![]; + + loop { + tokio::select! { + Some(v) = s1.next() => { + res1.push(v); + }, + Some(v) = s2.next() => { + res2.push(v); + }, + else => { + break; + } + } + } + + assert_eq!(4, res1.len()); + assert_eq!(2, res2.len()); + assert!(Arc::ptr_eq(res1.get(1).unwrap(), res2.get(0).unwrap())); + assert!(Arc::ptr_eq(res1.get(3).unwrap(), res2.get(1).unwrap())); + } +} diff --git a/operators/src/pro/executor/operators.rs b/operators/src/pro/executor/operators.rs new file mode 100644 index 000000000..bb945c9ea --- /dev/null +++ b/operators/src/pro/executor/operators.rs @@ -0,0 +1,339 @@ +use crate::engine::{QueryContext, RasterQueryProcessor, VectorQueryProcessor}; +use crate::util::Result; +use futures::stream::BoxStream; +use futures::task::{Context, Poll}; +use futures::Stream; +use geoengine_datatypes::primitives::{ + AxisAlignedRectangle, BoundingBox2D, QueryRectangle, RasterQueryRectangle, SpatialPartition2D, + VectorQueryRectangle, +}; +use geoengine_datatypes::raster::{Pixel, RasterTile2D}; +use ouroboros::self_referencing; +use std::pin::Pin; +use std::sync::Arc; + +/// Turns a [`QueryProcessor`][crate::engine::QueryProcessor] into a [Stream] of results. +#[async_trait::async_trait] +pub trait OneshotQueryProcessor { + type BBox: AxisAlignedRectangle; + type Output: Stream + Send + 'static; + + async fn into_stream( + self, + qr: QueryRectangle, + ctx: Arc, + ) -> Result; +} + +/// Helper struct to tie together a [result stream][Stream] with the bounding [`VectorQueryProcessor`]. +#[self_referencing] +struct FeatureStreamWrapper +where + V: 'static + Send, +{ + proc: Arc>, + ctx: Arc, + #[borrows(proc, ctx)] + #[covariant] + stream: BoxStream<'this, Result>, +} + +/// A helper to generate a static [`Stream`] for a given [`VectorQueryProcessor`]. +pub struct FeatureStreamBoxer +where + V: 'static + Send, +{ + h: FeatureStreamWrapper, +} + +impl FeatureStreamBoxer +where + V: 'static + Send, +{ + /// Consumes a [`VectorQueryProcessor`], a [`VectorQueryRectangle`] and a [`QueryContext`] and + /// returns a `'static` [`Stream`] that can be passed to an [`Executor`][crate::pro::executor::Executor]. + pub async fn new( + proc: Arc>, + qr: VectorQueryRectangle, + ctx: Arc, + ) -> Result { + let h = FeatureStreamWrapperAsyncSendTryBuilder { + proc, + ctx, + stream_builder: |proc, ctx| { + Box::pin(async move { proc.vector_query(qr, ctx.as_ref()).await }) + }, + } + .try_build() + .await?; + + Ok(Self { h }) + } +} + +impl Stream for FeatureStreamBoxer +where + V: 'static + Send, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut() + .h + .with_stream_mut(|s| Pin::new(s).poll_next(cx)) + } +} + +#[async_trait::async_trait] +impl OneshotQueryProcessor for Box> +where + V: 'static + Send, +{ + type BBox = BoundingBox2D; + type Output = FeatureStreamBoxer; + + async fn into_stream( + self, + qr: QueryRectangle, + ctx: Arc, + ) -> Result { + FeatureStreamBoxer::new(self.into(), qr, ctx).await + } +} + +#[async_trait::async_trait] +impl OneshotQueryProcessor for Arc> +where + V: 'static + Send, +{ + type BBox = BoundingBox2D; + type Output = FeatureStreamBoxer; + + async fn into_stream( + self, + qr: QueryRectangle, + ctx: Arc, + ) -> Result { + FeatureStreamBoxer::new(self, qr, ctx).await + } +} + +/// Helper struct to tie together a [result stream][Stream] with the bounding [`RasterQueryProcessor`]. +#[self_referencing] +struct RasterStreamWrapper +where + RasterType: 'static + Send, +{ + proc: Arc>, + ctx: Arc, + #[borrows(proc, ctx)] + #[covariant] + stream: BoxStream<'this, Result>>, +} + +pub struct RasterStreamBoxer +where + RasterType: 'static + Send, +{ + h: RasterStreamWrapper, +} + +impl RasterStreamBoxer +where + RasterType: Pixel, +{ + /// Consumes a [`RasterQueryProcessor`], a [`RasterQueryRectangle`] and a [`QueryContext`] and + /// returns a `'static` [`Stream`] that can be passed to an [`Executor`][crate::pro::executor::Executor]. + pub async fn new( + proc: Arc>, + qr: RasterQueryRectangle, + ctx: Arc, + ) -> Result { + let h = RasterStreamWrapperAsyncSendTryBuilder { + proc, + ctx, + stream_builder: |proc, ctx| { + Box::pin(async move { proc.raster_query(qr, ctx.as_ref()).await }) + }, + } + .try_build() + .await?; + Ok(Self { h }) + } +} + +impl Stream for RasterStreamBoxer +where + RasterType: Pixel, +{ + type Item = Result>; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut() + .h + .with_stream_mut(|s| Pin::new(s).poll_next(cx)) + } +} + +#[async_trait::async_trait] +impl OneshotQueryProcessor + for Box> +where + RasterType: Pixel, +{ + type BBox = SpatialPartition2D; + type Output = RasterStreamBoxer; + + async fn into_stream( + self, + qr: QueryRectangle, + ctx: Arc, + ) -> Result { + RasterStreamBoxer::new(self.into(), qr, ctx).await + } +} + +#[async_trait::async_trait] +impl OneshotQueryProcessor + for Arc> +where + RasterType: Pixel, +{ + type BBox = SpatialPartition2D; + type Output = RasterStreamBoxer; + + async fn into_stream( + self, + qr: QueryRectangle, + ctx: Arc, + ) -> Result { + RasterStreamBoxer::new(self, qr, ctx).await + } +} + +#[cfg(test)] +mod tests { + use crate::engine::{ + MockQueryContext, QueryContext, RasterQueryProcessor, VectorQueryProcessor, + }; + use crate::error::Error; + use crate::pro::executor::operators::OneshotQueryProcessor; + use futures::prelude::stream::BoxStream; + use futures::StreamExt; + use geoengine_datatypes::primitives::{ + BoundingBox2D, NoGeometry, RasterQueryRectangle, SpatialPartition2D, SpatialResolution, + TimeInterval, VectorQueryRectangle, + }; + use geoengine_datatypes::raster::RasterTile2D; + use geoengine_datatypes::util::test::TestDefault; + use std::sync::Arc; + + struct TestProcesor { + fail: bool, + } + + #[async_trait::async_trait] + impl VectorQueryProcessor for TestProcesor { + type VectorType = NoGeometry; + + async fn vector_query<'a>( + &'a self, + _query: VectorQueryRectangle, + _ctx: &'a dyn QueryContext, + ) -> crate::util::Result>> { + if self.fail { + Err(Error::QueryProcessor) + } else { + let s = futures::stream::empty(); + Ok(Box::pin(s)) + } + } + } + + #[async_trait::async_trait] + impl RasterQueryProcessor for TestProcesor { + type RasterType = u8; + + async fn raster_query<'a>( + &'a self, + _query: RasterQueryRectangle, + _ctx: &'a dyn QueryContext, + ) -> crate::util::Result>>> + { + if self.fail { + Err(Error::QueryProcessor) + } else { + let s = futures::stream::empty(); + Ok(Box::pin(s)) + } + } + } + + #[tokio::test] + async fn test_vector_ok() { + let tp: Box> = + Box::new(TestProcesor { fail: false }); + + let ctx = Arc::new(MockQueryContext::test_default()); + let qr = VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new((0., 0.).into(), (10., 10.).into()).unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }; + + let s = tp.into_stream(qr, ctx).await; + assert!(s.is_ok()); + let v = s.unwrap().collect::>().await; + assert!(v.is_empty()); + } + + #[tokio::test] + async fn test_vector_fail() { + let tp: Box> = + Box::new(TestProcesor { fail: true }); + + let ctx = Arc::new(MockQueryContext::test_default()); + let qr = VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new((0., 0.).into(), (10., 10.).into()).unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }; + + let s = tp.into_stream(qr, ctx).await; + assert!(s.is_err()); + } + + #[tokio::test] + async fn test_raster_ok() { + let tp: Box> = + Box::new(TestProcesor { fail: false }); + + let ctx = Arc::new(MockQueryContext::test_default()); + let qr = RasterQueryRectangle { + spatial_bounds: SpatialPartition2D::new((0., 10.).into(), (10., 0.).into()).unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }; + + let s = tp.into_stream(qr, ctx).await; + assert!(s.is_ok()); + let v = s.unwrap().collect::>().await; + assert!(v.is_empty()); + } + + #[tokio::test] + async fn test_raster_fail() { + let tp: Box> = + Box::new(TestProcesor { fail: true }); + + let ctx = Arc::new(MockQueryContext::test_default()); + let qr = RasterQueryRectangle { + spatial_bounds: SpatialPartition2D::new((0., 10.).into(), (10., 0.).into()).unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }; + + let s = tp.into_stream(qr, ctx).await; + assert!(s.is_err()); + } +} diff --git a/operators/src/pro/executor/replay/mod.rs b/operators/src/pro/executor/replay/mod.rs new file mode 100644 index 000000000..428e2812a --- /dev/null +++ b/operators/src/pro/executor/replay/mod.rs @@ -0,0 +1,868 @@ +use futures::{Future, Stream}; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; + +/// Possible errors when sending data to the channel +pub enum SendError { + /// The channel is closed (i.e., all receivers were dropped) + Closed(T), + /// The replay buffer is full - try again later + Full(T), +} + +/// Possible errors when sending data to the channel +#[derive(Clone, Debug)] +pub enum ReceiveError { + /// The channel is closed - no more data will be delivered + Closed, + /// Currently no message present - try again later + NoMessage, +} + +/// Possible errors when subscribing to the channel +#[derive(Clone, Debug)] +pub enum SubscribeError { + /// The channel progressed too far, joining not possible anymore + WouldLag, + /// The channel is closed already, joining not possible anymore + Closed, +} + +/// Creates a new replay channel with the specified queue size. This means +/// that the channel holds at most `queue_size` results. +/// In contrast to, e.g., tokio's broadcast channel, lagging receivers +/// are dropped or do not miss any result. +/// However, this means that results can only be processed at the pace +/// of the slowest receiver. +pub fn channel(queue_size: usize) -> (Sender, Receiver) { + let mut inner = Inner::new(queue_size); + + let (tx_id, rx_id) = { + ( + inner.create_sender(), + inner + .create_receiver(true) + .expect("Initial receiver creation must succeed."), + ) + }; + + let inner = Arc::new(Mutex::new(inner)); + + let tx = Sender { + id: tx_id, + pending_futures: Default::default(), + inner: inner.clone(), + }; + let rx = Receiver { + id: rx_id, + pending_futures: Default::default(), + inner, + }; + (tx, rx) +} + +/// An entry in the replay queue. Every entry +/// consists of the actual result and the expected number +/// of receivers. +/// This count is decremented every time a receivers consumes +/// the result. When this count reaches `0`, the entry can +/// safely be discarded from the queue. +struct QueueEntry { + expected_receivers: usize, + value: T, +} + +/// The actual channel construct. +struct Inner { + receiver_id_seq: u32, + sender_id_seq: u32, + queue: VecDeque>, + queue_size: usize, + offset: usize, + receivers: HashMap, + senders: HashMap, +} + +/// A registered receiver that carries +/// its next read index. +struct InnerReceiver { + idx: usize, + waker: Option, +} + +/// A registered sender +struct InnerSender { + waker: Option, +} + +impl InnerSender { + fn update_waker(&mut self, waker: &Waker) { + update_waker(&mut self.waker, waker); + } + + fn remove_waker(&mut self) { + self.waker = None; + } +} + +impl InnerReceiver { + fn update_waker(&mut self, waker: &Waker) { + update_waker(&mut self.waker, waker); + } + + fn remove_waker(&mut self) { + self.waker = None; + } +} + +fn update_waker(opt: &mut Option, waker: &Waker) { + match opt { + Some(w) if w.will_wake(waker) => {} + _ => { + let _ignore = opt.replace(waker.clone()); + } + }; +} + +impl Inner +where + T: Clone, +{ + /// Creates a new Innter instance with the given `queue_size`. + fn new(queue_size: usize) -> Inner { + Inner { + receiver_id_seq: 0, + sender_id_seq: 0, + queue: VecDeque::with_capacity(queue_size), + queue_size, + offset: 0, + receivers: HashMap::new(), + senders: HashMap::new(), + } + } + + /// Creates a new sender and returns its `id`. + fn create_sender(&mut self) -> u32 { + let id = self.sender_id_seq; + self.sender_id_seq += 1; + + assert!(self + .senders + .insert(id, InnerSender { waker: None },) + .is_none()); + id + } + + /// Removes the sender with the given `id`. + fn remove_sender(&mut self, id: u32) { + if self.senders.remove(&id).is_some() && self.senders.is_empty() { + self.notify_receivers(); + } + } + + /// Tries make space for a new result in the queue. + /// This attempt may fail, if the queue is full and + /// at least one consumer did not receive the eldest element yet. + fn clean_up_queue(&mut self) -> bool { + while self.queue.len() >= self.queue_size + && self + .queue + .front() + .map_or(false, |e| e.expected_receivers == 0) + { + self.queue.pop_front().expect(""); + self.offset += 1; + } + self.queue.len() < self.queue_size + } + + /// Tries to send the given value into the channel. + /// + /// #Errors + /// This method fails with [`Closed`](SendError::Closed) if no receivers are left. + /// Moreover, [`Full`](SendError::Full) is returned if there is no more space + /// within the queue. + fn try_send(&mut self, value: T) -> Result<(), SendError> { + if self.receivers.is_empty() { + Err(SendError::Closed(value)) + } else if self.clean_up_queue() { + self.queue.push_back(QueueEntry { + expected_receivers: self.receivers.len(), + value, + }); + self.notify_receivers(); + Ok(()) + } else { + Err(SendError::Full(value)) + } + } + + /// Notifies all pending senders. + fn notify_senders(&mut self) { + for v in self.senders.values_mut() { + if let Some(w) = v.waker.take() { + w.wake(); + } + } + } + + /// Updates the waker for the sender with the given `sender_id`. + fn update_send_waker(&mut self, sender_id: u32, waker: &Waker) { + if let Some(s) = self.senders.get_mut(&sender_id) { + s.update_waker(waker); + } + } + + /// Removes the waker for the sender with the given `sender_id`. + fn remove_send_waker(&mut self, sender_id: u32) { + if let Some(s) = self.senders.get_mut(&sender_id) { + s.remove_waker(); + } + } + + /// Creates a new receivers and returns its id. + /// + /// # Errors + /// This methods fails with [`WouldLag`](SubscribeError::WouldLag), if the first element of + /// the result stream was evicted already. Moreover, it returns [`Closed`](SubscribeError::Closed), + /// if the channel was closed already (i.e., all senders or receivers were dropped). + /// + fn create_receiver(&mut self, first: bool) -> Result { + if self.offset > 0 { + return Err(SubscribeError::WouldLag); + } else if !first && self.receivers.is_empty() { + return Err(SubscribeError::Closed); + } + + // Increment expected read count + for v in &mut self.queue { + v.expected_receivers += 1; + } + + let id = self.receiver_id_seq; + self.receiver_id_seq += 1; + + assert!(self + .receivers + .insert( + id, + InnerReceiver { + idx: 0, + waker: None, + }, + ) + .is_none()); + Ok(id) + } + + /// Removes the receiver with the given `id`. + fn remove_receiver(&mut self, id: u32) { + if let Some(r) = self.receivers.remove(&id) { + let idx = r.idx - self.offset; + let mut notify = false; + for i in idx..self.queue.len() { + let e = &mut self.queue[i]; + e.expected_receivers -= 1; + notify |= e.expected_receivers == 0; + } + if notify { + self.notify_senders(); + } + } + } + + /// Tries to receive the next message for the receiver with the given `receiver_id`. + /// + /// # Errors + /// Returns [`NoMessage`](ReceiveError::NoMessage) if there is currently no new message available + /// for the given receiver - please try again later. + /// + /// Moreover, it returns [`Closed`](ReceiveError::Closed) if the channel is closed. No more data + /// will be delivered in this case. + fn try_recv(&mut self, receiver_id: u32) -> Result { + let recv = self + .receivers + .get_mut(&receiver_id) + .expect("Unknown receiver id."); + + let q_idx = recv.idx - self.offset; + + let res = match self.queue.get_mut(q_idx) { + Some(e) => { + e.expected_receivers -= 1; + Ok(e.value.clone()) + } + None if self.senders.is_empty() => Err(ReceiveError::Closed), + None => Err(ReceiveError::NoMessage), + }?; + + recv.idx += 1; + self.notify_senders(); + Ok(res) + } + + /// Notifies all waiting receivers. + fn notify_receivers(&mut self) { + for v in self.receivers.values_mut() { + if let Some(w) = v.waker.take() { + w.wake(); + } + } + } + + /// Updates the waker for the receiver with the given `receiver_id`. + fn update_receiver_waker(&mut self, receiver_id: u32, waker: &Waker) { + self.receivers + .get_mut(&receiver_id) + .unwrap() + .update_waker(waker); + } + + /// Removes the waker for the receiver with the given `receiver_id`. + fn remove_receiver_waker(&mut self, receiver_id: u32) { + self.receivers.get_mut(&receiver_id).unwrap().remove_waker(); + } +} + +// ///////////////////////////////////////////////////////////////////////////// +// +// SENDER SIDE +// +// ///////////////////////////////////////////////////////////////////////////// + +/// The sender side of a replay channel +pub struct Sender +where + T: Clone, +{ + id: u32, + pending_futures: AtomicUsize, + inner: Arc>>, +} + +impl Sender +where + T: Clone, +{ + /// Registers a new receiver to the channel and returns it. + /// + /// # Errors + /// This methods fails with [`WouldLag`](SubscribeError::WouldLag), if the first element of + /// the result stream was evicted already. Moreover, it returns [`Closed`](SubscribeError::Closed), + /// if the channel was closed already. + pub fn subscribe(&self) -> Result, SubscribeError> { + let id = { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.create_receiver(false) + }?; + + Ok(Receiver { + id, + pending_futures: Default::default(), + inner: self.inner.clone(), + }) + } + + /// Tries to send the given value immediately. If this is not possible, the + /// value is returned with the error. + /// + /// # Errors + /// This method fails with [`Closed`](SendError::Closed) if no receivers are left. + /// Moreover, [`Full`](SendError::Full) is returned if there is no more space + /// within the queue. + pub fn try_send(&self, v: T) -> Result<(), SendError> { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.try_send(v) + } +} + +impl Sender +where + T: Clone + Unpin, +{ + /// Sends a value, waiting until there is capacity. + pub fn send(&self, v: T) -> Send { + let _ignore = self.pending_futures.fetch_add(1, Ordering::Relaxed); + Send { + sender: self, + value: Some(v), + } + } + + fn poll_send( + &self, + send: &mut Send<'_, T>, + cx: &mut Context<'_>, + ) -> Poll>> { + let value = send.value.take().expect("Send without message."); + let mut lock = crate::util::safe_lock_mutex(&self.inner); + match lock.try_send(value) { + Ok(_) => Poll::Ready(Ok(())), + Err(SendError::Full(v)) => { + lock.update_send_waker(self.id, cx.waker()); + let _ignore = send.value.insert(v); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + } + + fn send_dropped(&self) { + if self.pending_futures.fetch_sub(1, Ordering::Relaxed) == 1 { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.remove_send_waker(self.id); + } + } +} + +impl Clone for Sender +where + T: Clone, +{ + fn clone(&self) -> Self { + let id = { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.create_sender() + }; + Sender { + id, + pending_futures: Default::default(), + inner: self.inner.clone(), + } + } +} + +impl Drop for Sender +where + T: Clone, +{ + fn drop(&mut self) { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.remove_sender(self.id); + } +} + +/// A message waiting to be send. +pub struct Send<'sender, T> +where + T: Clone + Unpin, +{ + sender: &'sender Sender, + value: Option, +} + +impl Drop for Send<'_, T> +where + T: Clone + Unpin, +{ + fn drop(&mut self) { + self.sender.send_dropped(); + } +} + +impl Future for Send<'_, T> +where + T: Clone + Unpin, +{ + type Output = Result<(), SendError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.sender.poll_send(self.get_mut(), cx) + } +} + +// ///////////////////////////////////////////////////////////////////////////// +// +// RECEIVER SIDE +// +// ///////////////////////////////////////////////////////////////////////////// + +/// The receiver side of a replay channel +pub struct Receiver +where + T: Clone, +{ + id: u32, + pending_futures: AtomicUsize, + inner: Arc>>, +} + +impl Receiver +where + T: Clone, +{ + /// Tries to receive the next message immediately. + /// + /// # Errors + /// Returns [`NoMessage`](ReceiveError::NoMessage) if there is currently no new message available + /// - please try again later. + /// + /// Moreover, it returns [`Closed`](ReceiveError::Closed) if the channel is closed. No more data + /// will be delivered in this case. + pub fn try_recv(&self) -> Result { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.try_recv(self.id) + } + + /// Receives the next value from the channel, possibly waiting for it to arrive. + pub fn recv(&self) -> Recv { + let _ignore = self.pending_futures.fetch_add(1, Ordering::Relaxed); + Recv { receiver: self } + } + + fn poll_recv(&self, cx: &mut Context<'_>) -> Poll> { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + match lock.try_recv(self.id) { + Ok(v) => Poll::Ready(Some(v)), + Err(ReceiveError::Closed) => Poll::Ready(None), + Err(ReceiveError::NoMessage) => { + lock.update_receiver_waker(self.id, cx.waker()); + Poll::Pending + } + } + } + + pub fn recv_dropped(&self) { + if self.pending_futures.fetch_sub(1, Ordering::Relaxed) == 1 { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.remove_receiver_waker(self.id); + } + } +} + +impl Drop for Receiver +where + T: Clone, +{ + fn drop(&mut self) { + let mut lock = crate::util::safe_lock_mutex(&self.inner); + lock.remove_receiver(self.id); + } +} + +/// A message waiting to be received. +pub struct Recv<'receiver, T> +where + T: Clone, +{ + receiver: &'receiver Receiver, +} + +impl Drop for Recv<'_, T> +where + T: Clone, +{ + fn drop(&mut self) { + self.receiver.recv_dropped(); + } +} + +impl Future for Recv<'_, T> +where + T: Clone, +{ + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.receiver.poll_recv(cx) + } +} + +/// A stream representation of the `Receiver`. +pub struct ReceiverStream +where + T: Clone, +{ + inner: Receiver, +} + +impl From> for ReceiverStream +where + T: Clone, +{ + fn from(r: Receiver) -> Self { + ReceiverStream { inner: r } + } +} + +impl Stream for ReceiverStream +where + T: Clone, +{ + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_recv(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::StreamExt; + use std::matches; + use tokio::time::Duration; + + #[test] + fn test_send() { + let (tx, rx) = channel(2); + assert!(tx.try_send(1).is_ok()); + assert!(tx.try_send(2).is_ok()); + assert!(matches!(tx.try_send(3), Err(SendError::Full(3)))); + assert!(matches!(rx.try_recv(), Ok(1))); + assert!(tx.try_send(3).is_ok()); + } + + #[test] + fn test_all_receivers_gone() { + let (tx, rx) = channel(2); + assert!(tx.try_send(1).is_ok()); + drop(rx); + assert!(matches!(tx.try_send(2), Err(SendError::Closed(2)))); + } + + #[test] + fn test_send_drop_lagging_receiver() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + assert!(tx.try_send(1).is_ok()); + assert!(tx.try_send(2).is_ok()); + assert!(matches!(tx.try_send(3), Err(SendError::Full(3)))); + assert!(matches!(rx.try_recv(), Ok(1))); + drop(rx2); + assert!(tx.try_send(3).is_ok()); + } + + #[test] + fn test_receive() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + assert!(tx.try_send(1).is_ok()); + assert!(tx.try_send(2).is_ok()); + assert!(matches!(tx.try_send(3), Err(SendError::Full(3)))); + assert!(matches!(rx.try_recv(), Ok(1))); + assert!(matches!(rx.try_recv(), Ok(2))); + assert!(matches!(rx.try_recv(), Err(ReceiveError::NoMessage))); + assert!(matches!(rx2.try_recv(), Ok(1))); + assert!(tx.try_send(3).is_ok()); + assert!(matches!(rx.try_recv(), Ok(3))); + assert!(matches!(rx2.try_recv(), Ok(2))); + assert!(matches!(rx2.try_recv(), Ok(3))); + assert!(matches!(rx.try_recv(), Err(ReceiveError::NoMessage))); + assert!(matches!(rx2.try_recv(), Err(ReceiveError::NoMessage))); + } + + #[test] + fn test_receive_after_tx_close() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + assert!(tx.try_send(1).is_ok()); + assert!(tx.try_send(2).is_ok()); + assert!(matches!(tx.try_send(3), Err(SendError::Full(3)))); + assert!(matches!(rx.try_recv(), Ok(1))); + assert!(matches!(rx.try_recv(), Ok(2))); + assert!(matches!(rx.try_recv(), Err(ReceiveError::NoMessage))); + assert!(matches!(rx2.try_recv(), Ok(1))); + assert!(tx.try_send(3).is_ok()); + drop(tx); + assert!(matches!(rx.try_recv(), Ok(3))); + assert!(matches!(rx2.try_recv(), Ok(2))); + assert!(matches!(rx2.try_recv(), Ok(3))); + assert!(matches!(rx.try_recv(), Err(ReceiveError::Closed))); + assert!(matches!(rx2.try_recv(), Err(ReceiveError::Closed))); + } + + #[tokio::test] + async fn test_send_async() { + let (tx, rx) = channel(2); + + let t1 = tokio::spawn(async move { + for i in 1..=3 { + assert!(tx.send(i).await.is_ok()); + } + }); + + let mut result = Vec::new(); + while let Some(v) = rx.recv().await { + result.push(v); + } + + assert_eq!(vec![1, 2, 3], result); + assert!(t1.await.is_ok()); + } + + #[tokio::test] + async fn test_all_receivers_gone_async() { + let (tx, rx) = channel(2); + drop(rx); + let t1 = tokio::spawn(async move { + let res = tx.send(1).await; + assert!(matches!(res, Err(SendError::Closed(1)))); + }); + assert!(t1.await.is_ok()); + } + + #[tokio::test] + async fn test_send_drop_lagging_receiver_async() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + + let ct = tokio_util::sync::CancellationToken::new(); + let cloned_token = ct.clone(); + + let t1 = tokio::spawn(async move { + for i in 1..=3 { + tokio::select! { + res = tx.send(i) => { + assert!(res.is_ok()); + }, + _ = cloned_token.cancelled() => { + return false; + } + } + } + true + }); + + assert!(matches!(rx.recv().await, Some(1))); + assert!(matches!(rx.recv().await, Some(2))); + assert!(matches!(rx.try_recv(), Err(ReceiveError::NoMessage))); + drop(rx2); + // Wait until we cancel + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + ct.cancel(); + assert!(matches!(t1.await, Ok(true))); + assert!(matches!(rx.try_recv(), Ok(3))); + } + + #[tokio::test] + async fn test_receive_async() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + + let f1 = async move { + for i in 1..=3 { + assert!(tx.send(i).await.is_ok()); + } + }; + + let f2 = async move { + let mut res = Vec::new(); + while let Some(v) = rx.recv().await { + res.push(v); + } + res + }; + + let f3 = async move { + let mut res = Vec::new(); + while let Some(v) = rx2.recv().await { + res.push(v); + } + res + }; + + let (_, r1, r2) = tokio::join!(f1, f2, f3); + + assert_eq!(vec![1, 2, 3], r1); + assert_eq!(vec![1, 2, 3], r2); + } + + #[tokio::test] + async fn test_receive_after_tx_close_async() { + let (tx, rx) = channel(2); + + assert!(tx.send(1).await.is_ok()); + assert!(tx.send(2).await.is_ok()); + assert!(matches!(rx.recv().await, Some(1))); + assert!(tx.send(3).await.is_ok()); + drop(tx); + assert!(matches!(rx.recv().await, Some(2))); + assert!(matches!(rx.recv().await, Some(3))); + assert!(matches!(rx.recv().await, None)); + } + + #[tokio::test] + async fn test_multiple_waiting() { + let (tx, rx) = channel(2); + + let f1 = tx.send(1); + let f2 = tx.send(2); + let f3 = tx.send(3); + let f4 = tx.send(4); + + let (_r1, _r2) = tokio::join!(f1, f2); + let (_tx1, _tx2, r1, r2, r3, r4) = + tokio::join!(f3, f4, rx.recv(), rx.recv(), rx.recv(), rx.recv()); + + assert_eq!(Some(1), r1); + assert_eq!(Some(2), r2); + assert_eq!(Some(3), r3); + assert_eq!(Some(4), r4); + } + + #[tokio::test] + async fn test_send_multiple_tasks() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + + let t = tokio::spawn(async move { + let f1 = tx.send(1); + let f2 = tx.send(2); + let f3 = tx.send(3); + let (_, _, _) = tokio::join!(f1, f2, f3); + }); + + let rt = { + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(100)).await; + let (r1, r2, r3, r4) = tokio::join!(rx2.recv(), rx2.recv(), rx2.recv(), rx2.recv()); + assert_eq!(Some(1), r1); + assert_eq!(Some(2), r2); + assert_eq!(Some(3), r3); + assert_eq!(None, r4); + }) + }; + + let (r1, r2, r3, r4) = tokio::join!(rx.recv(), rx.recv(), rx.recv(), rx.recv()); + + assert_eq!(Some(1), r1); + assert_eq!(Some(2), r2); + assert_eq!(Some(3), r3); + assert_eq!(None, r4); + + let (r1, r2) = tokio::join!(t, rt); + assert!(r1.is_ok()); + assert!(r2.is_ok()); + } + + #[tokio::test] + async fn test_receive_stream() { + let (tx, rx) = channel(2); + let rx2 = tx.subscribe().unwrap(); + + let f1 = async move { + for i in 1..=3 { + assert!(tx.send(i).await.is_ok()); + } + }; + + let f2 = async move { + let s: ReceiverStream = rx.into(); + s.collect::>().await + }; + + let f3 = async move { + let s: ReceiverStream = rx2.into(); + s.collect::>().await + }; + + let (_, results1, results2) = tokio::join!(f1, f2, f3); + + assert_eq!(vec![1, 2, 3], results1); + assert_eq!(vec![1, 2, 3], results2); + } +} diff --git a/operators/src/pro/mod.rs b/operators/src/pro/mod.rs index f70e69421..545dffb11 100644 --- a/operators/src/pro/mod.rs +++ b/operators/src/pro/mod.rs @@ -1 +1,2 @@ // This is an inclusion point of Geo Engine Pro +pub mod executor; diff --git a/operators/src/util/raster_stream_to_png.rs b/operators/src/util/raster_stream_to_png.rs index 16ce12e74..1b926bcf5 100644 --- a/operators/src/util/raster_stream_to_png.rs +++ b/operators/src/util/raster_stream_to_png.rs @@ -1,3 +1,4 @@ +use futures::stream::BoxStream; use futures::StreamExt; use geoengine_datatypes::{ operations::image::{Colorizer, RgbaColor, ToPng}, @@ -24,9 +25,33 @@ pub async fn raster_stream_to_png_bytes( where T: Pixel, { - let colorizer = colorizer.unwrap_or(default_colorizer_gradient::()?); - let tile_stream = processor.query(query_rect, &query_ctx).await?; + raster_stream_to_png_bytes_stream::( + tile_stream, + query_rect, + width, + height, + time, + colorizer, + no_data_value, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub async fn raster_stream_to_png_bytes_stream( + tile_stream: BoxStream<'_, Result>>, + query_rect: RasterQueryRectangle, + width: u32, + height: u32, + time: Option, + colorizer: Option, + no_data_value: Option, +) -> Result> +where + T: Pixel, +{ + let colorizer = colorizer.unwrap_or(default_colorizer_gradient::()?); let x_query_resolution = query_rect.spatial_bounds.size_x() / f64::from(width); let y_query_resolution = query_rect.spatial_bounds.size_y() / f64::from(height); diff --git a/operators/test-data/raster/msg/20121212_1200.tif b/operators/test-data/raster/msg/20121212_1200.tif new file mode 100644 index 000000000..cc5b55f5f Binary files /dev/null and b/operators/test-data/raster/msg/20121212_1200.tif differ diff --git a/services/src/contexts/mod.rs b/services/src/contexts/mod.rs index c872e67e3..cb4efba12 100644 --- a/services/src/contexts/mod.rs +++ b/services/src/contexts/mod.rs @@ -19,6 +19,7 @@ use geoengine_operators::engine::{ ChunkByteSize, ExecutionContext, MetaData, MetaDataProvider, QueryContext, RasterResultDescriptor, VectorResultDescriptor, }; + use geoengine_operators::mock::MockDatasetDataSourceLoadingInfo; use geoengine_operators::source::{GdalLoadingInfo, OgrSourceDataset}; diff --git a/services/src/error.rs b/services/src/error.rs index c5038478f..0cd1a6a4a 100644 --- a/services/src/error.rs +++ b/services/src/error.rs @@ -212,6 +212,8 @@ pub enum Error { MissingSpatialReference, + ExternalAddressNotConfigured, + WcsVersionNotSupported, WcsGridOriginMustEqualBoundingboxUpperLeft, WcsBoundingboxCrsMustEqualGridBaseCrs, @@ -308,6 +310,7 @@ pub enum Error { InvalidAPIToken { message: String, }, + MissingNFDIMetaData, #[snafu(context(false))] @@ -328,6 +331,19 @@ pub enum Error { }, BaseUrlMustEndWithSlash, + + #[snafu(context(false))] + #[cfg(feature = "pro")] + #[snafu(display("Executor error: {}", source))] + Executor { + source: geoengine_operators::pro::executor::error::ExecutorError, + }, + + #[cfg(feature = "pro")] + #[snafu(display("Executor error: {}", message))] + ExecutorComputation { + message: String, + }, } impl actix_web::error::ResponseError for Error { diff --git a/services/src/handlers/plots.rs b/services/src/handlers/plots.rs index 21d819427..907cdc624 100644 --- a/services/src/handlers/plots.rs +++ b/services/src/handlers/plots.rs @@ -37,6 +37,14 @@ pub(crate) struct GetPlot { pub spatial_resolution: SpatialResolution, } +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct WrappedPlotOutput { + pub(crate) output_format: PlotOutputFormat, + pub(crate) plot_type: &'static str, + pub(crate) data: serde_json::Value, +} + /// Generates a [plot](WrappedPlotOutput). /// /// # Example @@ -117,12 +125,23 @@ pub(crate) struct GetPlot { /// ] /// } /// ``` + async fn get_plot_handler( id: web::Path, params: web::Query, session: C::Session, ctx: web::Data, ) -> Result { + let output = process_plot_request(id, params, session, ctx).await?; + Ok(web::Json(output)) +} + +pub(crate) async fn process_plot_request( + id: web::Path, + params: web::Query, + session: C::Session, + ctx: web::Data, +) -> Result { let workflow = ctx .workflow_registry_ref() .await @@ -193,21 +212,11 @@ async fn get_plot_handler( } }; - let output = WrappedPlotOutput { + Ok(WrappedPlotOutput { output_format, plot_type, data, - }; - - Ok(web::Json(output)) -} - -#[derive(Debug, Clone, PartialEq, Serialize)] -#[serde(rename_all = "camelCase")] -struct WrappedPlotOutput { - output_format: PlotOutputFormat, - plot_type: &'static str, - data: serde_json::Value, + }) } #[cfg(test)] diff --git a/services/src/handlers/wfs.rs b/services/src/handlers/wfs.rs index 68f3f7095..de0b7254d 100644 --- a/services/src/handlers/wfs.rs +++ b/services/src/handlers/wfs.rs @@ -1,5 +1,5 @@ use actix_web::{web, FromRequest, HttpResponse}; -use geoengine_datatypes::primitives::VectorQueryRectangle; +use geoengine_datatypes::primitives::{BoundingBox2D, QueryRectangle, VectorQueryRectangle}; use reqwest::Url; use snafu::{ensure, ResultExt}; @@ -13,6 +13,7 @@ use crate::util::config::get_config_element; use crate::util::user_input::QueryEx; use crate::workflows::registry::WorkflowRegistry; use crate::workflows::workflow::{Workflow, WorkflowId}; +use futures::stream::BoxStream; use futures::StreamExt; use geoengine_datatypes::collections::ToGeoJson; use geoengine_datatypes::{ @@ -23,10 +24,9 @@ use geoengine_datatypes::{ primitives::{FeatureData, Geometry, MultiPoint, TimeInstance, TimeInterval}, spatial_reference::SpatialReference, }; -use geoengine_operators::engine::{ - QueryContext, ResultDescriptor, TypedVectorQueryProcessor, VectorQueryProcessor, -}; +use geoengine_operators::call_on_generic_vector_processor; use geoengine_operators::engine::{QueryProcessor, VectorOperator}; +use geoengine_operators::engine::{ResultDescriptor, TypedVectorQueryProcessor}; use geoengine_operators::processing::{Reprojection, ReprojectionParams}; use serde_json::json; use std::str::FromStr; @@ -154,7 +154,7 @@ async fn wfs_handler( /// /// ``` #[allow(clippy::too_many_lines)] -async fn get_capabilities( +pub(crate) async fn get_capabilities( _request: &GetCapabilities, ctx: &C, session: C::Session, @@ -399,6 +399,28 @@ async fn get_feature( session: C::Session, endpoint: WorkflowId, ) -> Result { + if request.type_names.feature_type == "93d6785e-5eea-4e0e-8074-e7f78733d988" { + return get_feature_mock(request); + } + + let (processor, query_rect) = + extract_operator_and_bounding_box(request, ctx, session.clone(), endpoint).await?; + + let query_ctx = ctx.query_context()?; + + let json = call_on_generic_vector_processor!(processor, p => { + let stream = p.query(query_rect, &query_ctx).await?; + vector_stream_to_geojson(stream).await + })?; + Ok(HttpResponse::Ok().json(json)) +} + +pub(crate) async fn extract_operator_and_bounding_box( + request: &GetFeature, + ctx: &C, + session: C::Session, + endpoint: WorkflowId, +) -> Result<(TypedVectorQueryProcessor, QueryRectangle)> { let type_names = match request.type_names.namespace.as_deref() { None => WorkflowId::from_str(&request.type_names.feature_type)?, Some(_) => { @@ -414,12 +436,6 @@ async fn get_feature( } ); - // TODO: validate request further - - if request.type_names.feature_type == "93d6785e-5eea-4e0e-8074-e7f78733d988" { - return get_feature_mock(request); - } - let workflow: Workflow = ctx.workflow_registry_ref().await.load(&type_names).await?; let operator = workflow.operator.get_vector().context(error::Operator)?; @@ -469,30 +485,12 @@ async fn get_feature( // TODO: find a reasonable fallback, e.g., dependent on the SRS or BBox .unwrap_or_else(SpatialResolution::zero_point_one), }; - let query_ctx = ctx.query_context()?; - let json = match processor { - TypedVectorQueryProcessor::Data(p) => { - vector_stream_to_geojson(p, query_rect, &query_ctx).await - } - TypedVectorQueryProcessor::MultiPoint(p) => { - vector_stream_to_geojson(p, query_rect, &query_ctx).await - } - TypedVectorQueryProcessor::MultiLineString(p) => { - vector_stream_to_geojson(p, query_rect, &query_ctx).await - } - TypedVectorQueryProcessor::MultiPolygon(p) => { - vector_stream_to_geojson(p, query_rect, &query_ctx).await - } - }?; - - Ok(HttpResponse::Ok().json(json)) + Ok((processor, query_rect)) } -async fn vector_stream_to_geojson( - processor: Box>>, - query_rect: VectorQueryRectangle, - query_ctx: &dyn QueryContext, +pub(crate) async fn vector_stream_to_geojson( + stream: BoxStream<'_, geoengine_operators::util::Result>>, ) -> Result where G: Geometry + 'static, @@ -500,9 +498,6 @@ where { let features: Vec = Vec::new(); - // TODO: more efficient merging of the partial feature collections - let stream = processor.query(query_rect, query_ctx).await?; - let features = stream .fold( Result::, error::Error>::Ok(features), diff --git a/services/src/handlers/wms.rs b/services/src/handlers/wms.rs index 5831994c4..75b9175d7 100644 --- a/services/src/handlers/wms.rs +++ b/services/src/handlers/wms.rs @@ -3,7 +3,7 @@ use reqwest::Url; use snafu::{ensure, ResultExt}; use geoengine_datatypes::primitives::{ - AxisAlignedRectangle, RasterQueryRectangle, SpatialPartition2D, + AxisAlignedRectangle, QueryRectangle, RasterQueryRectangle, SpatialPartition2D, }; use geoengine_datatypes::{ operations::image::Colorizer, primitives::SpatialResolution, @@ -22,7 +22,7 @@ use crate::workflows::registry::WorkflowRegistry; use crate::workflows::workflow::WorkflowId; use geoengine_datatypes::primitives::{TimeInstance, TimeInterval}; -use geoengine_operators::engine::{RasterOperator, ResultDescriptor}; +use geoengine_operators::engine::{RasterOperator, ResultDescriptor, TypedRasterQueryProcessor}; use geoengine_operators::processing::{Reprojection, ReprojectionParams}; use geoengine_operators::{ call_on_generic_raster_processor, util::raster_stream_to_png::raster_stream_to_png_bytes, @@ -116,7 +116,7 @@ async fn wms_handler( /// /// /// ``` -async fn get_capabilities( +pub(crate) async fn get_capabilities( _request: &GetCapabilities, ctx: &C, session: C::Session, @@ -227,6 +227,33 @@ async fn get_map( session: C::Session, endpoint: WorkflowId, ) -> Result { + let (processor, query_rect, colorizer, no_data_value) = + extract_operator_and_bounding_box_and_colorizer(request, ctx, session, endpoint).await?; + + let query_ctx = ctx.query_context()?; + + let image_bytes = call_on_generic_raster_processor!( + processor, + p => + raster_stream_to_png_bytes(p, query_rect, query_ctx, request.width, request.height, request.time, colorizer, no_data_value.map(AsPrimitive::as_)).await + ).map_err(error::Error::from)?; + + Ok(HttpResponse::Ok() + .content_type(mime::IMAGE_PNG) + .body(image_bytes)) +} + +pub(crate) async fn extract_operator_and_bounding_box_and_colorizer( + request: &GetMap, + ctx: &C, + session: C::Session, + endpoint: WorkflowId, +) -> Result<( + TypedRasterQueryProcessor, + QueryRectangle, + Option, + Option, +)> { let layer = WorkflowId::from_str(&request.layers)?; ensure!( @@ -296,19 +323,9 @@ async fn get_map( ), }; - let query_ctx = ctx.query_context()?; - let colorizer = colorizer_from_style(&request.styles)?; - let image_bytes = call_on_generic_raster_processor!( - processor, - p => - raster_stream_to_png_bytes(p, query_rect, query_ctx, request.width, request.height, request.time, colorizer, no_data_value.map(AsPrimitive::as_)).await - ).map_err(error::Error::from)?; - - Ok(HttpResponse::Ok() - .content_type(mime::IMAGE_PNG) - .body(image_bytes)) + Ok((processor, query_rect, colorizer, no_data_value)) } fn colorizer_from_style(styles: &str) -> Result> { @@ -319,7 +336,7 @@ fn colorizer_from_style(styles: &str) -> Result> { } #[allow(clippy::unnecessary_wraps)] // TODO: remove line once implemented fully -fn get_legend_graphic( +pub(crate) fn get_legend_graphic( _request: &GetLegendGraphic, _ctx: &C, _endpoint: WorkflowId, diff --git a/services/src/pro/contexts/in_memory.rs b/services/src/pro/contexts/in_memory.rs index 9a24888c8..590b81b95 100644 --- a/services/src/pro/contexts/in_memory.rs +++ b/services/src/pro/contexts/in_memory.rs @@ -1,6 +1,6 @@ use crate::contexts::{ExecutionContextImpl, QueryContextImpl}; use crate::error; -use crate::pro::contexts::{Context, Db, ProContext}; +use crate::pro::contexts::{Context, Db, ProContext, TaskManager}; use crate::pro::datasets::{add_datasets_from_directory, ProHashMapDatasetDb}; use crate::pro::projects::ProHashMapProjectDb; use crate::pro::users::{HashMapUserDb, UserDb, UserSession}; @@ -27,6 +27,7 @@ pub struct ProInMemoryContext { thread_pool: Arc, exe_ctx_tiling_spec: TilingSpecification, query_ctx_chunk_size: ChunkByteSize, + task_manager: TaskManager, } impl TestDefault for ProInMemoryContext { @@ -39,6 +40,7 @@ impl TestDefault for ProInMemoryContext { thread_pool: create_rayon_thread_pool(0), exe_ctx_tiling_spec: TestDefault::test_default(), query_ctx_chunk_size: TestDefault::test_default(), + task_manager: Default::default(), } } } @@ -63,6 +65,7 @@ impl ProInMemoryContext { exe_ctx_tiling_spec, query_ctx_chunk_size, dataset_db: Arc::new(RwLock::new(db)), + task_manager: Default::default(), } } @@ -78,6 +81,7 @@ impl ProInMemoryContext { thread_pool: create_rayon_thread_pool(0), exe_ctx_tiling_spec, query_ctx_chunk_size, + task_manager: Default::default(), } } } @@ -95,6 +99,9 @@ impl ProContext for ProInMemoryContext { async fn user_db_ref_mut(&self) -> RwLockWriteGuard<'_, Self::UserDB> { self.user_db.write().await } + fn task_manager(&self) -> TaskManager { + self.task_manager.clone() + } } #[async_trait] diff --git a/services/src/pro/contexts/mod.rs b/services/src/pro/contexts/mod.rs index bbb096209..84c8590e6 100644 --- a/services/src/pro/contexts/mod.rs +++ b/services/src/pro/contexts/mod.rs @@ -6,11 +6,26 @@ mod postgres; pub use in_memory::ProInMemoryContext; #[cfg(feature = "postgres")] pub use postgres::PostgresContext; +use std::sync::Arc; +use std::time::Duration; use crate::contexts::{Context, Db}; +use crate::pro::executor::scheduler::TaskScheduler; +use crate::pro::executor::{ + DataDescription, FeatureCollectionTaskDescription, MultiLinestringDescription, + MultiPointDescription, MultiPolygonDescription, PlotDescription, RasterTaskDescription, + STFilterable, +}; use crate::pro::users::{UserDb, UserSession}; - +use crate::util::config::get_config_element; use async_trait::async_trait; +use geoengine_datatypes::collections::FeatureCollection; +use geoengine_datatypes::primitives::{ + Geometry, MultiLineString, MultiPoint, MultiPolygon, NoGeometry, +}; +use geoengine_datatypes::raster::Pixel; +use geoengine_datatypes::util::arrow::ArrowTyped; +use geoengine_operators::pro::executor::Executor; use tokio::sync::{RwLockReadGuard, RwLockWriteGuard}; /// A pro contexts that extends the default context. @@ -22,4 +37,316 @@ pub trait ProContext: Context { fn user_db(&self) -> Db; async fn user_db_ref(&self) -> RwLockReadGuard; async fn user_db_ref_mut(&self) -> RwLockWriteGuard; + fn task_manager(&self) -> TaskManager; +} + +/// The `TaskManager` provides access to [Executors][Executor] for all available +/// result types. +/// It uses an [Arc] internally, so cloning is cheap and there is no +/// need to wrap it by yourself. +#[derive(Clone)] +pub struct TaskManager { + executors: Arc, +} + +/// Holds all executors provided by the [`TaskManager`]. +struct Executors { + plot_executor: Executor, + vector_schedulers: FeatureExecutors, + raster_schedulers: RasterSchedulers, +} + +/// Trait for retrieving an [`Executor`] instance for feature data. +pub trait GetFeatureExecutor +where + G: Geometry + ArrowTyped + 'static, + for<'a> &'a FeatureCollection: STFilterable, +{ + /// Retrieves [Executor] + fn get_feature_executor(&self) -> &Executor>; +} + +/// Trait for retrieving an [`TaskScheduler`] instance for feature data. +pub trait GetFeatureScheduler +where + G: Geometry + ArrowTyped + 'static, + for<'a> &'a FeatureCollection: STFilterable, +{ + /// Retrieves [`TaskScheduler`] + fn get_feature_scheduler(&self) -> &TaskScheduler>; +} + +/// Summarizes all [`Executors`][Executor] for feature data. +struct FeatureExecutors { + data: TaskScheduler, + point: TaskScheduler, + line: TaskScheduler, + polygon: TaskScheduler, +} + +impl FeatureExecutors { + /// Creates a new instance with each [`Executor`] having the given `queue_size`. + fn new(queue_size: usize, merge_dead_space_threshold: f64, timeout: Duration) -> Self { + Self { + data: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + point: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + line: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + polygon: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + } + } +} + +/// Trait for retrieving an [`Executor`] instance for raster data. +pub trait GetRasterExecutor { + /// Retrieves [`Executor`] + fn get_raster_executor(&self) -> &Executor>; +} + +/// Trait for retrieving an [`TaskScheduler`] instance for raster data. +pub trait GetRasterScheduler { + /// Retrieves [`TaskScheduler`] + fn get_raster_scheduler(&self) -> &TaskScheduler>; +} + +/// Summarizes all [`Executors`][Executor] for feature data. +struct RasterSchedulers { + ru8: TaskScheduler>, + ru16: TaskScheduler>, + ru32: TaskScheduler>, + ru64: TaskScheduler>, + ri8: TaskScheduler>, + ri16: TaskScheduler>, + ri32: TaskScheduler>, + ri64: TaskScheduler>, + rf32: TaskScheduler>, + rf64: TaskScheduler>, +} + +impl RasterSchedulers { + /// Creates a new instance with each [`Executor`] having the given `queue_size`. + fn new(queue_size: usize, merge_dead_space_threshold: f64, timeout: Duration) -> Self { + Self { + ru8: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ru16: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ru32: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ru64: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ri8: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ri16: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ri32: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + ri64: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + rf32: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + rf64: TaskScheduler::new(queue_size, merge_dead_space_threshold, timeout), + } + } +} + +impl TaskManager { + pub fn plot_executor(&self) -> &Executor { + &self.executors.plot_executor + } +} + +impl GetFeatureExecutor for TaskManager { + fn get_feature_executor(&self) -> &Executor { + self.executors.vector_schedulers.data.executor() + } +} + +impl GetFeatureExecutor for TaskManager { + fn get_feature_executor(&self) -> &Executor { + self.executors.vector_schedulers.point.executor() + } +} + +impl GetFeatureExecutor for TaskManager { + fn get_feature_executor(&self) -> &Executor { + self.executors.vector_schedulers.line.executor() + } +} + +impl GetFeatureExecutor for TaskManager { + fn get_feature_executor(&self) -> &Executor { + self.executors.vector_schedulers.polygon.executor() + } +} + +impl GetFeatureScheduler for TaskManager { + fn get_feature_scheduler(&self) -> &TaskScheduler { + &self.executors.vector_schedulers.data + } +} + +impl GetFeatureScheduler for TaskManager { + fn get_feature_scheduler(&self) -> &TaskScheduler { + &self.executors.vector_schedulers.point + } +} + +impl GetFeatureScheduler for TaskManager { + fn get_feature_scheduler(&self) -> &TaskScheduler { + &self.executors.vector_schedulers.line + } +} + +impl GetFeatureScheduler for TaskManager { + fn get_feature_scheduler(&self) -> &TaskScheduler { + &self.executors.vector_schedulers.polygon + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ru8.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ru16.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ru32.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ru64.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ri8.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ri16.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ri32.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.ri64.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.rf32.executor() + } +} + +impl GetRasterExecutor for TaskManager { + fn get_raster_executor(&self) -> &Executor> { + self.executors.raster_schedulers.rf64.executor() + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ru8 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ru16 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ru32 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ru64 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ri8 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ri16 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ri32 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.ri64 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.rf32 + } +} + +impl GetRasterScheduler for TaskManager { + fn get_raster_scheduler(&self) -> &TaskScheduler> { + &self.executors.raster_schedulers.rf64 + } +} + +impl Default for TaskManager { + fn default() -> Self { + let queue_size = + get_config_element::().map_or(5, |it| it.queue_size); + let raster_timeout = Duration::from_millis( + get_config_element::() + .map_or(0, |it| it.raster_scheduler_timeout_ms), + ); + let feature_timeout = Duration::from_millis( + get_config_element::() + .map_or(0, |it| it.feature_scheduler_timeout_ms), + ); + + let raster_threshold = get_config_element::() + .map_or(0.01, |it| it.raster_scheduler_merge_threshold); + + let feature_threshold = get_config_element::() + .map_or(0.01, |it| it.feature_scheduler_merge_threshold); + + TaskManager { + executors: Arc::new(Executors { + plot_executor: Executor::new(queue_size), + vector_schedulers: FeatureExecutors::new( + queue_size, + feature_threshold, + feature_timeout, + ), + raster_schedulers: RasterSchedulers::new( + queue_size, + raster_threshold, + raster_timeout, + ), + }), + } + } } diff --git a/services/src/pro/contexts/postgres.rs b/services/src/pro/contexts/postgres.rs index dafaf0ee2..6acda4b11 100644 --- a/services/src/pro/contexts/postgres.rs +++ b/services/src/pro/contexts/postgres.rs @@ -1,5 +1,6 @@ use crate::datasets::add_from_directory::add_providers_from_directory; use crate::error::{self, Result}; +use crate::pro::contexts::TaskManager; use crate::pro::datasets::{add_datasets_from_directory, PostgresDatasetDb, Role}; use crate::pro::projects::ProjectPermission; use crate::pro::users::{UserDb, UserId, UserSession}; @@ -50,6 +51,7 @@ where thread_pool: Arc, exe_ctx_tiling_spec: TilingSpecification, query_ctx_chunk_size: ChunkByteSize, + task_manager: TaskManager, } impl PostgresContext @@ -77,6 +79,7 @@ where workflow_registry: Arc::new(RwLock::new(PostgresWorkflowRegistry::new(pool.clone()))), dataset_db: Arc::new(RwLock::new(PostgresDatasetDb::new(pool.clone()))), thread_pool: create_rayon_thread_pool(0), + task_manager: Default::default(), exe_ctx_tiling_spec, query_ctx_chunk_size, }) @@ -108,6 +111,7 @@ where workflow_registry: Arc::new(RwLock::new(PostgresWorkflowRegistry::new(pool.clone()))), dataset_db: Arc::new(RwLock::new(dataset_db)), thread_pool: create_rayon_thread_pool(0), + task_manager: Default::default(), exe_ctx_tiling_spec, query_ctx_chunk_size, }) @@ -442,6 +446,10 @@ where async fn user_db_ref_mut(&self) -> RwLockWriteGuard<'_, Self::UserDB> { self.user_db.write().await } + + fn task_manager(&self) -> TaskManager { + self.task_manager.clone() + } } #[async_trait] diff --git a/services/src/pro/executor/mod.rs b/services/src/pro/executor/mod.rs new file mode 100644 index 000000000..ecd29abfd --- /dev/null +++ b/services/src/pro/executor/mod.rs @@ -0,0 +1,872 @@ +use crate::handlers::plots::WrappedPlotOutput; +use crate::pro::executor::scheduler::MergableTaskDescription; +use crate::workflows::workflow::WorkflowId; +use float_cmp::approx_eq; +use futures_util::future::BoxFuture; +use futures_util::stream::BoxStream; +use futures_util::FutureExt; +use geo::prelude::Intersects; +use geoengine_datatypes::collections::{ + FeatureCollection, FeatureCollectionInfos, FeatureCollectionModifications, +}; +use geoengine_datatypes::primitives::{ + AxisAlignedRectangle, BoundingBox2D, Coordinate2D, Geometry, MultiLineString, MultiPoint, + MultiPolygon, NoGeometry, QueryRectangle, RasterQueryRectangle, SpatialPartition2D, + SpatialPartitioned, TemporalBounded, TimeInstance, TimeInterval, VectorQueryRectangle, +}; +use geoengine_datatypes::raster::{Pixel, RasterTile2D}; +use geoengine_datatypes::util::arrow::ArrowTyped; +use geoengine_operators::engine::{QueryContext, RasterQueryProcessor, VectorQueryProcessor}; +use geoengine_operators::pro::executor::operators::OneshotQueryProcessor; +use geoengine_operators::pro::executor::ExecutorTaskDescription; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +pub mod scheduler; + +/// Merges two time intervals regardless if they actually intersect. +fn merge_time_interval(l: &TimeInterval, r: &TimeInterval) -> TimeInterval { + TimeInterval::new_unchecked( + TimeInstance::min(l.start(), r.start()), + TimeInstance::max(l.end(), r.end()), + ) +} + +/// Merges to rectangles regardless if they actually intersect and returns +/// the lower left and upper right coordinate of the resulting rectangle. +fn merge_rect(l: &T, r: &T) -> (Coordinate2D, Coordinate2D) { + let ll = Coordinate2D::new( + f64_min(l.lower_left().x, r.lower_left().x), + f64_min(l.lower_left().y, r.lower_left().y), + ); + + let ur = Coordinate2D::new( + f64_max(l.upper_right().x, r.upper_right().x), + f64_max(l.upper_right().y, r.upper_right().y), + ); + + (ll, ur) +} + +/// Calculates the dead space that occurs when merging two query rectangles. A value of 0 indicates +/// that the merge is perfect in the sense that no superfluid calculations occur. +fn merge_dead_space( + l: &QueryRectangle, + r: &QueryRectangle, +) -> f64 { + let t_sum = (l.time_interval.duration_ms() + 1 + r.time_interval.duration_ms() + 1) as f64; + let t_merged = + (merge_time_interval(&l.time_interval, &r.time_interval).duration_ms() + 1) as f64; + + let (ll, ur) = merge_rect(&l.spatial_bounds, &r.spatial_bounds); + let merged_rect = BoundingBox2D::new_unchecked(ll, ur); + + let x_merged = merged_rect.size_x(); + let y_merged = merged_rect.size_y(); + + let x_sum = l.spatial_bounds.size_x() + r.spatial_bounds.size_x(); + let y_sum = l.spatial_bounds.size_y() + r.spatial_bounds.size_y(); + + let vol_merged = t_merged * x_merged * y_merged; + let vol_sum = t_sum * x_sum * y_sum; + + f64_max(0.0, vol_merged / vol_sum - 1.0) +} + +/// Computes the minimum of two f64 values. +/// This method has no proper handling for `f64::NAN` values. +fn f64_min(l: f64, r: f64) -> f64 { + if l < r { + l + } else { + r + } +} + +/// Computes the maximum of two f64 values. +/// This method has no proper handling for `f64::NAN` values. +fn f64_max(l: f64, r: f64) -> f64 { + if l > r { + l + } else { + r + } +} + +/// A [description][ExecutorTaskDescription] of an [`Executor`][geoengine_operators::pro::executor::Executor] task +/// for plot data. +#[derive(Clone, Debug)] +pub struct PlotDescription { + pub id: WorkflowId, + pub spatial_bounds: BoundingBox2D, + pub temporal_bounds: TimeInterval, +} + +impl PlotDescription { + pub fn new( + id: WorkflowId, + spatial_bounds: BoundingBox2D, + temporal_bounds: TimeInterval, + ) -> PlotDescription { + PlotDescription { + id, + spatial_bounds, + temporal_bounds, + } + } +} + +#[async_trait::async_trait] +impl ExecutorTaskDescription for PlotDescription { + type KeyType = WorkflowId; + type ResultType = crate::error::Result; + + fn primary_key(&self) -> &Self::KeyType { + &self.id + } + + fn is_contained_in(&self, other: &Self) -> bool { + self.temporal_bounds == other.temporal_bounds + && approx_eq!(BoundingBox2D, self.spatial_bounds, other.spatial_bounds) + } + + fn slice_result(&self, result: &Self::ResultType) -> Option { + Some(match result { + Ok(wpo) => Ok(wpo.clone()), + Err(_e) => Err(crate::error::Error::NotYetImplemented), + }) + } +} + +/// A trait for spatio-temporal filterable [feature collections][FeatureCollection]. +pub trait STFilterable { + /// Filters out features that do not intersect the given spatio-temporal bounds + /// and returns the resulting [`FeatureCollection`]. + fn apply_filter( + self, + t: &TimeInterval, + s: &BoundingBox2D, + ) -> geoengine_datatypes::util::Result>; +} + +impl<'a, G> STFilterable for &'a FeatureCollection +where + G: Geometry + ArrowTyped, + Self: IntoIterator, + ::Item: Intersects + Intersects, +{ + fn apply_filter( + self, + t: &TimeInterval, + s: &BoundingBox2D, + ) -> geoengine_datatypes::util::Result> { + let mask = self + .into_iter() + .map(|x| x.intersects(t) && x.intersects(s)) + .collect::>(); + self.filter(mask) + } +} + +pub type DataDescription = FeatureCollectionTaskDescription; +pub type MultiPointDescription = FeatureCollectionTaskDescription; +pub type MultiLinestringDescription = FeatureCollectionTaskDescription; +pub type MultiPolygonDescription = FeatureCollectionTaskDescription; + +/// A [description][ExecutorTaskDescription] of an [`Executor`][geoengine_operators::pro::executor::Executor] task +/// for feature data. +#[derive(Clone)] +pub struct FeatureCollectionTaskDescription { + pub id: WorkflowId, + pub query_rectangle: VectorQueryRectangle, + processor: Arc>>, + context: Arc, +} + +impl FeatureCollectionTaskDescription { + pub fn new( + id: WorkflowId, + query_rectangle: VectorQueryRectangle, + processor: Box>>, + context: C, + ) -> Self { + Self::new_arced(id, query_rectangle, processor.into(), Arc::new(context)) + } + + pub fn new_arced( + id: WorkflowId, + query_rectangle: VectorQueryRectangle, + processor: Arc>>, + context: Arc, + ) -> Self { + Self { + id, + query_rectangle, + processor, + context, + } + } +} + +impl Debug for FeatureCollectionTaskDescription +where + G: Geometry + ArrowTyped + 'static, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "FeatureCollectionTaskDescription{{id={:?}, query_rectangle={:?}}}", + &self.id, &self.query_rectangle + ) + } +} + +impl ExecutorTaskDescription for FeatureCollectionTaskDescription +where + G: Geometry + ArrowTyped + 'static, + for<'a> &'a FeatureCollection: STFilterable, +{ + type KeyType = WorkflowId; + type ResultType = geoengine_operators::util::Result>; + + fn primary_key(&self) -> &Self::KeyType { + &self.id + } + + fn is_contained_in(&self, other: &Self) -> bool { + other + .query_rectangle + .time_interval + .contains(&self.query_rectangle.time_interval) + && (!G::IS_GEOMETRY + || other + .query_rectangle + .spatial_bounds + .contains_bbox(&self.query_rectangle.spatial_bounds)) + } + + fn slice_result(&self, result: &Self::ResultType) -> Option { + match result { + Ok(collection) => { + let filtered = collection + .apply_filter( + &self.query_rectangle.time_interval, + &self.query_rectangle.spatial_bounds, + ) + .map_err(geoengine_operators::error::Error::from); + + match filtered { + Ok(f) if f.is_empty() => None, + Ok(f) => Some(Ok(f)), + Err(e) => Some(Err(e)), + } + } + Err(_e) => Some(Err(geoengine_operators::error::Error::NotYetImplemented)), + } + } +} + +#[async_trait::async_trait] +impl MergableTaskDescription for FeatureCollectionTaskDescription +where + G: Geometry + ArrowTyped + 'static, + for<'a> &'a FeatureCollection: STFilterable, +{ + fn merge(&self, other: &Self) -> Self { + let merged_t = merge_time_interval( + &self.query_rectangle.time_interval, + &other.query_rectangle.time_interval, + ); + let (ll, ur) = merge_rect( + &self.query_rectangle.spatial_bounds, + &other.query_rectangle.spatial_bounds, + ); + + Self { + id: self.id, + query_rectangle: VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new_unchecked(ll, ur), + time_interval: merged_t, + spatial_resolution: self.query_rectangle.spatial_resolution, + }, + processor: self.processor.clone(), + context: self.context.clone(), + } + } + + fn merge_dead_space(&self, other: &Self) -> f64 { + merge_dead_space(&self.query_rectangle, &other.query_rectangle) + } + + fn execute( + &self, + ) -> BoxFuture< + 'static, + geoengine_operators::pro::executor::error::Result>, + > { + let proc = self.processor.clone(); + let qr = self.query_rectangle; + let ctx = self.context.clone(); + + let stream_future = proc.into_stream(qr, ctx).map(|result| match result { + Ok(rb) => Ok(Box::pin(rb) as BoxStream<'static, Self::ResultType>), + Err(e) => Err(geoengine_operators::pro::executor::error::ExecutorError::from(e)), + }); + Box::pin(stream_future) + } +} + +/// A [description][ExecutorTaskDescription] of an [`geoengine_operators::pro::executor::Executor`] task +/// for raster data. +#[derive(Clone)] +pub struct RasterTaskDescription

+where + P: Pixel, +{ + pub id: WorkflowId, + pub query_rectangle: RasterQueryRectangle, + processor: Arc>, + context: Arc, +} + +impl

Debug for RasterTaskDescription

+where + P: Pixel, +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RasterTaskDescription{{id={:?}, query_rectangle={:?}}}", + &self.id, &self.query_rectangle + ) + } +} + +impl

RasterTaskDescription

+where + P: Pixel, +{ + pub fn new( + id: WorkflowId, + query_rectangle: RasterQueryRectangle, + processor: Box>, + context: C, + ) -> Self { + Self::new_arced(id, query_rectangle, processor.into(), Arc::new(context)) + } + + pub fn new_arced( + id: WorkflowId, + query_rectangle: RasterQueryRectangle, + processor: Arc>, + context: Arc, + ) -> Self { + Self { + id, + query_rectangle, + processor, + context, + } + } +} + +impl

ExecutorTaskDescription for RasterTaskDescription

+where + P: Pixel, +{ + type KeyType = WorkflowId; + type ResultType = geoengine_operators::util::Result>; + + fn primary_key(&self) -> &Self::KeyType { + &self.id + } + + fn is_contained_in(&self, other: &Self) -> bool { + other + .query_rectangle + .time_interval + .contains(&self.query_rectangle.time_interval) + && other + .query_rectangle + .spatial_bounds + .contains(&self.query_rectangle.spatial_bounds) + && other.query_rectangle.spatial_resolution == self.query_rectangle.spatial_resolution + } + + fn slice_result(&self, result: &Self::ResultType) -> Option { + match result { + Ok(r) => { + if r.temporal_bounds() + .intersects(&self.query_rectangle.time_interval) + && r.spatial_partition() + .intersects(&self.query_rectangle.spatial_bounds) + { + Some(Ok(r.clone())) + } else { + None + } + } + Err(_e) => Some(Err(geoengine_operators::error::Error::NotYetImplemented)), + } + } +} + +#[async_trait::async_trait] +impl

MergableTaskDescription for RasterTaskDescription

+where + P: Pixel, +{ + fn merge(&self, other: &Self) -> Self { + let merged_t = merge_time_interval( + &self.query_rectangle.time_interval, + &other.query_rectangle.time_interval, + ); + let (ll, ur) = merge_rect( + &self.query_rectangle.spatial_bounds, + &other.query_rectangle.spatial_bounds, + ); + + Self { + id: self.id, + query_rectangle: RasterQueryRectangle { + spatial_bounds: SpatialPartition2D::new_unchecked( + Coordinate2D::new(ll.x, ur.y), + Coordinate2D::new(ur.x, ll.y), + ), + time_interval: merged_t, + spatial_resolution: self.query_rectangle.spatial_resolution, + }, + processor: self.processor.clone(), + context: self.context.clone(), + } + } + + fn merge_dead_space(&self, other: &Self) -> f64 { + merge_dead_space(&self.query_rectangle, &other.query_rectangle) + } + + fn execute( + &self, + ) -> BoxFuture< + 'static, + geoengine_operators::pro::executor::error::Result>, + > { + let proc = self.processor.clone(); + let qr = self.query_rectangle; + let ctx = self.context.clone(); + + let stream_future = proc.into_stream(qr, ctx).map(|result| match result { + Ok(rb) => Ok(Box::pin(rb) as BoxStream<'static, Self::ResultType>), + Err(e) => Err(geoengine_operators::pro::executor::error::ExecutorError::from(e)), + }); + Box::pin(stream_future) + } +} + +#[cfg(test)] +mod tests { + use crate::handlers::plots::WrappedPlotOutput; + use crate::pro::executor::{ + DataDescription, MultiPointDescription, PlotDescription, RasterTaskDescription, + }; + use crate::util::Identifier; + use crate::workflows::workflow::WorkflowId; + use geoengine_datatypes::collections::{ + DataCollection, FeatureCollectionInfos, MultiPointCollection, + }; + use geoengine_datatypes::plots::PlotOutputFormat; + use geoengine_datatypes::primitives::{ + BoundingBox2D, Coordinate2D, Measurement, MultiPoint, NoGeometry, QueryRectangle, + SpatialPartition2D, SpatialResolution, TimeInterval, VectorQueryRectangle, + }; + use geoengine_datatypes::raster::{ + EmptyGrid2D, GeoTransform, GridIdx2D, GridOrEmpty, RasterDataType, RasterTile2D, + }; + use geoengine_datatypes::spatial_reference::SpatialReferenceOption; + use geoengine_datatypes::util::test::TestDefault; + use geoengine_operators::engine::{ + MockExecutionContext, MockQueryContext, RasterOperator, RasterQueryProcessor, + RasterResultDescriptor, TypedRasterQueryProcessor, TypedVectorQueryProcessor, + VectorOperator, VectorQueryProcessor, + }; + use geoengine_operators::mock::{ + MockFeatureCollectionSource, MockRasterSource, MockRasterSourceParams, + }; + use geoengine_operators::pro::executor::ExecutorTaskDescription; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn test_plot() { + let id = WorkflowId::new(); + + let pd1 = PlotDescription::new( + id, + BoundingBox2D::new((0.0, 0.0).into(), (10.0, 10.0).into()).unwrap(), + TimeInterval::new(0, 10).unwrap(), + ); + + let pd2 = PlotDescription::new( + id, + BoundingBox2D::new(Coordinate2D::new(4.0, 4.0), Coordinate2D::new(6.0, 6.0)).unwrap(), + TimeInterval::new(4, 6).unwrap(), + ); + + let pd3 = PlotDescription::new( + id, + BoundingBox2D::new(Coordinate2D::new(4.0, 4.0), Coordinate2D::new(6.0, 6.0)).unwrap(), + TimeInterval::new(0, 10).unwrap(), + ); + + let pd4 = PlotDescription::new( + id, + BoundingBox2D::new((0.0, 0.0).into(), (10.0, 10.0).into()).unwrap(), + TimeInterval::new(4, 6).unwrap(), + ); + + let result = Ok(WrappedPlotOutput { + output_format: PlotOutputFormat::JsonPlain, + plot_type: "test", + data: Default::default(), + }); + + assert!(pd1.is_contained_in(&pd1)); + assert!(!pd1.is_contained_in(&pd2)); + assert!(!pd1.is_contained_in(&pd3)); + assert!(!pd1.is_contained_in(&pd4)); + assert!(!pd2.is_contained_in(&pd1)); + assert!(!pd2.is_contained_in(&pd1)); + assert!(!pd2.is_contained_in(&pd3)); + assert!(!pd2.is_contained_in(&pd4)); + assert!(!pd3.is_contained_in(&pd1)); + assert!(!pd3.is_contained_in(&pd2)); + assert!(!pd3.is_contained_in(&pd4)); + assert!(!pd4.is_contained_in(&pd1)); + assert!(!pd4.is_contained_in(&pd2)); + assert!(!pd4.is_contained_in(&pd3)); + + assert_eq!( + result.as_ref().unwrap(), + pd1.slice_result(&result).unwrap().as_ref().unwrap() + ); + } + + #[tokio::test] + async fn test_data() { + let id = WorkflowId::new(); + + let collection = DataCollection::from_data( + vec![NoGeometry, NoGeometry, NoGeometry], + vec![ + TimeInterval::new(0, 10).unwrap(), + TimeInterval::new(2, 3).unwrap(), + TimeInterval::new(3, 8).unwrap(), + ], + HashMap::new(), + ) + .unwrap(); + + let source = MockFeatureCollectionSource::single(collection.clone()).boxed(); + + let source = source + .initialize(&MockExecutionContext::test_default()) + .await + .unwrap(); + + let proc: Arc> = + if let Ok(TypedVectorQueryProcessor::Data(p)) = source.query_processor() { + p + } else { + panic!() + } + .into(); + + let ctx = Arc::new(MockQueryContext::test_default()); + + let pd1 = DataDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new_unchecked( + (0.0, 0.0).into(), + (10.0, 10.0).into(), + ), + time_interval: TimeInterval::new_unchecked(0, 10), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd2 = DataDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new( + Coordinate2D::new(4.0, 4.0), + Coordinate2D::new(6.0, 6.0), + ) + .unwrap(), + time_interval: TimeInterval::new(4, 6).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd3 = DataDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new( + Coordinate2D::new(4.0, 4.0), + Coordinate2D::new(6.0, 6.0), + ) + .unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc, + ctx, + ); + + assert!(pd2.is_contained_in(&pd1)); + assert!(pd3.is_contained_in(&pd1)); + assert!(pd2.is_contained_in(&pd3)); + assert!(!pd1.is_contained_in(&pd2)); + assert!(!pd3.is_contained_in(&pd2)); + + let sliced = pd2.slice_result(&Ok(collection.clone())).unwrap().unwrap(); + + assert_eq!(2, sliced.len()); + + let sliced = pd3.slice_result(&Ok(collection)).unwrap().unwrap(); + + assert_eq!(3, sliced.len()); + } + + #[tokio::test] + async fn test_vector() { + let id = WorkflowId::new(); + + let collection = MultiPointCollection::from_data( + vec![ + MultiPoint::new(vec![Coordinate2D::new(1.0, 1.0)]).unwrap(), + MultiPoint::new(vec![Coordinate2D::new(5.0, 5.0)]).unwrap(), + MultiPoint::new(vec![Coordinate2D::new(5.0, 5.0)]).unwrap(), + ], + vec![ + TimeInterval::new(0, 10).unwrap(), + TimeInterval::new(2, 3).unwrap(), + TimeInterval::new(3, 8).unwrap(), + ], + HashMap::new(), + ) + .unwrap(); + + let source = MockFeatureCollectionSource::single(collection.clone()).boxed(); + + let source = source + .initialize(&MockExecutionContext::test_default()) + .await + .unwrap(); + + let proc: Arc> = + if let Ok(TypedVectorQueryProcessor::MultiPoint(p)) = source.query_processor() { + p + } else { + panic!() + } + .into(); + + let ctx = Arc::new(MockQueryContext::test_default()); + + let pd1 = MultiPointDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new( + Coordinate2D::new(0.0, 0.0), + Coordinate2D::new(10.0, 10.0), + ) + .unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd2 = MultiPointDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new( + Coordinate2D::new(4.0, 4.0), + Coordinate2D::new(6.0, 6.0), + ) + .unwrap(), + time_interval: TimeInterval::new(4, 6).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd3 = MultiPointDescription::new_arced( + id, + VectorQueryRectangle { + spatial_bounds: BoundingBox2D::new( + Coordinate2D::new(4.0, 4.0), + Coordinate2D::new(6.0, 6.0), + ) + .unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc, + ctx, + ); + + assert!(pd2.is_contained_in(&pd1)); + assert!(pd3.is_contained_in(&pd1)); + assert!(pd2.is_contained_in(&pd3)); + assert!(!pd1.is_contained_in(&pd2)); + assert!(!pd1.is_contained_in(&pd3)); + assert!(!pd3.is_contained_in(&pd2)); + + let sliced = pd2.slice_result(&Ok(collection.clone())).unwrap().unwrap(); + + assert_eq!(1, sliced.len()); + + let sliced = pd3.slice_result(&Ok(collection)).unwrap().unwrap(); + + assert_eq!(2, sliced.len()); + } + + #[tokio::test] + #[allow(clippy::too_many_lines)] + async fn test_raster() { + let id = WorkflowId::new(); + + let tile = RasterTile2D::::new( + TimeInterval::new(0, 10).unwrap(), + GridIdx2D::new([0, 0]), + GeoTransform::new((0., 10.).into(), 1.0, -1.0), + GridOrEmpty::Empty(EmptyGrid2D::new([10, 10].into(), 0_u8)), + ); + + let source = MockRasterSource { + params: MockRasterSourceParams { + data: vec![tile.clone()], + result_descriptor: RasterResultDescriptor { + data_type: RasterDataType::U8, + spatial_reference: SpatialReferenceOption::Unreferenced, + measurement: Measurement::Unitless, + no_data_value: None, + }, + }, + } + .boxed(); + + let source = source + .initialize(&MockExecutionContext::test_default()) + .await + .unwrap(); + + let proc: Arc> = + if let Ok(TypedRasterQueryProcessor::U8(p)) = source.query_processor() { + p + } else { + panic!() + } + .into(); + + let ctx = Arc::new(MockQueryContext::test_default()); + + let pd1 = RasterTaskDescription::::new_arced( + id, + QueryRectangle { + spatial_bounds: SpatialPartition2D::new_unchecked( + Coordinate2D::new(0.0, 10.0), + Coordinate2D::new(10.0, 0.0), + ), + time_interval: TimeInterval::new_unchecked(0, 10), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd2 = RasterTaskDescription::::new_arced( + id, + QueryRectangle { + spatial_bounds: SpatialPartition2D::new( + Coordinate2D::new(0.0, 10.0), + Coordinate2D::new(10.0, 0.0), + ) + .unwrap(), + time_interval: TimeInterval::new(4, 6).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc.clone(), + ctx.clone(), + ); + + let pd3 = RasterTaskDescription::::new_arced( + id, + QueryRectangle { + spatial_bounds: SpatialPartition2D::new( + Coordinate2D::new(4.0, 6.0), + Coordinate2D::new(6.0, 4.0), + ) + .unwrap(), + time_interval: TimeInterval::new(0, 10).unwrap(), + spatial_resolution: SpatialResolution::one(), + }, + proc, + ctx, + ); + + assert!(pd2.is_contained_in(&pd1)); + assert!(pd3.is_contained_in(&pd1)); + assert!(!pd2.is_contained_in(&pd3)); + assert!(!pd1.is_contained_in(&pd2)); + assert!(!pd1.is_contained_in(&pd3)); + assert!(!pd3.is_contained_in(&pd2)); + + { + let tile = RasterTile2D::::new( + TimeInterval::new(0, 10).unwrap(), + GridIdx2D::new([0, 0]), + GeoTransform::new((0., 10.).into(), 1.0, -1.0), + GridOrEmpty::Empty(EmptyGrid2D::new([10, 10].into(), 0_u8)), + ); + let res = Ok(tile); + + assert!(pd2.slice_result(&res).is_some()); + assert!(pd3.slice_result(&res).is_some()); + } + + { + let tile = RasterTile2D::::new( + TimeInterval::new(0, 3).unwrap(), + GridIdx2D::new([0, 0]), + GeoTransform::new((0., 10.).into(), 1.0, -1.0), + GridOrEmpty::Empty(EmptyGrid2D::new([10, 10].into(), 0_u8)), + ); + let res = Ok(tile); + + assert!(pd2.slice_result(&res).is_none()); + assert!(pd3.slice_result(&res).is_some()); + } + + { + let tile = RasterTile2D::::new( + TimeInterval::new(0, 10).unwrap(), + GridIdx2D::new([0, 0]), + GeoTransform::new((8., 2.).into(), 1.0, -1.0), + GridOrEmpty::Empty(EmptyGrid2D::new([10, 10].into(), 0_u8)), + ); + let res = Ok(tile); + + assert!(pd2.slice_result(&res).is_some()); + assert!(pd3.slice_result(&res).is_none()); + } + } +} diff --git a/services/src/pro/executor/scheduler.rs b/services/src/pro/executor/scheduler.rs new file mode 100644 index 000000000..e9d5d94f5 --- /dev/null +++ b/services/src/pro/executor/scheduler.rs @@ -0,0 +1,386 @@ +use futures_util::future::BoxFuture; +use futures_util::stream::{BoxStream, FuturesUnordered}; +use futures_util::StreamExt; +use geoengine_operators::pro::executor::error::Result; +use geoengine_operators::pro::executor::{Executor, ExecutorTaskDescription, StreamReceiver}; +use log::{debug, error, warn}; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use std::vec; +use tokio::sync::oneshot::{channel, Sender}; +use tokio::task::JoinHandle; + +/// An extension of the `ExecutorTaskDescription` that allows +/// to merge tasks and schedule the resulting query +#[async_trait::async_trait] +pub trait MergableTaskDescription: ExecutorTaskDescription { + /// Merges two descriptions + #[must_use] + fn merge(&self, other: &Self) -> Self; + + /// Calculates the dead space that would result when merging the + /// two given descriptions. + fn merge_dead_space(&self, other: &Self) -> f64; + + /// Executes the query + fn execute( + &self, + ) -> BoxFuture< + 'static, + geoengine_operators::pro::executor::error::Result>, + >; +} + +type TaskMap = HashMap<::KeyType, Vec>>; + +type SharedTaskMap = Arc>>; + +type ShutdownFlag = Arc; + +/// An enum representing either a single task, +/// or a set of merged tasks +enum ScheduledTask { + Single(TaskEntry), + Merged(MergedTasks), +} + +impl ScheduledTask { + /// Calculates the resulting dead space when merging this task with the given one. + fn merge_dead_space(&self, other: &Self) -> f64 { + match (self, other) { + (Self::Merged(l), Self::Merged(r)) => l.description.merge_dead_space(&r.description), + (Self::Merged(l), Self::Single(r)) => l.description.merge_dead_space(&r.description), + (Self::Single(l), Self::Merged(r)) => l.description.merge_dead_space(&r.description), + (Self::Single(l), Self::Single(r)) => l.description.merge_dead_space(&r.description), + } + } + + /// Schedules this task to the executor. + async fn schedule(self, executor: &Executor) -> Result<()> { + match self { + Self::Single(t) => { + debug!("Scheduling single task: {:?}", &t.description); + let stream_future = t.description.execute(); + if t.response + .send( + executor + .submit_stream_future(t.description, stream_future) + .await, + ) + .is_err() + { + warn!("Stream consumer dropped unexpectedly"); + } + } + Self::Merged(parent_task) => { + debug!("Scheduling merged task: {:?}", &parent_task.description); + let parent_stream_future = parent_task.description.execute(); + // We need to keep this in order to keep the stream alive + let _pres = executor + .submit_stream_future(parent_task.description.clone(), parent_stream_future) + .await?; + for task in parent_task.covered_tasks { + debug!( + " Appending task {:?} to {:?}", + &task.description, &parent_task.description + ); + let task_future = task.description.execute(); + if task + .response + .send( + executor + .submit_stream_future(task.description, task_future) + .await, + ) + .is_err() + { + warn!("Stream consumer dropped unexpectedly"); + } + } + } + } + Ok(()) + } + + /// Merges this task with the given one. + fn merge(self, other: Self) -> Self { + let (desc, covered_tasks) = match (self, other) { + (Self::Merged(l), Self::Merged(mut r)) => { + let mut covered_tasks = l.covered_tasks; + covered_tasks.append(&mut r.covered_tasks); + let desc = l.description.merge(&r.description); + (desc, covered_tasks) + } + (Self::Merged(l), Self::Single(r)) => { + let desc = l.description.merge(&r.description); + let mut covered_tasks = l.covered_tasks; + covered_tasks.push(r); + (desc, covered_tasks) + } + (Self::Single(l), Self::Merged(r)) => { + let desc = r.description.merge(&l.description); + let mut covered_tasks = r.covered_tasks; + covered_tasks.push(l); + (desc, covered_tasks) + } + (Self::Single(l), Self::Single(r)) => (l.description.merge(&r.description), vec![l, r]), + }; + Self::Merged(MergedTasks { + description: desc, + covered_tasks, + }) + } +} + +impl From> for ScheduledTask { + fn from(v: TaskEntry) -> Self { + Self::Single(v) + } +} + +/// Represents a set of merged tasks. +/// The `description` covers the bounds of all merged tasks, +/// `covered_tasks` contains all original tasks that were merged into this one. +struct MergedTasks { + description: Desc, + covered_tasks: Vec>, +} + +/// A task send to the scheduler, `response` is used +/// to return the result to the requesting task. +struct TaskEntry { + description: Desc, + response: Sender>>, +} + +/// The looper ticks with a fixed interval and collects tasks. +/// Every tick, all collected tasks are merged as far as possible +/// and scheduled to the underlying executor. +/// +/// Tasks are merged, if the occurring dead space is below `merge_dead_space_threshold`. +struct MergeLooper +where + Desc: MergableTaskDescription, +{ + timeout: Duration, + executor: Arc>, + tasks: SharedTaskMap, + merge_dead_space_threshold: f64, + shutdown: ShutdownFlag, +} + +impl MergeLooper +where + Desc: MergableTaskDescription, +{ + pub async fn main_loop(&mut self) { + log::info!("Starting merger loop."); + loop { + tokio::time::sleep(self.timeout).await; + + // Check shutown + if self.shutdown.load(Ordering::Relaxed) { + break; + } + + { + let new_tasks = { + let mut tasks = geoengine_operators::util::safe_lock_mutex(&self.tasks); + tasks.drain().collect::>() + }; + + if !new_tasks.is_empty() { + let executor = self.executor.clone(); + let threshold = self.merge_dead_space_threshold; + + tokio::spawn(async move { + debug!("Scheduling tasks."); + Self::schedule(executor, new_tasks, threshold).await; + debug!("Finished scheduling tasks."); + }); + } + } + } + log::info!("Finished merger loop."); + } + + /// Schedules the given set of tasks. Merging of tasks is performed, if + /// the occurring dead space is below the given threshold. + async fn schedule(executor: Arc>, tasks: TaskMap, threshold: f64) { + let merged_tasks = tokio::task::spawn_blocking(move || Self::merge_tasks(tasks, threshold)) + .await + .expect("Task merging must complete."); + + let futures = merged_tasks + .into_values() + .flat_map(std::iter::IntoIterator::into_iter) + .map(|task| task.schedule(executor.as_ref())) + .collect::>(); + + for res in futures.collect::>().await { + if let Err(e) = res { + error!("Failed to schedule tasks: {:?}", e); + } + } + } + + /// Merges the tasks with the given threshold + fn merge_tasks(tasks: TaskMap, threshold: f64) -> TaskMap { + tasks + .into_iter() + .map(|(k, v)| { + debug!("Merging {} tasks for key {:?}", v.len(), &k); + let merged = Self::handle_homogeneous_tasks(v, threshold); + debug!( + "Finished merging tasks for key {:?}. Resulted in {} tasks.", + &k, + merged.len() + ); + (k, merged) + }) + .collect::>() + } + + /// Merges a set of homogeneous tasks. It performs `single_merge_pass` until + /// the resulting task list does not change anymore. + fn handle_homogeneous_tasks( + mut tasks: Vec>, + threshold: f64, + ) -> Vec> { + // Merge the tasks + loop { + let old_size = tasks.len(); + tasks = Self::single_merge_pass(tasks, threshold); + if tasks.len() == old_size { + break; + } + } + tasks + } + + /// A single merge pass is a O(n^2) operation on the given task list. + /// It picks the first task from the list and tries to merge it + /// with all other tasks. + /// All tasks that could not be merged remain in the list and are processed + /// in a subsequent loop. This is repeated until the list unmerged tasks is empty. + fn single_merge_pass( + mut tasks: Vec>, + threshold: f64, + ) -> Vec> { + let mut merged = Vec::with_capacity(tasks.len()); + + while let Some(mut current) = tasks.pop() { + let mut tmp = Vec::with_capacity(tasks.len()); + for t in tasks { + if current.merge_dead_space(&t) <= threshold { + current = current.merge(t); + } else { + tmp.push(t); + } + } + merged.push(current); + tasks = tmp; + } + merged + } +} + +/// The scheduler ticks with a fixed interval and collects tasks. +/// Every tick, all collected tasks are merged as far as possible +/// and scheduled to the underlying executor. +/// +/// Tasks are merged, if the occurring dead space is below `merge_dead_space_threshold`. +pub struct TaskScheduler +where + Desc: MergableTaskDescription, +{ + executor: Arc>, + tasks: SharedTaskMap, + timeout: Duration, + _looper_handle: JoinHandle<()>, + shutdown: ShutdownFlag, +} + +impl Drop for TaskScheduler +where + Desc: MergableTaskDescription, +{ + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + } +} + +impl TaskScheduler +where + Desc: MergableTaskDescription, +{ + /// Creates a new scheduler with the given dead space and timeout. + pub fn new( + executor_buffer_size: usize, + merge_dead_space_threshold: f64, + timeout: Duration, + ) -> TaskScheduler { + let tasks = Arc::new(Mutex::new(HashMap::new())); + let shutdown = Arc::new(AtomicBool::new(false)); + let executor = Arc::new(Executor::new(executor_buffer_size)); + + let mut looper = MergeLooper { + timeout, + tasks: tasks.clone(), + executor: executor.clone(), + merge_dead_space_threshold, + shutdown: shutdown.clone(), + }; + + let looper_handle = tokio::spawn(async move { looper.main_loop().await }); + + Self { + executor, + tasks, + timeout, + _looper_handle: looper_handle, + shutdown, + } + } + + /// Returns the backing executor + pub fn executor(&self) -> &Executor { + self.executor.as_ref() + } + + /// Directly submits a task to the executor, bypassing scheduling + pub async fn fastpath(&self, key: Desc) -> Result> { + let stream = key.execute().await?; + self.executor.submit_stream(key, stream).await + } + + /// Schedules the given task. + pub async fn schedule_stream(&self, key: Desc) -> Result> { + if self.timeout.is_zero() { + return self.fastpath(key).await; + } + + let (tx, rx) = channel(); + + let task_entry = TaskEntry { + description: key.clone(), + response: tx, + }; + + { + let mut tasks = geoengine_operators::util::safe_lock_mutex(&self.tasks); + match tasks.entry(key.primary_key().clone()) { + Entry::Vacant(ve) => { + ve.insert(vec![task_entry.into()]); + } + Entry::Occupied(mut oe) => { + oe.get_mut().push(task_entry.into()); + } + } + } + rx.await? + } +} diff --git a/services/src/pro/handlers.rs b/services/src/pro/handlers.rs index bd7713ccb..dd3da8894 100644 --- a/services/src/pro/handlers.rs +++ b/services/src/pro/handlers.rs @@ -1,4 +1,7 @@ #[cfg(feature = "odm")] pub mod drone_mapping; +pub mod plots; pub mod projects; pub mod users; +pub mod wfs; +pub mod wms; diff --git a/services/src/pro/handlers/plots.rs b/services/src/pro/handlers/plots.rs new file mode 100644 index 000000000..119f23b31 --- /dev/null +++ b/services/src/pro/handlers/plots.rs @@ -0,0 +1,41 @@ +use actix_web::{web, FromRequest, Responder}; +use uuid::Uuid; + +use crate::error::Result; +use crate::handlers::plots::GetPlot; +use crate::pro::contexts::ProContext; +use crate::pro::executor::PlotDescription; +use crate::workflows::workflow::WorkflowId; + +pub(crate) fn init_plot_routes(cfg: &mut web::ServiceConfig) +where + C: ProContext, + C::Session: FromRequest, +{ + cfg.service(web::resource("/plot/{id}").route(web::get().to(get_plot_handler::))); +} + +/// Generates a [plot](crate::handlers::plots::WrappedPlotOutput). +/// This handler behaves the same as the standard [plot handler](crate::handlers::plots::get_plot_handler), +/// except that it uses an [executor](crate::pro::contexts::TaskManager) for query execution. +async fn get_plot_handler( + id: web::Path, + params: web::Query, + session: C::Session, + ctx: web::Data, +) -> Result { + let workflow_id = WorkflowId(*id.as_ref()); + let task_manager = ctx.task_manager(); + + let desc = PlotDescription { + id: workflow_id, + spatial_bounds: params.bbox, + temporal_bounds: params.time, + }; + + let task = crate::handlers::plots::process_plot_request(id, params, session, ctx); + + let result = task_manager.plot_executor().submit(desc, task).await?; + + Ok(web::Json(result?)) +} diff --git a/services/src/pro/handlers/wfs.rs b/services/src/pro/handlers/wfs.rs new file mode 100644 index 000000000..6dbf3f8a7 --- /dev/null +++ b/services/src/pro/handlers/wfs.rs @@ -0,0 +1,62 @@ +use crate::error::Result; +use crate::ogc::wfs::request::{GetFeature, WfsRequest}; +use crate::pro::contexts::{GetFeatureScheduler, ProContext}; +use crate::pro::executor::FeatureCollectionTaskDescription; +use crate::util::user_input::QueryEx; +use crate::workflows::workflow::WorkflowId; +use actix_web::{web, FromRequest, HttpResponse}; +use geoengine_operators::call_on_generic_vector_processor; + +pub(crate) fn init_wfs_routes(cfg: &mut web::ServiceConfig) +where + C: ProContext, + C::Session: FromRequest, +{ + cfg.service(web::resource("/wfs/{workflow}").route(web::get().to(wfs_handler::))); +} + +async fn wfs_handler( + workflow: web::Path, + request: QueryEx, + ctx: web::Data, + session: C::Session, +) -> Result { + match request.into_inner() { + WfsRequest::GetCapabilities(request) => { + crate::handlers::wfs::get_capabilities( + &request, + ctx.get_ref(), + session, + workflow.into_inner(), + ) + .await + } + WfsRequest::GetFeature(request) => { + get_feature(&request, ctx.get_ref(), session, workflow.into_inner()).await + } + _ => Ok(HttpResponse::NotImplemented().finish()), + } +} + +/// Retrieves feature data objects. See [wfs handler](`crate::handlers::wfs::get_feature`). +async fn get_feature( + request: &GetFeature, + ctx: &C, + session: C::Session, + endpoint: WorkflowId, +) -> Result { + let (processor, query_rect) = + crate::handlers::wfs::extract_operator_and_bounding_box(request, ctx, session, endpoint) + .await?; + + let query_ctx = ctx.query_context()?; + + let tm = ctx.task_manager(); + let json = call_on_generic_vector_processor!(processor, p => { + let desc = FeatureCollectionTaskDescription::new(endpoint, + query_rect, p, query_ctx); + let stream = tm.get_feature_scheduler().schedule_stream(desc).await?; + crate::handlers::wfs::vector_stream_to_geojson(Box::pin(stream)).await + })?; + Ok(HttpResponse::Ok().json(json)) +} diff --git a/services/src/pro/handlers/wms.rs b/services/src/pro/handlers/wms.rs new file mode 100644 index 000000000..5608906cd --- /dev/null +++ b/services/src/pro/handlers/wms.rs @@ -0,0 +1,94 @@ +use crate::error::Result; +use crate::error::{self, Error}; +use crate::ogc::wms::request::{GetMap, WmsRequest}; +use crate::util::user_input::QueryEx; +use crate::workflows::workflow::WorkflowId; +use actix_web::{web, FromRequest, HttpResponse}; + +use crate::pro::contexts::{GetRasterScheduler, ProContext}; +use crate::pro::executor::RasterTaskDescription; +use geoengine_operators::call_on_generic_raster_processor; +use geoengine_operators::util::raster_stream_to_png::raster_stream_to_png_bytes_stream; +use num_traits::AsPrimitive; + +pub(crate) fn init_wms_routes(cfg: &mut web::ServiceConfig) +where + C: ProContext, + C::Session: FromRequest, +{ + cfg.service(web::resource("/wms/{workflow}").route(web::get().to(wms_handler::))); +} + +async fn wms_handler( + workflow: web::Path, + request: QueryEx, + ctx: web::Data, + session: C::Session, +) -> Result { + match request.into_inner() { + WmsRequest::GetCapabilities(request) => { + let external_address = + crate::util::config::get_config_element::()? + .external_address + .ok_or(Error::ExternalAddressNotConfigured)?; + crate::handlers::wms::get_capabilities( + &request, + &external_address, + ctx.get_ref(), + session, + workflow.into_inner(), + ) + .await + } + WmsRequest::GetMap(request) => { + get_map(&request, ctx.get_ref(), session, workflow.into_inner()).await + } + WmsRequest::GetLegendGraphic(request) => { + crate::handlers::wms::get_legend_graphic(&request, ctx.get_ref(), workflow.into_inner()) + } + _ => Ok(HttpResponse::NotImplemented().finish()), + } +} + +/// Renders a map as raster image. +/// +/// # Example +/// +/// ```text +/// GET /wms/df756642-c5a3-4d72-8ad7-629d312ae993?request=GetMap&service=WMS&version=2.0.0&layers=df756642-c5a3-4d72-8ad7-629d312ae993&bbox=1,2,3,4&width=100&height=100&crs=EPSG%3A4326&styles=ssss&format=image%2Fpng +/// ``` +/// Response: +/// PNG image +async fn get_map( + request: &GetMap, + ctx: &C, + session: C::Session, + endpoint: WorkflowId, +) -> Result { + let (processor, query_rect, colorizer, no_data_value) = + crate::handlers::wms::extract_operator_and_bounding_box_and_colorizer( + request, ctx, session, endpoint, + ) + .await?; + + let query_ctx = ctx.query_context()?; + + let image_bytes = call_on_generic_raster_processor!( + processor, + p => { + + let desc = RasterTaskDescription::new( + endpoint, + query_rect, + p, + query_ctx, + ); + let stream = ctx.task_manager().get_raster_scheduler().schedule_stream(desc).await?; + raster_stream_to_png_bytes_stream(Box::pin(stream), query_rect, request.width, request.height, request.time, colorizer, no_data_value.map(AsPrimitive::as_)).await + } + ).map_err(error::Error::from)?; + + Ok(HttpResponse::Ok() + .content_type(mime::IMAGE_PNG) + .body(image_bytes)) +} diff --git a/services/src/pro/mod.rs b/services/src/pro/mod.rs index 0bb96d9a5..186a3fa98 100644 --- a/services/src/pro/mod.rs +++ b/services/src/pro/mod.rs @@ -2,6 +2,7 @@ pub mod contexts; pub mod datasets; +pub mod executor; pub mod handlers; pub mod projects; pub mod server; diff --git a/services/src/pro/server.rs b/services/src/pro/server.rs index 4d6f245d4..4cb505675 100644 --- a/services/src/pro/server.rs +++ b/services/src/pro/server.rs @@ -43,14 +43,14 @@ where .wrap(middleware::NormalizePath::trim()) .configure(configure_extractors) .configure(handlers::datasets::init_dataset_routes::) - .configure(handlers::plots::init_plot_routes::) + .configure(pro::handlers::plots::init_plot_routes::) .configure(pro::handlers::projects::init_project_routes::) .configure(pro::handlers::users::init_user_routes::) .configure(handlers::spatial_references::init_spatial_reference_routes::) .configure(handlers::upload::init_upload_routes::) .configure(handlers::wcs::init_wcs_routes::) - .configure(handlers::wfs::init_wfs_routes::) - .configure(handlers::wms::init_wms_routes::) + .configure(pro::handlers::wfs::init_wfs_routes::) + .configure(pro::handlers::wms::init_wms_routes::) .configure(handlers::workflows::init_workflow_routes::); #[cfg(feature = "odm")] { diff --git a/services/src/pro/util/tests.rs b/services/src/pro/util/tests.rs index 48770454e..ac97c285f 100644 --- a/services/src/pro/util/tests.rs +++ b/services/src/pro/util/tests.rs @@ -113,7 +113,7 @@ where .wrap(middleware::NormalizePath::trim()) .configure(configure_extractors) .configure(handlers::datasets::init_dataset_routes::) - .configure(handlers::plots::init_plot_routes::) + .configure(pro::handlers::plots::init_plot_routes::) .configure(pro::handlers::projects::init_project_routes::) .configure(pro::handlers::users::init_user_routes::) .configure(handlers::spatial_references::init_spatial_reference_routes::) diff --git a/services/src/util/config.rs b/services/src/util/config.rs index 24411c1da..85bebf873 100644 --- a/services/src/util/config.rs +++ b/services/src/util/config.rs @@ -356,3 +356,16 @@ pub struct GFBio { impl ConfigElement for GFBio { const KEY: &'static str = "gfbio"; } + +#[derive(Debug, Deserialize)] +pub struct Executor { + pub queue_size: usize, + pub raster_scheduler_timeout_ms: u64, + pub raster_scheduler_merge_threshold: f64, + pub feature_scheduler_timeout_ms: u64, + pub feature_scheduler_merge_threshold: f64, +} + +impl ConfigElement for Executor { + const KEY: &'static str = "executor"; +}