@@ -30,6 +30,7 @@ import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
3030import org .apache .wayang .basic .model .{DLModel , Model , LogisticRegressionModel ,DecisionTreeRegressionModel }
3131import org .apache .wayang .basic .operators .{DLTrainingOperator , GlobalReduceOperator , LocalCallbackSink , MapOperator , SampleOperator , LogisticRegressionOperator ,DecisionTreeRegressionOperator , LinearSVCOperator }
3232import org .apache .wayang .commons .util .profiledb .model .Experiment
33+ import org .apache .wayang .core .api .spatial .{SpatialGeometry , SpatialPredicate }
3334import org .apache .wayang .core .function .FunctionDescriptor .{SerializableBiFunction , SerializableBinaryOperator , SerializableFunction , SerializableIntUnaryOperator , SerializablePredicate }
3435import org .apache .wayang .core .optimizer .ProbabilisticDoubleInterval
3536import org .apache .wayang .core .optimizer .cardinality .CardinalityEstimator
@@ -281,6 +282,57 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
281282 thatKeyUdf : SerializableFunction [ThatOut , Key ]) =
282283 new JoinDataQuantaBuilder (this , that, thisKeyUdf, thatKeyUdf)
283284
285+ /**
286+ * Feed the built [[DataQuanta ]] into a spatial filter operator.
287+ * Requires the wayang-spatial plugin to be loaded.
288+ *
289+ * @param keyUdf function to extract geometry from elements
290+ * @param predicate the spatial predicate type
291+ * @param filterGeometry the geometry to filter against
292+ * @return a [[DataQuantaBuilder ]] representing the filtered output
293+ */
294+ def spatialFilter (
295+ keyUdf : SerializableFunction [Out , _ <: SpatialGeometry ],
296+ predicate : SpatialPredicate ,
297+ filterGeometry : SpatialGeometry
298+ ): SpatialFilterDataQuantaBuilder [Out ] =
299+ new SpatialFilterDataQuantaBuilder (this , keyUdf, predicate, filterGeometry)
300+
301+ /**
302+ * Feed the built [[DataQuanta ]] into a spatial filter operator with SQL pushdown support.
303+ *
304+ * @param keyUdf function to extract geometry from elements
305+ * @param predicate the spatial predicate type
306+ * @param filterGeometry the geometry to filter against
307+ * @param sqlGeometryColumn the name of the geometry column in the database for SQL pushdown
308+ * @return a [[SpatialFilterDataQuantaBuilder ]] representing the filtered output
309+ */
310+ def spatialFilter (
311+ keyUdf : SerializableFunction [Out , _ <: SpatialGeometry ],
312+ predicate : SpatialPredicate ,
313+ filterGeometry : SpatialGeometry ,
314+ sqlGeometryColumn : String
315+ ): SpatialFilterDataQuantaBuilder [Out ] =
316+ new SpatialFilterDataQuantaBuilder (this , keyUdf, predicate, filterGeometry)
317+ .withSqlGeometryColumnName(sqlGeometryColumn)
318+
319+ /**
320+ * Feed the built [[DataQuanta ]] of this and the given instance into a spatial join operator.
321+ *
322+ * @param thisKeyUdf function to extract geometry from this instance's elements
323+ * @param that the other [[DataQuantaBuilder ]] to join with
324+ * @param thatKeyUdf function to extract geometry from `that` instance's elements
325+ * @param predicate the spatial predicate type
326+ * @return a [[SpatialJoinDataQuantaBuilder ]] representing the joined output as Tuple2
327+ */
328+ def spatialJoin [ThatOut ](
329+ thisKeyUdf : SerializableFunction [Out , _ <: SpatialGeometry ],
330+ that : DataQuantaBuilder [_, ThatOut ],
331+ thatKeyUdf : SerializableFunction [ThatOut , _ <: SpatialGeometry ],
332+ predicate : SpatialPredicate
333+ ): SpatialJoinDataQuantaBuilder [Out , ThatOut ] =
334+ new SpatialJoinDataQuantaBuilder (this , that, thisKeyUdf, thatKeyUdf, predicate)
335+
284336 /**
285337 * Feed the built [[DataQuanta ]] of this and the given instance into a
286338 * [[org.apache.wayang.basic.operators.DLTrainingOperator ]].
@@ -510,12 +562,12 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
510562 * @param catalog Iceberg Catalog
511563 * @param schema Iceberg Schema of the table to create
512564 * @param tableIdentifier Iceberg Table Identifier of the table to create
513- * @param outputFileFormat File format of the output data files
565+ * @param outputFileFormat File format of the output data files
514566 * @return the collected data quanta
515567 */
516568
517- def writeIcebergTable (catalog : Catalog ,
518- schema : Schema ,
569+ def writeIcebergTable (catalog : Catalog ,
570+ schema : Schema ,
519571 tableIdentifier : TableIdentifier ,
520572 outputFileFormat : FileFormat ,
521573 jobName : String ): Unit = {
@@ -1959,3 +2011,41 @@ class KeyedDataQuantaBuilder[Out, Key](private val dataQuantaBuilder: DataQuanta
19592011 dataQuantaBuilder.coGroup(this .keyExtractor, that.dataQuantaBuilder, that.keyExtractor)
19602012
19612013}
2014+
2015+ class SpatialFilterDataQuantaBuilder [T ](inputDataQuanta : DataQuantaBuilder [_, T ],
2016+ keySelector : SerializableFunction [T , _ <: SpatialGeometry ],
2017+ predicateType : SpatialPredicate ,
2018+ filterGeometry : SpatialGeometry )
2019+ (implicit javaPlanBuilder : JavaPlanBuilder )
2020+ extends BasicDataQuantaBuilder [SpatialFilterDataQuantaBuilder [T ], T ] {
2021+
2022+ private var columnName : String = _
2023+
2024+ def withSqlGeometryColumnName (columnName : String ): SpatialFilterDataQuantaBuilder [T ] = {
2025+ this .columnName = columnName
2026+ this
2027+ }
2028+
2029+ override protected def build : DataQuanta [T ] = {
2030+ val dq = inputDataQuanta.dataQuanta()
2031+ dq.spatialFilterJava(keySelector, predicateType, filterGeometry, this .columnName)
2032+ }
2033+ }
2034+
2035+ class SpatialJoinDataQuantaBuilder [In0 , In1 ](inputDataQuanta0 : DataQuantaBuilder [_, In0 ],
2036+ inputDataQuanta1 : DataQuantaBuilder [_, In1 ],
2037+ keyUdf0 : SerializableFunction [In0 , _ <: SpatialGeometry ],
2038+ keyUdf1 : SerializableFunction [In1 , _ <: SpatialGeometry ],
2039+ predicateType : SpatialPredicate )
2040+ (implicit javaPlanBuilder : JavaPlanBuilder )
2041+ extends BasicDataQuantaBuilder [SpatialJoinDataQuantaBuilder [In0 , In1 ], RT2 [In0 , In1 ]] {
2042+
2043+ override protected def build : DataQuanta [RT2 [In0 , In1 ]] = {
2044+ val dq0 = inputDataQuanta0.dataQuanta()
2045+ val dq1 = inputDataQuanta1.dataQuanta()
2046+ applyTargetPlatforms(
2047+ dq0.spatialJoinJava(keyUdf0, dq1, keyUdf1, predicateType)(inputDataQuanta1.classTag),
2048+ this .getTargetPlatforms()
2049+ )
2050+ }
2051+ }
0 commit comments