diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py index e429f03a7200..198a4ad94cde 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py @@ -394,6 +394,37 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline": """ return self._append(stages.Sort(*orders)) + def search( + self, query_or_options: str | BooleanExpression | stages.SearchOptions + ) -> "_BasePipeline": + """ + Adds a search stage to the pipeline. + + This stage filters documents based on the provided query expression. + + Example: + >>> from google.cloud.firestore_v1.pipeline_stages import SearchOptions + >>> from google.cloud.firestore_v1.pipeline_expressions import And, DocumentMatches, Field, GeoPoint + >>> # Search for restaurants matching either "waffles" or "pancakes" near a location + >>> pipeline = client.pipeline().collection("restaurants").search( + ... SearchOptions( + ... query=And( + ... DocumentMatches("waffles OR pancakes"), + ... Field.of("location").geo_distance(GeoPoint(38.9, -107.0)).less_than(1000) + ... ), + ... sort=Score().descending() + ... ) + ... ) + + Args: + options: Either a string or expression representing the search query, or + A `SearchOptions` instance configuring the search. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Search(query_or_options)) + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": """ Performs a pseudo-random sampling of the documents from the previous stage. diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py index 630258f9cadd..2cfaa38a0454 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py @@ -730,6 +730,53 @@ def less_than_or_equal( [self, self._cast_to_expr_or_convert_to_constant(other)], ) + @expose_as_static + def between( + self, lower: Expression | float, upper: Expression | float + ) -> "BooleanExpression": + """Evaluates if the result of this expression is between + the lower bound (inclusive) and upper bound (inclusive). + + This is functionally equivalent to performing an `And` operation with + `greater_than_or_equal` and `less_than_or_equal`. + + Example: + >>> # Check if the 'age' field is between 18 and 65 + >>> Field.of("age").between(18, 65) + + Args: + lower: Lower bound (inclusive) of the range. + upper: Upper bound (inclusive) of the range. + + Returns: + A new `BooleanExpression` representing the between comparison. + """ + return And( + self.greater_than_or_equal(lower), + self.less_than_or_equal(upper), + ) + + @expose_as_static + def geo_distance(self, other: Expression | GeoPoint) -> "FunctionExpression": + """Evaluates to the distance in meters between the location in the specified + field and the query location. + + Note: This Expression can only be used within a `Search` stage. + + Example: + >>> # Calculate distance between the 'location' field and a target GeoPoint + >>> Field.of("location").geo_distance(target_point) + + Args: + other: Compute distance to this GeoPoint expression or constant value. + + Returns: + A new `FunctionExpression` representing the distance. + """ + return FunctionExpression( + "geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) + @expose_as_static def equal_any( self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression @@ -2927,6 +2974,56 @@ def __init__(self): super().__init__("rand", [], use_infix_repr=False) +class Score(FunctionExpression): + """Evaluates to the search score that reflects the topicality of the document + to all of the text predicates (`queryMatch`) + in the search query. If `SearchOptions.query` is not set or does not contain + any text predicates, then this topicality score will always be `0`. + + Note: This Expression can only be used within a `Search` stage. + + Example: + >>> # Sort by search score and retrieve it via add_fields + >>> db.pipeline().collection("restaurants").search( + ... query="tacos", + ... sort=Score().descending(), + ... add_fields=[Score().as_("search_score")] + ... ) + + Returns: + A new `Expression` representing the score operation. + """ + + def __init__(self): + super().__init__("score", [], use_infix_repr=False) + + +class DocumentMatches(BooleanExpression): + """Creates a boolean expression for a document match query. + + Note: This Expression can only be used within a `Search` stage. + + Example: + >>> # Find documents matching the query string + >>> db.pipeline().collection("restaurants").search( + ... query=DocumentMatches("pizza OR pasta") + ... ) + + Args: + query: The search query string or expression. + + Returns: + A new `BooleanExpression` representing the document match. + """ + + def __init__(self, query: Expression | str): + super().__init__( + "document_matches", + [Expression._cast_to_expr_or_convert_to_constant(query)], + use_infix_repr=False, + ) + + class Variable(Expression): """ Creates an expression that retrieves the value of a variable bound via `Pipeline.define`. diff --git a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py index 6c5ac68ddf0d..f6c3b2cc7bf4 100644 --- a/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py +++ b/packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py @@ -30,6 +30,7 @@ AliasedExpression, BooleanExpression, CONSTANT_TYPE, + DocumentMatches, Expression, Field, Ordering, @@ -109,6 +110,79 @@ def percentage(value: float): return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) +class SearchOptions: + """Options for configuring the `Search` pipeline stage.""" + + def __init__( + self, + query: str | BooleanExpression, + *, + limit: Optional[int] = None, + retrieval_depth: Optional[int] = None, + sort: Optional[Sequence[Ordering] | Ordering] = None, + add_fields: Optional[Sequence[Selectable]] = None, + offset: Optional[int] = None, + language_code: Optional[str] = None, + ): + """ + Initializes a SearchOptions instance. + + Args: + query (str | BooleanExpression): Specifies the search query that will be used to query and score documents + by the search stage. The query can be expressed as an `Expression`, which will be used to score + and filter the results. Not all expressions supported by Pipelines are supported in the Search query. + The query can also be expressed as a string in the Search DSL. + limit (Optional[int]): The maximum number of documents to return from the Search stage. + retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents + will be processed in the pre-sort order specified by the search index. + sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted. + add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`. + offset (Optional[int]): The number of documents to skip. + language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn". + """ + self.query = DocumentMatches(query) if isinstance(query, str) else query + self.limit = limit + self.retrieval_depth = retrieval_depth + self.sort = [sort] if isinstance(sort, Ordering) else sort + self.add_fields = add_fields + self.offset = offset + self.language_code = language_code + + def __repr__(self): + args = [f"query={self.query!r}"] + if self.limit is not None: + args.append(f"limit={self.limit}") + if self.retrieval_depth is not None: + args.append(f"retrieval_depth={self.retrieval_depth}") + if self.sort is not None: + args.append(f"sort={self.sort}") + if self.add_fields is not None: + args.append(f"add_fields={self.add_fields}") + if self.offset is not None: + args.append(f"offset={self.offset}") + if self.language_code is not None: + args.append(f"language_code={self.language_code!r}") + return f"{self.__class__.__name__}({', '.join(args)})" + + def _to_dict(self) -> dict[str, Value]: + options = {"query": self.query._to_pb()} + if self.limit is not None: + options["limit"] = Value(integer_value=self.limit) + if self.retrieval_depth is not None: + options["retrieval_depth"] = Value(integer_value=self.retrieval_depth) + if self.sort is not None: + options["sort"] = Value( + array_value={"values": [s._to_pb() for s in self.sort]} + ) + if self.add_fields is not None: + options["add_fields"] = Selectable._to_value(self.add_fields) + if self.offset is not None: + options["offset"] = Value(integer_value=self.offset) + if self.language_code is not None: + options["language_code"] = Value(string_value=self.language_code) + return options + + class UnnestOptions: """Options for configuring the `Unnest` pipeline stage. @@ -423,6 +497,24 @@ def _pb_args(self): ] +class Search(Stage): + """Search stage.""" + + def __init__(self, query_or_options: str | BooleanExpression | SearchOptions): + super().__init__("search") + if isinstance(query_or_options, SearchOptions): + options = query_or_options + else: + options = SearchOptions(query=query_or_options) + self.options = options + + def _pb_args(self) -> list[Value]: + return [] + + def _pb_options(self) -> dict[str, Value]: + return self.options._to_dict() + + class Select(Stage): """Selects or creates a set of fields.""" diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml index a801481dff4b..f473b24e8477 100644 --- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml @@ -148,8 +148,13 @@ data: cities: city1: name: "San Francisco" + location: GEOPOINT(37.7749,-122.4194) city2: name: "New York" + location: GEOPOINT(40.7128,-74.0060) + city3: + name: "Saskatoon" + location: GEOPOINT(52.1579,-106.6702) "cities/city1/landmarks": lm1: name: "Golden Gate Bridge" @@ -167,4 +172,57 @@ data: rating: 5 rev2: author: "Bob" - rating: 4 \ No newline at end of file + rating: 4 + "cities/city3/landmarks": + lm4: + name: "Western Development Museum" + type: "Museum" + restaurants: + sunnySideUp: + name: "The Sunny Side Up" + description: "A cozy neighborhood diner serving classic breakfast favorites all day long, from fluffy pancakes to savory omelets." + location: GEOPOINT(39.7541,-105.0002) + menu: "

Breakfast Classics

Sides

" + average_price_per_person: 15 + goldenWaffle: + name: "The Golden Waffle" + description: "Specializing exclusively in Belgian-style waffles. Open daily from 6:00 AM to 11:00 AM." + location: GEOPOINT(39.7183,-104.9621) + menu: "

Signature Waffles

Drinks

" + average_price_per_person: 13 + lotusBlossomThai: + name: "Lotus Blossom Thai" + description: "Authentic Thai cuisine featuring hand-crushed spices and traditional family recipes from the Chiang Mai region." + location: GEOPOINT(39.7315,-104.9847) + menu: "

Appetizers

Main Course

" + average_price_per_person: 22 + mileHighCatch: + name: "Mile High Catch" + description: "Freshly sourced seafood offering a wide variety of Pacific fish and Atlantic shellfish in an upscale atmosphere." + location: GEOPOINT(39.7401,-104.9903) + menu: "

From the Raw Bar

Entrees

" + average_price_per_person: 45 + peakBurgers: + name: "Peak Burgers" + description: "Casual burger joint focused on locally sourced Colorado beef and hand-cut fries." + location: GEOPOINT(39.7622,-105.0125) + menu: "

Burgers

Sides

" + average_price_per_person: 18 + solTacos: + name: "El Sol Tacos" + description: "A vibrant street-side taco stand serving up quick, delicious, and traditional Mexican street food." + location: GEOPOINT(39.6952,-105.0274) + menu: "

Tacos ($3.50 each)

Beverages

" + average_price_per_person: 12 + eastsideTacos: + name: "Eastside Cantina" + description: "Authentic street tacos and hand-shaken margaritas on the vibrant east side of the city." + location: GEOPOINT(39.735,-104.885) + menu: "

Tacos

Drinks

" + average_price_per_person: 18 + eastsideChicken: + name: "Eastside Chicken" + description: "Fried chicken to go - next to Eastside Cantina." + location: GEOPOINT(39.735,-104.885) + menu: "

Fried Chicken

Drinks

" + average_price_per_person: 12 diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml index d9f96cd3cd65..253ffcd89a09 100644 --- a/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/logical.yaml @@ -760,3 +760,56 @@ tests: - "value_or_default" assert_results: - value_or_default: "1984" + - description: expression_between + pipeline: + - Collection: restaurants + - Where: + - FunctionExpression.between: + - Field: average_price_per_person + - Constant: 15 + - Constant: 20 + - Select: + - name + - Sort: + - Ordering: + - Field: name + - ASCENDING + assert_results: + - name: Eastside Cantina + - name: Peak Burgers + - name: The Sunny Side Up + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: average_price_per_person + - integerValue: '15' + name: greater_than_or_equal + - functionValue: + args: + - fieldReferenceValue: average_price_per_person + - integerValue: '20' + name: less_than_or_equal + name: and + name: where + - args: + - mapValue: + fields: + name: + fieldReferenceValue: name + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: name + name: sort diff --git a/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml new file mode 100644 index 000000000000..1e3568857c4b --- /dev/null +++ b/packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml @@ -0,0 +1,450 @@ +tests: + - description: search_stage_basic + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: waffles + limit: 2 + assert_results: + - name: The Golden Waffle + description: Specializing exclusively in Belgian-style waffles. Open daily from + 6:00 AM to 11:00 AM. + location: GEOPOINT(39.7183, -104.9621) + menu:

Signature Waffles

Drinks

+ average_price_per_person: 13 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + limit: + integerValue: '2' + query: + functionValue: + args: + - stringValue: waffles + name: document_matches + - description: search_stage_full_options + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + DocumentMatches: + - Constant: tacos + limit: 5 + retrieval_depth: 10 + offset: 1 + language_code: en + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + limit: + integerValue: '5' + retrieval_depth: + integerValue: '10' + offset: + integerValue: '1' + language_code: + stringValue: en + query: + functionValue: + args: + - stringValue: tacos + name: document_matches + assert_count: 1 + - description: search_stage_with_sort_and_add_fields + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: tacos + sort: + Ordering: + - Score: [] + - DESCENDING + add_fields: + - AliasedExpression: + - Score: [] + - search_score + - Select: + - name + - search_score + assert_results_approximate: + config: + # be flexible in score values, but should be > 0 + absolute_tolerance: 0.99 + data: + - name: Eastside Cantina + search_score: 1.0 + - name: El Sol Tacos + search_score: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /restaurants + - name: search + options: + query: + functionValue: + name: document_matches + args: + - stringValue: tacos + sort: + arrayValue: + values: + - mapValue: + fields: + direction: + stringValue: descending + expression: + functionValue: + name: score + add_fields: + mapValue: + fields: + search_score: + functionValue: + name: score + - name: select + args: + - mapValue: + fields: + name: + fieldReferenceValue: name + search_score: + fieldReferenceValue: search_score + - description: expression_geo_distance + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + FunctionExpression.less_than_or_equal: + - FunctionExpression.geo_distance: + - Field: location + - GeoPoint: + - 39.6985 + - -105.024 + - Constant: 1000.0 + - Select: + - name + - Sort: + - Ordering: + - Field: name + - ASCENDING + assert_results: + - name: El Sol Tacos + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: location + - geoPointValue: + latitude: 39.6985 + longitude: -105.024 + name: geo_distance + - doubleValue: 1000.0 + name: less_than_or_equal + - args: + - mapValue: + fields: + name: + fieldReferenceValue: name + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: name + name: sort + - description: search_full_document + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: waffles + assert_results: + - name: The Golden Waffle + description: Specializing exclusively in Belgian-style waffles. Open daily from + 6:00 AM to 11:00 AM. + location: GEOPOINT(39.7183, -104.9621) + menu:

Signature Waffles

Drinks

+ average_price_per_person: 13 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - stringValue: waffles + name: document_matches + - description: search_negate_match + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + DocumentMatches: + - Constant: coffee -waffles + assert_results: + - name: The Sunny Side Up + description: A cozy neighborhood diner serving classic breakfast favorites all + day long, from fluffy pancakes to savory omelets. + location: GEOPOINT(39.7541, -105.0002) + menu:

Breakfast Classics

Sides

+ average_price_per_person: 15 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - stringValue: coffee -waffles + name: document_matches + - description: search_rquery_as_query_param + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: chicken wings + assert_results: + - name: Eastside Chicken + description: Fried chicken to go - next to Eastside Cantina. + location: GEOPOINT(39.735, -104.885) + menu:

Fried Chicken

Drinks

+ average_price_per_person: 12 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - stringValue: chicken wings + name: document_matches + - description: search_sort_by_distance + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + FunctionExpression.less_than_or_equal: + - FunctionExpression.geo_distance: + - Field: location + - GeoPoint: + - 39.6985 + - -105.024 + - Constant: 5600.0 + sort: + Ordering: + - FunctionExpression.geo_distance: + - Field: location + - GeoPoint: + - 39.6985 + - -105.024 + - ASCENDING + assert_results: + - name: El Sol Tacos + description: A vibrant street-side taco stand serving up quick, delicious, and + traditional Mexican street food. + location: GEOPOINT(39.6952, -105.0274) + menu:

Tacos ($3.50 each)

Beverages

+ average_price_per_person: 12 + - name: Lotus Blossom Thai + description: Authentic Thai cuisine featuring hand-crushed spices and traditional + family recipes from the Chiang Mai region. + location: GEOPOINT(39.7315, -104.9847) + menu:

Appetizers

Main + Course

+ average_price_per_person: 22 + - name: Mile High Catch + description: Freshly sourced seafood offering a wide variety of Pacific fish and + Atlantic shellfish in an upscale atmosphere. + location: GEOPOINT(39.7401, -104.9903) + menu:

From the Raw Bar

Entrees

+ average_price_per_person: 45 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: location + - geoPointValue: + latitude: 39.6985 + longitude: -105.024 + name: geo_distance + - doubleValue: 5600.0 + name: less_than_or_equal + sort: + arrayValue: + values: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + functionValue: + args: + - fieldReferenceValue: location + - geoPointValue: + latitude: 39.6985 + longitude: -105.024 + name: geo_distance + - description: search_add_fields_score + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + DocumentMatches: + - Constant: waffles + add_fields: + - AliasedExpression: + - Score: [] + - searchScore + - Select: + - name + - searchScore + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /restaurants + - name: search + options: + query: + functionValue: + name: document_matches + args: + - stringValue: waffles + add_fields: + mapValue: + fields: + searchScore: + functionValue: + name: score + - name: select + args: + - mapValue: + fields: + name: + fieldReferenceValue: name + searchScore: + fieldReferenceValue: searchScore + assert_results_approximate: + config: + absolute_tolerance: 0.99 + data: + - name: The Golden Waffle + searchScore: 1.0 + - description: search_sort_by_score + pipeline: + - Collection: restaurants + - Search: + - SearchOptions: + query: + DocumentMatches: + - Constant: tacos + sort: + Ordering: + - Score: [] + - DESCENDING + assert_results: + - name: Eastside Cantina + description: Authentic street tacos and hand-shaken margaritas on the vibrant + east side of the city. + location: GEOPOINT(39.735, -104.885) + menu:

Tacos

Drinks

+ average_price_per_person: 18 + - name: El Sol Tacos + description: A vibrant street-side taco stand serving up quick, delicious, and + traditional Mexican street food. + location: GEOPOINT(39.6952, -105.0274) + menu:

Tacos ($3.50 each)

Beverages

+ average_price_per_person: 12 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /restaurants + name: collection + - name: search + options: + query: + functionValue: + args: + - stringValue: tacos + name: document_matches + sort: + arrayValue: + values: + - mapValue: + fields: + direction: + stringValue: descending + expression: + functionValue: + name: score diff --git a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py index d66767822ee9..6881279c665b 100644 --- a/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py +++ b/packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py @@ -28,7 +28,7 @@ from google.protobuf.json_format import MessageToDict from test__helpers import FIRESTORE_EMULATOR, FIRESTORE_ENTERPRISE_DB, system_test_lock -from google.cloud.firestore import AsyncClient, Client +from google.cloud.firestore import AsyncClient, Client, GeoPoint from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1 import pipeline_expressions as expr from google.cloud.firestore_v1 import pipeline_stages as stages @@ -72,6 +72,28 @@ def yaml_loader(field="tests", dir_name="pipeline_e2e", attach_file_name=True): combined_yaml.update(extracted) elif isinstance(combined_yaml, list) and extracted: combined_yaml.extend(extracted) + + # Validate test keys + allowed_keys = { + "description", + "pipeline", + "assert_proto", + "assert_error", + "assert_results", + "assert_count", + "assert_results_approximate", + "assert_end_state", + "file_name", + } + if field == "tests" and isinstance(combined_yaml, list): + for item in combined_yaml: + if isinstance(item, dict): + for key in item: + if key not in allowed_keys: + raise ValueError( + f"Unrecognized key '{key}' in test '{item.get('description', 'Unknown')}' in file '{item.get('file_name', 'Unknown')}'" + ) + return combined_yaml @@ -111,6 +133,34 @@ def test_pipeline_expected_errors(test_dict, client): assert match, f"error '{found_error}' does not match '{error_regex}'" +def _assert_pipeline_results( + got_results, expected_results, expected_approximate_results, expected_count +): + if expected_results: + assert got_results == expected_results + if expected_approximate_results is not None: + tolerance = 1e-4 + if ( + isinstance(expected_approximate_results, dict) + and "data" in expected_approximate_results + ): + if ( + "config" in expected_approximate_results + and "absolute_tolerance" in expected_approximate_results["config"] + ): + tolerance = expected_approximate_results["config"]["absolute_tolerance"] + expected_approximate_results = expected_approximate_results["data"] + + assert len(got_results) == len(expected_approximate_results), ( + "got unexpected result count" + ) + for idx in range(len(got_results)): + expected = expected_approximate_results[idx] + assert got_results[idx] == pytest.approx(expected, abs=tolerance) + if expected_count is not None: + assert len(got_results) == expected_count + + @pytest.mark.parametrize( "test_dict", [ @@ -136,18 +186,9 @@ def test_pipeline_results(test_dict, client): pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() for snapshot in pipeline.stream()] - if expected_results: - assert got_results == expected_results - if expected_approximate_results: - assert len(got_results) == len(expected_approximate_results), ( - "got unexpected result count" - ) - for idx in range(len(got_results)): - assert got_results[idx] == pytest.approx( - expected_approximate_results[idx], abs=1e-4 - ) - if expected_count is not None: - assert len(got_results) == expected_count + _assert_pipeline_results( + got_results, expected_results, expected_approximate_results, expected_count + ) if expected_end_state: for doc_path, expected_content in expected_end_state.items(): doc_ref = client.document(doc_path) @@ -209,18 +250,9 @@ async def test_pipeline_results_async(test_dict, async_client): pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() async for snapshot in pipeline.stream()] - if expected_results: - assert got_results == expected_results - if expected_approximate_results: - assert len(got_results) == len(expected_approximate_results), ( - "got unexpected result count" - ) - for idx in range(len(got_results)): - assert got_results[idx] == pytest.approx( - expected_approximate_results[idx], abs=1e-4 - ) - if expected_count is not None: - assert len(got_results) == expected_count + _assert_pipeline_results( + got_results, expected_results, expected_approximate_results, expected_count + ) if expected_end_state: for doc_path, expected_content in expected_end_state.items(): doc_ref = async_client.document(doc_path) @@ -395,12 +427,16 @@ def _parse_yaml_types(data): else: return [_parse_yaml_types(value) for value in data] # detect timestamps - if isinstance(data, str) and ":" in data: + if isinstance(data, str) and ":" in data and not data.startswith("GEOPOINT("): try: parsed_datetime = datetime.datetime.fromisoformat(data) return parsed_datetime except ValueError: pass + if isinstance(data, str) and data.startswith("GEOPOINT("): + match = re.match(r"GEOPOINT\(([^,]+),\s*([^)]+)\)", data) + if match: + return GeoPoint(float(match.group(1)), float(match.group(2))) if data == "NaN": return float("NaN") return data diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py index 82d89a12f978..f1408be240a7 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py @@ -403,6 +403,8 @@ def test_pipeline_execute_stream_equivalence(): ("replace_with", (Field.of("n"),), stages.ReplaceWith), ("sort", (Field.of("n").descending(),), stages.Sort), ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("search", ("my query",), stages.Search), + ("search", (stages.SearchOptions(query="my query"),), stages.Search), ("sample", (10,), stages.Sample), ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), ("union", (_make_pipeline(),), stages.Union), diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py index b285e8e4b614..d2c935a5f25d 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py @@ -790,6 +790,45 @@ def test_equal(self): infix_instance = arg1.equal(arg2) assert infix_instance == instance + def test_between(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Lower") + arg3 = self._make_arg("Upper") + instance = Expression.between(arg1, arg2, arg3) + assert instance.name == "and" + assert len(instance.params) == 2 + assert instance.params[0].name == "greater_than_or_equal" + assert instance.params[1].name == "less_than_or_equal" + assert ( + repr(instance) + == "And(Left.greater_than_or_equal(Lower), Left.less_than_or_equal(Upper))" + ) + infix_instance = arg1.between(arg2, arg3) + assert infix_instance == instance + + def test_geo_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = Expression.geo_distance(arg1, arg2) + assert instance.name == "geo_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.geo_distance(Right)" + infix_instance = arg1.geo_distance(arg2) + assert infix_instance == instance + + def test_document_matches(self): + arg1 = self._make_arg("Query") + instance = expr.DocumentMatches(arg1) + assert instance.name == "document_matches" + assert instance.params == [arg1] + assert repr(instance) == "DocumentMatches(Query)" + + def test_score(self): + instance = expr.Score() + assert instance.name == "score" + assert instance.params == [] + assert repr(instance) == "Score()" + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") diff --git a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py index b9ab603b713b..064c41c37b70 100644 --- a/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py +++ b/packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py @@ -24,6 +24,7 @@ Constant, Field, Ordering, + DocumentMatches, ) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector @@ -778,6 +779,71 @@ def test_to_pb_percent_mode(self): assert len(result_percent.options) == 0 +class TestSearch: + def test_search_defaults(self): + options = stages.SearchOptions(query="technology") + assert options.query.name == "document_matches" + assert options.limit is None + assert options.retrieval_depth is None + assert options.sort is None + assert options.add_fields is None + assert options.offset is None + assert options.language_code is None + + stage = stages.Search(options) + pb_opts = stage._pb_options() + assert "query" in pb_opts + assert "limit" not in pb_opts + assert "retrieval_depth" not in pb_opts + + def test_search_full_options(self): + options = stages.SearchOptions( + query=DocumentMatches("tech"), + limit=10, + retrieval_depth=2, + sort=Ordering("score", Ordering.Direction.DESCENDING), + add_fields=[Field("extra")], + offset=5, + language_code="en", + ) + assert options.limit == 10 + assert options.retrieval_depth == 2 + assert len(options.sort) == 1 + assert options.offset == 5 + assert options.language_code == "en" + + stage = stages.Search(options) + pb_opts = stage._pb_options() + + assert pb_opts["limit"].integer_value == 10 + assert pb_opts["retrieval_depth"].integer_value == 2 + assert len(pb_opts["sort"].array_value.values) == 1 + assert pb_opts["offset"].integer_value == 5 + assert pb_opts["language_code"].string_value == "en" + assert "query" in pb_opts + + def test_search_string_query_wrapping(self): + options = stages.SearchOptions(query="science") + assert options.query.name == "document_matches" + assert options.query.params[0].value == "science" + + def test_search_with_string(self): + stage = stages.Search("technology") + assert isinstance(stage.options, stages.SearchOptions) + assert stage.options.query.name == "document_matches" + assert stage.options.query.params[0].value == "technology" + pb_opts = stage._pb_options() + assert "query" in pb_opts + + def test_search_with_boolean_expression(self): + expr = DocumentMatches("tech") + stage = stages.Search(expr) + assert isinstance(stage.options, stages.SearchOptions) + assert stage.options.query is expr + pb_opts = stage._pb_options() + assert "query" in pb_opts + + class TestSelect: def _make_one(self, *args, **kwargs): return stages.Select(*args, **kwargs)