diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py
index e429f03a7200..198a4ad94cde 100644
--- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py
+++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py
@@ -394,6 +394,37 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
"""
return self._append(stages.Sort(*orders))
+ def search(
+ self, query_or_options: str | BooleanExpression | stages.SearchOptions
+ ) -> "_BasePipeline":
+ """
+ Adds a search stage to the pipeline.
+
+ This stage filters documents based on the provided query expression.
+
+ Example:
+ >>> from google.cloud.firestore_v1.pipeline_stages import SearchOptions
+ >>> from google.cloud.firestore_v1.pipeline_expressions import And, DocumentMatches, Field, GeoPoint
+ >>> # Search for restaurants matching either "waffles" or "pancakes" near a location
+ >>> pipeline = client.pipeline().collection("restaurants").search(
+ ... SearchOptions(
+ ... query=And(
+ ... DocumentMatches("waffles OR pancakes"),
+ ... Field.of("location").geo_distance(GeoPoint(38.9, -107.0)).less_than(1000)
+ ... ),
+ ... sort=Score().descending()
+ ... )
+ ... )
+
+ Args:
+ options: Either a string or expression representing the search query, or
+ A `SearchOptions` instance configuring the search.
+
+ Returns:
+ A new Pipeline object with this stage appended to the stage list
+ """
+ return self._append(stages.Search(query_or_options))
+
def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
"""
Performs a pseudo-random sampling of the documents from the previous stage.
diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py
index 630258f9cadd..2cfaa38a0454 100644
--- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py
+++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py
@@ -730,6 +730,53 @@ def less_than_or_equal(
[self, self._cast_to_expr_or_convert_to_constant(other)],
)
+ @expose_as_static
+ def between(
+ self, lower: Expression | float, upper: Expression | float
+ ) -> "BooleanExpression":
+ """Evaluates if the result of this expression is between
+ the lower bound (inclusive) and upper bound (inclusive).
+
+ This is functionally equivalent to performing an `And` operation with
+ `greater_than_or_equal` and `less_than_or_equal`.
+
+ Example:
+ >>> # Check if the 'age' field is between 18 and 65
+ >>> Field.of("age").between(18, 65)
+
+ Args:
+ lower: Lower bound (inclusive) of the range.
+ upper: Upper bound (inclusive) of the range.
+
+ Returns:
+ A new `BooleanExpression` representing the between comparison.
+ """
+ return And(
+ self.greater_than_or_equal(lower),
+ self.less_than_or_equal(upper),
+ )
+
+ @expose_as_static
+ def geo_distance(self, other: Expression | GeoPoint) -> "FunctionExpression":
+ """Evaluates to the distance in meters between the location in the specified
+ field and the query location.
+
+ Note: This Expression can only be used within a `Search` stage.
+
+ Example:
+ >>> # Calculate distance between the 'location' field and a target GeoPoint
+ >>> Field.of("location").geo_distance(target_point)
+
+ Args:
+ other: Compute distance to this GeoPoint expression or constant value.
+
+ Returns:
+ A new `FunctionExpression` representing the distance.
+ """
+ return FunctionExpression(
+ "geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)]
+ )
+
@expose_as_static
def equal_any(
self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression
@@ -2927,6 +2974,56 @@ def __init__(self):
super().__init__("rand", [], use_infix_repr=False)
+class Score(FunctionExpression):
+ """Evaluates to the search score that reflects the topicality of the document
+ to all of the text predicates (`queryMatch`)
+ in the search query. If `SearchOptions.query` is not set or does not contain
+ any text predicates, then this topicality score will always be `0`.
+
+ Note: This Expression can only be used within a `Search` stage.
+
+ Example:
+ >>> # Sort by search score and retrieve it via add_fields
+ >>> db.pipeline().collection("restaurants").search(
+ ... query="tacos",
+ ... sort=Score().descending(),
+ ... add_fields=[Score().as_("search_score")]
+ ... )
+
+ Returns:
+ A new `Expression` representing the score operation.
+ """
+
+ def __init__(self):
+ super().__init__("score", [], use_infix_repr=False)
+
+
+class DocumentMatches(BooleanExpression):
+ """Creates a boolean expression for a document match query.
+
+ Note: This Expression can only be used within a `Search` stage.
+
+ Example:
+ >>> # Find documents matching the query string
+ >>> db.pipeline().collection("restaurants").search(
+ ... query=DocumentMatches("pizza OR pasta")
+ ... )
+
+ Args:
+ query: The search query string or expression.
+
+ Returns:
+ A new `BooleanExpression` representing the document match.
+ """
+
+ def __init__(self, query: Expression | str):
+ super().__init__(
+ "document_matches",
+ [Expression._cast_to_expr_or_convert_to_constant(query)],
+ use_infix_repr=False,
+ )
+
+
class Variable(Expression):
"""
Creates an expression that retrieves the value of a variable bound via `Pipeline.define`.
diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py
index 6c5ac68ddf0d..f6c3b2cc7bf4 100644
--- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py
+++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py
@@ -30,6 +30,7 @@
AliasedExpression,
BooleanExpression,
CONSTANT_TYPE,
+ DocumentMatches,
Expression,
Field,
Ordering,
@@ -109,6 +110,79 @@ def percentage(value: float):
return SampleOptions(value, mode=SampleOptions.Mode.PERCENT)
+class SearchOptions:
+ """Options for configuring the `Search` pipeline stage."""
+
+ def __init__(
+ self,
+ query: str | BooleanExpression,
+ *,
+ limit: Optional[int] = None,
+ retrieval_depth: Optional[int] = None,
+ sort: Optional[Sequence[Ordering] | Ordering] = None,
+ add_fields: Optional[Sequence[Selectable]] = None,
+ offset: Optional[int] = None,
+ language_code: Optional[str] = None,
+ ):
+ """
+ Initializes a SearchOptions instance.
+
+ Args:
+ query (str | BooleanExpression): Specifies the search query that will be used to query and score documents
+ by the search stage. The query can be expressed as an `Expression`, which will be used to score
+ and filter the results. Not all expressions supported by Pipelines are supported in the Search query.
+ The query can also be expressed as a string in the Search DSL.
+ limit (Optional[int]): The maximum number of documents to return from the Search stage.
+ retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents
+ will be processed in the pre-sort order specified by the search index.
+ sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted.
+ add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`.
+ offset (Optional[int]): The number of documents to skip.
+ language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn".
+ """
+ self.query = DocumentMatches(query) if isinstance(query, str) else query
+ self.limit = limit
+ self.retrieval_depth = retrieval_depth
+ self.sort = [sort] if isinstance(sort, Ordering) else sort
+ self.add_fields = add_fields
+ self.offset = offset
+ self.language_code = language_code
+
+ def __repr__(self):
+ args = [f"query={self.query!r}"]
+ if self.limit is not None:
+ args.append(f"limit={self.limit}")
+ if self.retrieval_depth is not None:
+ args.append(f"retrieval_depth={self.retrieval_depth}")
+ if self.sort is not None:
+ args.append(f"sort={self.sort}")
+ if self.add_fields is not None:
+ args.append(f"add_fields={self.add_fields}")
+ if self.offset is not None:
+ args.append(f"offset={self.offset}")
+ if self.language_code is not None:
+ args.append(f"language_code={self.language_code!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+ def _to_dict(self) -> dict[str, Value]:
+ options = {"query": self.query._to_pb()}
+ if self.limit is not None:
+ options["limit"] = Value(integer_value=self.limit)
+ if self.retrieval_depth is not None:
+ options["retrieval_depth"] = Value(integer_value=self.retrieval_depth)
+ if self.sort is not None:
+ options["sort"] = Value(
+ array_value={"values": [s._to_pb() for s in self.sort]}
+ )
+ if self.add_fields is not None:
+ options["add_fields"] = Selectable._to_value(self.add_fields)
+ if self.offset is not None:
+ options["offset"] = Value(integer_value=self.offset)
+ if self.language_code is not None:
+ options["language_code"] = Value(string_value=self.language_code)
+ return options
+
+
class UnnestOptions:
"""Options for configuring the `Unnest` pipeline stage.
@@ -423,6 +497,24 @@ def _pb_args(self):
]
+class Search(Stage):
+ """Search stage."""
+
+ def __init__(self, query_or_options: str | BooleanExpression | SearchOptions):
+ super().__init__("search")
+ if isinstance(query_or_options, SearchOptions):
+ options = query_or_options
+ else:
+ options = SearchOptions(query=query_or_options)
+ self.options = options
+
+ def _pb_args(self) -> list[Value]:
+ return []
+
+ def _pb_options(self) -> dict[str, Value]:
+ return self.options._to_dict()
+
+
class Select(Stage):
"""Selects or creates a set of fields."""
diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml
index a801481dff4b..f473b24e8477 100644
--- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml
+++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml
@@ -148,8 +148,13 @@ data:
cities:
city1:
name: "San Francisco"
+ location: GEOPOINT(37.7749,-122.4194)
city2:
name: "New York"
+ location: GEOPOINT(40.7128,-74.0060)
+ city3:
+ name: "Saskatoon"
+ location: GEOPOINT(52.1579,-106.6702)
"cities/city1/landmarks":
lm1:
name: "Golden Gate Bridge"
@@ -167,4 +172,57 @@ data:
rating: 5
rev2:
author: "Bob"
- rating: 4
\ No newline at end of file
+ rating: 4
+ "cities/city3/landmarks":
+ lm4:
+ name: "Western Development Museum"
+ type: "Museum"
+ restaurants:
+ sunnySideUp:
+ name: "The Sunny Side Up"
+ description: "A cozy neighborhood diner serving classic breakfast favorites all day long, from fluffy pancakes to savory omelets."
+ location: GEOPOINT(39.7541,-105.0002)
+ menu: "
Breakfast Classics
- Denver Omelet - $12
- Buttermilk Pancakes - $10
- Steak and Eggs - $16
Sides
- Hash Browns - $4
- Thick-cut Bacon - $5
- Drip Coffee - $2
"
+ average_price_per_person: 15
+ goldenWaffle:
+ name: "The Golden Waffle"
+ description: "Specializing exclusively in Belgian-style waffles. Open daily from 6:00 AM to 11:00 AM."
+ location: GEOPOINT(39.7183,-104.9621)
+ menu: "Signature Waffles
- Strawberry Delight - $11
- Chicken and Waffles - $14
- Chocolate Chip Crunch - $10
Drinks
- Fresh OJ - $4
- Artisan Coffee - $3
"
+ average_price_per_person: 13
+ lotusBlossomThai:
+ name: "Lotus Blossom Thai"
+ description: "Authentic Thai cuisine featuring hand-crushed spices and traditional family recipes from the Chiang Mai region."
+ location: GEOPOINT(39.7315,-104.9847)
+ menu: "Appetizers
- Spring Rolls - $7
- Chicken Satay - $9
Main Course
- Pad Thai - $15
- Green Curry - $16
- Drunken Noodles - $15
"
+ average_price_per_person: 22
+ mileHighCatch:
+ name: "Mile High Catch"
+ description: "Freshly sourced seafood offering a wide variety of Pacific fish and Atlantic shellfish in an upscale atmosphere."
+ location: GEOPOINT(39.7401,-104.9903)
+ menu: "From the Raw Bar
- Oysters (Half Dozen) - $18
- Lobster Cocktail - $22
Entrees
- Pan-Seared Salmon - $28
- King Crab Legs - $45
- Fish and Chips - $19
"
+ average_price_per_person: 45
+ peakBurgers:
+ name: "Peak Burgers"
+ description: "Casual burger joint focused on locally sourced Colorado beef and hand-cut fries."
+ location: GEOPOINT(39.7622,-105.0125)
+ menu: "Burgers
- The Peak Double - $12
- Bison Burger - $15
- Veggie Stack - $11
Sides
- Truffle Fries - $6
- Onion Rings - $5
"
+ average_price_per_person: 18
+ solTacos:
+ name: "El Sol Tacos"
+ description: "A vibrant street-side taco stand serving up quick, delicious, and traditional Mexican street food."
+ location: GEOPOINT(39.6952,-105.0274)
+ menu: "Tacos ($3.50 each)
- Al Pastor
- Carne Asada
- Pollo Asado
- Nopales (Cactus)
Beverages
- Horchata - $4
- Mexican Coke - $3
"
+ average_price_per_person: 12
+ eastsideTacos:
+ name: "Eastside Cantina"
+ description: "Authentic street tacos and hand-shaken margaritas on the vibrant east side of the city."
+ location: GEOPOINT(39.735,-104.885)
+ menu: "Tacos
- Carnitas Tacos - $4
- Barbacoa Tacos - $4.50
- Shrimp Tacos - $5
Drinks
- House Margarita - $9
- Jarritos - $3
"
+ average_price_per_person: 18
+ eastsideChicken:
+ name: "Eastside Chicken"
+ description: "Fried chicken to go - next to Eastside Cantina."
+ location: GEOPOINT(39.735,-104.885)
+ menu: "Fried Chicken
- Drumstick - $4
- Wings - $1
- Sandwich - $9
Drinks
- House Margarita - $9
- Jarritos - $3
"
+ average_price_per_person: 12
diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml
index d9f96cd3cd65..253ffcd89a09 100644
--- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml
+++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml
@@ -760,3 +760,56 @@ tests:
- "value_or_default"
assert_results:
- value_or_default: "1984"
+ - description: expression_between
+ pipeline:
+ - Collection: restaurants
+ - Where:
+ - FunctionExpression.between:
+ - Field: average_price_per_person
+ - Constant: 15
+ - Constant: 20
+ - Select:
+ - name
+ - Sort:
+ - Ordering:
+ - Field: name
+ - ASCENDING
+ assert_results:
+ - name: Eastside Cantina
+ - name: Peak Burgers
+ - name: The Sunny Side Up
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - args:
+ - functionValue:
+ args:
+ - functionValue:
+ args:
+ - fieldReferenceValue: average_price_per_person
+ - integerValue: '15'
+ name: greater_than_or_equal
+ - functionValue:
+ args:
+ - fieldReferenceValue: average_price_per_person
+ - integerValue: '20'
+ name: less_than_or_equal
+ name: and
+ name: where
+ - args:
+ - mapValue:
+ fields:
+ name:
+ fieldReferenceValue: name
+ name: select
+ - args:
+ - mapValue:
+ fields:
+ direction:
+ stringValue: ascending
+ expression:
+ fieldReferenceValue: name
+ name: sort
diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml
new file mode 100644
index 000000000000..1e3568857c4b
--- /dev/null
+++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml
@@ -0,0 +1,450 @@
+tests:
+ - description: search_stage_basic
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query: waffles
+ limit: 2
+ assert_results:
+ - name: The Golden Waffle
+ description: Specializing exclusively in Belgian-style waffles. Open daily from
+ 6:00 AM to 11:00 AM.
+ location: GEOPOINT(39.7183, -104.9621)
+ menu: Signature Waffles
- Strawberry Delight - $11
- Chicken
+ and Waffles - $14
- Chocolate Chip Crunch - $10
Drinks
- Fresh
+ OJ - $4
- Artisan Coffee - $3
+ average_price_per_person: 13
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ limit:
+ integerValue: '2'
+ query:
+ functionValue:
+ args:
+ - stringValue: waffles
+ name: document_matches
+ - description: search_stage_full_options
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ DocumentMatches:
+ - Constant: tacos
+ limit: 5
+ retrieval_depth: 10
+ offset: 1
+ language_code: en
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ limit:
+ integerValue: '5'
+ retrieval_depth:
+ integerValue: '10'
+ offset:
+ integerValue: '1'
+ language_code:
+ stringValue: en
+ query:
+ functionValue:
+ args:
+ - stringValue: tacos
+ name: document_matches
+ assert_count: 1
+ - description: search_stage_with_sort_and_add_fields
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query: tacos
+ sort:
+ Ordering:
+ - Score: []
+ - DESCENDING
+ add_fields:
+ - AliasedExpression:
+ - Score: []
+ - search_score
+ - Select:
+ - name
+ - search_score
+ assert_results_approximate:
+ config:
+ # be flexible in score values, but should be > 0
+ absolute_tolerance: 0.99
+ data:
+ - name: Eastside Cantina
+ search_score: 1.0
+ - name: El Sol Tacos
+ search_score: 1.0
+ assert_proto:
+ pipeline:
+ stages:
+ - name: collection
+ args:
+ - referenceValue: /restaurants
+ - name: search
+ options:
+ query:
+ functionValue:
+ name: document_matches
+ args:
+ - stringValue: tacos
+ sort:
+ arrayValue:
+ values:
+ - mapValue:
+ fields:
+ direction:
+ stringValue: descending
+ expression:
+ functionValue:
+ name: score
+ add_fields:
+ mapValue:
+ fields:
+ search_score:
+ functionValue:
+ name: score
+ - name: select
+ args:
+ - mapValue:
+ fields:
+ name:
+ fieldReferenceValue: name
+ search_score:
+ fieldReferenceValue: search_score
+ - description: expression_geo_distance
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ FunctionExpression.less_than_or_equal:
+ - FunctionExpression.geo_distance:
+ - Field: location
+ - GeoPoint:
+ - 39.6985
+ - -105.024
+ - Constant: 1000.0
+ - Select:
+ - name
+ - Sort:
+ - Ordering:
+ - Field: name
+ - ASCENDING
+ assert_results:
+ - name: El Sol Tacos
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - functionValue:
+ args:
+ - fieldReferenceValue: location
+ - geoPointValue:
+ latitude: 39.6985
+ longitude: -105.024
+ name: geo_distance
+ - doubleValue: 1000.0
+ name: less_than_or_equal
+ - args:
+ - mapValue:
+ fields:
+ name:
+ fieldReferenceValue: name
+ name: select
+ - args:
+ - mapValue:
+ fields:
+ direction:
+ stringValue: ascending
+ expression:
+ fieldReferenceValue: name
+ name: sort
+ - description: search_full_document
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query: waffles
+ assert_results:
+ - name: The Golden Waffle
+ description: Specializing exclusively in Belgian-style waffles. Open daily from
+ 6:00 AM to 11:00 AM.
+ location: GEOPOINT(39.7183, -104.9621)
+ menu: Signature Waffles
- Strawberry Delight - $11
- Chicken
+ and Waffles - $14
- Chocolate Chip Crunch - $10
Drinks
- Fresh
+ OJ - $4
- Artisan Coffee - $3
+ average_price_per_person: 13
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - stringValue: waffles
+ name: document_matches
+ - description: search_negate_match
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ DocumentMatches:
+ - Constant: coffee -waffles
+ assert_results:
+ - name: The Sunny Side Up
+ description: A cozy neighborhood diner serving classic breakfast favorites all
+ day long, from fluffy pancakes to savory omelets.
+ location: GEOPOINT(39.7541, -105.0002)
+ menu: Breakfast Classics
- Denver Omelet - $12
- Buttermilk
+ Pancakes - $10
- Steak and Eggs - $16
Sides
- Hash
+ Browns - $4
- Thick-cut Bacon - $5
- Drip Coffee - $2
+ average_price_per_person: 15
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - stringValue: coffee -waffles
+ name: document_matches
+ - description: search_rquery_as_query_param
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query: chicken wings
+ assert_results:
+ - name: Eastside Chicken
+ description: Fried chicken to go - next to Eastside Cantina.
+ location: GEOPOINT(39.735, -104.885)
+ menu: Fried Chicken
- Drumstick - $4
- Wings - $1
- Sandwich
+ - $9
Drinks
- House Margarita - $9
- Jarritos -
+ $3
+ average_price_per_person: 12
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - stringValue: chicken wings
+ name: document_matches
+ - description: search_sort_by_distance
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ FunctionExpression.less_than_or_equal:
+ - FunctionExpression.geo_distance:
+ - Field: location
+ - GeoPoint:
+ - 39.6985
+ - -105.024
+ - Constant: 5600.0
+ sort:
+ Ordering:
+ - FunctionExpression.geo_distance:
+ - Field: location
+ - GeoPoint:
+ - 39.6985
+ - -105.024
+ - ASCENDING
+ assert_results:
+ - name: El Sol Tacos
+ description: A vibrant street-side taco stand serving up quick, delicious, and
+ traditional Mexican street food.
+ location: GEOPOINT(39.6952, -105.0274)
+ menu: Tacos ($3.50 each)
- Al Pastor
- Carne Asada
- Pollo
+ Asado
- Nopales (Cactus)
Beverages
- Horchata -
+ $4
- Mexican Coke - $3
+ average_price_per_person: 12
+ - name: Lotus Blossom Thai
+ description: Authentic Thai cuisine featuring hand-crushed spices and traditional
+ family recipes from the Chiang Mai region.
+ location: GEOPOINT(39.7315, -104.9847)
+ menu: Appetizers
- Spring Rolls - $7
- Chicken Satay - $9
Main
+ Course
- Pad Thai - $15
- Green Curry - $16
- Drunken
+ Noodles - $15
+ average_price_per_person: 22
+ - name: Mile High Catch
+ description: Freshly sourced seafood offering a wide variety of Pacific fish and
+ Atlantic shellfish in an upscale atmosphere.
+ location: GEOPOINT(39.7401, -104.9903)
+ menu: From the Raw Bar
- Oysters (Half Dozen) - $18
- Lobster
+ Cocktail - $22
Entrees
- Pan-Seared Salmon - $28
- King
+ Crab Legs - $45
- Fish and Chips - $19
+ average_price_per_person: 45
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - functionValue:
+ args:
+ - fieldReferenceValue: location
+ - geoPointValue:
+ latitude: 39.6985
+ longitude: -105.024
+ name: geo_distance
+ - doubleValue: 5600.0
+ name: less_than_or_equal
+ sort:
+ arrayValue:
+ values:
+ - mapValue:
+ fields:
+ direction:
+ stringValue: ascending
+ expression:
+ functionValue:
+ args:
+ - fieldReferenceValue: location
+ - geoPointValue:
+ latitude: 39.6985
+ longitude: -105.024
+ name: geo_distance
+ - description: search_add_fields_score
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ DocumentMatches:
+ - Constant: waffles
+ add_fields:
+ - AliasedExpression:
+ - Score: []
+ - searchScore
+ - Select:
+ - name
+ - searchScore
+ assert_proto:
+ pipeline:
+ stages:
+ - name: collection
+ args:
+ - referenceValue: /restaurants
+ - name: search
+ options:
+ query:
+ functionValue:
+ name: document_matches
+ args:
+ - stringValue: waffles
+ add_fields:
+ mapValue:
+ fields:
+ searchScore:
+ functionValue:
+ name: score
+ - name: select
+ args:
+ - mapValue:
+ fields:
+ name:
+ fieldReferenceValue: name
+ searchScore:
+ fieldReferenceValue: searchScore
+ assert_results_approximate:
+ config:
+ absolute_tolerance: 0.99
+ data:
+ - name: The Golden Waffle
+ searchScore: 1.0
+ - description: search_sort_by_score
+ pipeline:
+ - Collection: restaurants
+ - Search:
+ - SearchOptions:
+ query:
+ DocumentMatches:
+ - Constant: tacos
+ sort:
+ Ordering:
+ - Score: []
+ - DESCENDING
+ assert_results:
+ - name: Eastside Cantina
+ description: Authentic street tacos and hand-shaken margaritas on the vibrant
+ east side of the city.
+ location: GEOPOINT(39.735, -104.885)
+ menu: Tacos
- Carnitas Tacos - $4
- Barbacoa Tacos - $4.50
- Shrimp
+ Tacos - $5
Drinks
- House Margarita - $9
- Jarritos
+ - $3
+ average_price_per_person: 18
+ - name: El Sol Tacos
+ description: A vibrant street-side taco stand serving up quick, delicious, and
+ traditional Mexican street food.
+ location: GEOPOINT(39.6952, -105.0274)
+ menu: Tacos ($3.50 each)
- Al Pastor
- Carne Asada
- Pollo
+ Asado
- Nopales (Cactus)
Beverages
- Horchata -
+ $4
- Mexican Coke - $3
+ average_price_per_person: 12
+ assert_proto:
+ pipeline:
+ stages:
+ - args:
+ - referenceValue: /restaurants
+ name: collection
+ - name: search
+ options:
+ query:
+ functionValue:
+ args:
+ - stringValue: tacos
+ name: document_matches
+ sort:
+ arrayValue:
+ values:
+ - mapValue:
+ fields:
+ direction:
+ stringValue: descending
+ expression:
+ functionValue:
+ name: score
diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py
index d66767822ee9..6881279c665b 100644
--- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py
+++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py
@@ -28,7 +28,7 @@
from google.protobuf.json_format import MessageToDict
from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock
-from google.cloud.firestore import AsyncClient, Client
+from google.cloud.firestore import AsyncClient, Client, GeoPoint
from google.cloud.firestore_v1 import pipeline_expressions
from google.cloud.firestore_v1 import pipeline_expressions as expr
from google.cloud.firestore_v1 import pipeline_stages as stages
@@ -72,6 +72,28 @@ def yaml_loader(field="tests", dir_name="pipeline_e2e", attach_file_name=True):
combined_yaml.update(extracted)
elif isinstance(combined_yaml, list) and extracted:
combined_yaml.extend(extracted)
+
+ # Validate test keys
+ allowed_keys = {
+ "description",
+ "pipeline",
+ "assert_proto",
+ "assert_error",
+ "assert_results",
+ "assert_count",
+ "assert_results_approximate",
+ "assert_end_state",
+ "file_name",
+ }
+ if field == "tests" and isinstance(combined_yaml, list):
+ for item in combined_yaml:
+ if isinstance(item, dict):
+ for key in item:
+ if key not in allowed_keys:
+ raise ValueError(
+ f"Unrecognized key '{key}' in test '{item.get('description', 'Unknown')}' in file '{item.get('file_name', 'Unknown')}'"
+ )
+
return combined_yaml
@@ -111,6 +133,34 @@ def test_pipeline_expected_errors(test_dict, client):
assert match, f"error '{found_error}' does not match '{error_regex}'"
+def _assert_pipeline_results(
+ got_results, expected_results, expected_approximate_results, expected_count
+):
+ if expected_results:
+ assert got_results == expected_results
+ if expected_approximate_results is not None:
+ tolerance = 1e-4
+ if (
+ isinstance(expected_approximate_results, dict)
+ and "data" in expected_approximate_results
+ ):
+ if (
+ "config" in expected_approximate_results
+ and "absolute_tolerance" in expected_approximate_results["config"]
+ ):
+ tolerance = expected_approximate_results["config"]["absolute_tolerance"]
+ expected_approximate_results = expected_approximate_results["data"]
+
+ assert len(got_results) == len(expected_approximate_results), (
+ "got unexpected result count"
+ )
+ for idx in range(len(got_results)):
+ expected = expected_approximate_results[idx]
+ assert got_results[idx] == pytest.approx(expected, abs=tolerance)
+ if expected_count is not None:
+ assert len(got_results) == expected_count
+
+
@pytest.mark.parametrize(
"test_dict",
[
@@ -136,18 +186,9 @@ def test_pipeline_results(test_dict, client):
pipeline = parse_pipeline(client, test_dict["pipeline"])
# check if server responds as expected
got_results = [snapshot.data() for snapshot in pipeline.stream()]
- if expected_results:
- assert got_results == expected_results
- if expected_approximate_results:
- assert len(got_results) == len(expected_approximate_results), (
- "got unexpected result count"
- )
- for idx in range(len(got_results)):
- assert got_results[idx] == pytest.approx(
- expected_approximate_results[idx], abs=1e-4
- )
- if expected_count is not None:
- assert len(got_results) == expected_count
+ _assert_pipeline_results(
+ got_results, expected_results, expected_approximate_results, expected_count
+ )
if expected_end_state:
for doc_path, expected_content in expected_end_state.items():
doc_ref = client.document(doc_path)
@@ -209,18 +250,9 @@ async def test_pipeline_results_async(test_dict, async_client):
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
# check if server responds as expected
got_results = [snapshot.data() async for snapshot in pipeline.stream()]
- if expected_results:
- assert got_results == expected_results
- if expected_approximate_results:
- assert len(got_results) == len(expected_approximate_results), (
- "got unexpected result count"
- )
- for idx in range(len(got_results)):
- assert got_results[idx] == pytest.approx(
- expected_approximate_results[idx], abs=1e-4
- )
- if expected_count is not None:
- assert len(got_results) == expected_count
+ _assert_pipeline_results(
+ got_results, expected_results, expected_approximate_results, expected_count
+ )
if expected_end_state:
for doc_path, expected_content in expected_end_state.items():
doc_ref = async_client.document(doc_path)
@@ -395,12 +427,16 @@ def _parse_yaml_types(data):
else:
return [_parse_yaml_types(value) for value in data]
# detect timestamps
- if isinstance(data, str) and ":" in data:
+ if isinstance(data, str) and ":" in data and not data.startswith("GEOPOINT("):
try:
parsed_datetime = datetime.datetime.fromisoformat(data)
return parsed_datetime
except ValueError:
pass
+ if isinstance(data, str) and data.startswith("GEOPOINT("):
+ match = re.match(r"GEOPOINT\(([^,]+),\s*([^)]+)\)", data)
+ if match:
+ return GeoPoint(float(match.group(1)), float(match.group(2)))
if data == "NaN":
return float("NaN")
return data
diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py
index 82d89a12f978..f1408be240a7 100644
--- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py
+++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py
@@ -403,6 +403,8 @@ def test_pipeline_execute_stream_equivalence():
("replace_with", (Field.of("n"),), stages.ReplaceWith),
("sort", (Field.of("n").descending(),), stages.Sort),
("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort),
+ ("search", ("my query",), stages.Search),
+ ("search", (stages.SearchOptions(query="my query"),), stages.Search),
("sample", (10,), stages.Sample),
("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample),
("union", (_make_pipeline(),), stages.Union),
diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py
index b285e8e4b614..d2c935a5f25d 100644
--- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py
+++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py
@@ -790,6 +790,45 @@ def test_equal(self):
infix_instance = arg1.equal(arg2)
assert infix_instance == instance
+ def test_between(self):
+ arg1 = self._make_arg("Left")
+ arg2 = self._make_arg("Lower")
+ arg3 = self._make_arg("Upper")
+ instance = Expression.between(arg1, arg2, arg3)
+ assert instance.name == "and"
+ assert len(instance.params) == 2
+ assert instance.params[0].name == "greater_than_or_equal"
+ assert instance.params[1].name == "less_than_or_equal"
+ assert (
+ repr(instance)
+ == "And(Left.greater_than_or_equal(Lower), Left.less_than_or_equal(Upper))"
+ )
+ infix_instance = arg1.between(arg2, arg3)
+ assert infix_instance == instance
+
+ def test_geo_distance(self):
+ arg1 = self._make_arg("Left")
+ arg2 = self._make_arg("Right")
+ instance = Expression.geo_distance(arg1, arg2)
+ assert instance.name == "geo_distance"
+ assert instance.params == [arg1, arg2]
+ assert repr(instance) == "Left.geo_distance(Right)"
+ infix_instance = arg1.geo_distance(arg2)
+ assert infix_instance == instance
+
+ def test_document_matches(self):
+ arg1 = self._make_arg("Query")
+ instance = expr.DocumentMatches(arg1)
+ assert instance.name == "document_matches"
+ assert instance.params == [arg1]
+ assert repr(instance) == "DocumentMatches(Query)"
+
+ def test_score(self):
+ instance = expr.Score()
+ assert instance.name == "score"
+ assert instance.params == []
+ assert repr(instance) == "Score()"
+
def test_greater_than_or_equal(self):
arg1 = self._make_arg("Left")
arg2 = self._make_arg("Right")
diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py
index b9ab603b713b..064c41c37b70 100644
--- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py
+++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py
@@ -24,6 +24,7 @@
Constant,
Field,
Ordering,
+ DocumentMatches,
)
from google.cloud.firestore_v1.types.document import Value
from google.cloud.firestore_v1.vector import Vector
@@ -778,6 +779,71 @@ def test_to_pb_percent_mode(self):
assert len(result_percent.options) == 0
+class TestSearch:
+ def test_search_defaults(self):
+ options = stages.SearchOptions(query="technology")
+ assert options.query.name == "document_matches"
+ assert options.limit is None
+ assert options.retrieval_depth is None
+ assert options.sort is None
+ assert options.add_fields is None
+ assert options.offset is None
+ assert options.language_code is None
+
+ stage = stages.Search(options)
+ pb_opts = stage._pb_options()
+ assert "query" in pb_opts
+ assert "limit" not in pb_opts
+ assert "retrieval_depth" not in pb_opts
+
+ def test_search_full_options(self):
+ options = stages.SearchOptions(
+ query=DocumentMatches("tech"),
+ limit=10,
+ retrieval_depth=2,
+ sort=Ordering("score", Ordering.Direction.DESCENDING),
+ add_fields=[Field("extra")],
+ offset=5,
+ language_code="en",
+ )
+ assert options.limit == 10
+ assert options.retrieval_depth == 2
+ assert len(options.sort) == 1
+ assert options.offset == 5
+ assert options.language_code == "en"
+
+ stage = stages.Search(options)
+ pb_opts = stage._pb_options()
+
+ assert pb_opts["limit"].integer_value == 10
+ assert pb_opts["retrieval_depth"].integer_value == 2
+ assert len(pb_opts["sort"].array_value.values) == 1
+ assert pb_opts["offset"].integer_value == 5
+ assert pb_opts["language_code"].string_value == "en"
+ assert "query" in pb_opts
+
+ def test_search_string_query_wrapping(self):
+ options = stages.SearchOptions(query="science")
+ assert options.query.name == "document_matches"
+ assert options.query.params[0].value == "science"
+
+ def test_search_with_string(self):
+ stage = stages.Search("technology")
+ assert isinstance(stage.options, stages.SearchOptions)
+ assert stage.options.query.name == "document_matches"
+ assert stage.options.query.params[0].value == "technology"
+ pb_opts = stage._pb_options()
+ assert "query" in pb_opts
+
+ def test_search_with_boolean_expression(self):
+ expr = DocumentMatches("tech")
+ stage = stages.Search(expr)
+ assert isinstance(stage.options, stages.SearchOptions)
+ assert stage.options.query is expr
+ pb_opts = stage._pb_options()
+ assert "query" in pb_opts
+
+
class TestSelect:
def _make_one(self, *args, **kwargs):
return stages.Select(*args, **kwargs)