From 40cc50cff4b2ecc762463ca6ec3856b1e6c7be68 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Mon, 24 Mar 2025 06:58:49 +0200 Subject: [PATCH 01/18] sdks/python: enrich data with CloudSQL --- .../enrichment_handlers/cloudsql.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py new file mode 100644 index 000000000000..a1a19d664777 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from collections.abc import Callable +from enum import Enum +from typing import Any +from typing import Optional + +from google.cloud.sql.connector import Connector + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler +from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel + +__all__ = [ + 'CloudSQLEnrichmentHandler', +] + +# RowKeyFn takes beam.Row and returns tuple of (key_id, key_value). +RowKeyFn = Callable[[beam.Row], tuple[str]] + +_LOGGER = logging.getLogger(__name__) + + +class DatabaseTypeAdapter(Enum): + POSTGRESQL = "pg8000" + MYSQL = "pymysql" + SQLSERVER = "pytds" + + +class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): + """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` + transform to interact with Google Cloud SQL databases. + + Args: + project_id (str): GCP project-id of the Cloud SQL instance. + region_id (str): GCP region-id of the Cloud SQL instance. + instance_id (str): GCP instance-id of the Cloud SQL instance. + database_type_adapter (DatabaseTypeAdapter): The type of database adapter to use. + Supported adapters are: POSTGRESQL (pg8000), MYSQL (pymysql), and SQLSERVER (pytds). + database_id (str): The id of the database to connect to. + database_user (str): The username for connecting to the database. + database_password (str): The password for connecting to the database. + table_id (str): The name of the table to query. + row_key (str): Field name from the input `beam.Row` object to use as + identifier for database querying. + row_key_fn: A lambda function that returns a string key from the + input row. Used to build/extract the identifier for the database query. + exception_level: A `enum.Enum` value from + ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` + to set the level when no matching record is found from the database query. + Defaults to ``ExceptionLevel.WARN``. + """ + def __init__( + self, + region_id: str, + project_id: str, + instance_id: str, + database_type_adapter: DatabaseTypeAdapter, + database_id: str, + database_user: str, + database_password: str, + table_id: str, + row_key: str = "", + *, + row_key_fn: Optional[RowKeyFn] = None, + exception_level: ExceptionLevel = ExceptionLevel.WARN, + ): + self._project_id = project_id + self._region_id = region_id + self._instance_id = instance_id + self._database_type_adapter = database_type_adapter + self._database_id = database_id + self._database_user = database_user + self._database_password = database_password + self._table_id = table_id + self._row_key = row_key + self._row_key_fn = row_key_fn + self._exception_level = exception_level + if ((not self._row_key_fn and not self._row_key) or + bool(self._row_key_fn and self._row_key)): + raise ValueError( + "Please specify exactly one of `row_key` or a lambda " + "function with `row_key_fn` to extract the row key " + "from the input row.") + + def __enter__(self): + """Connect to the the Cloud SQL instance.""" + self.connector = Connector() + self.client = self.connector.connect( + f"{self._project_id}:{self._region_id}:{self._instance_id}", + driver=self._database_type_adapter.value, + db=self._database_id, + user=self._database_user, + password=self._database_password, + ) + self.cursor = self.client.cursor() + + def __call__(self, request: beam.Row, *args, **kwargs): + """ + Executes a query to the Cloud SQL instance and returns + a `Tuple` of request and response. + + Args: + request: the input `beam.Row` to enrich. + """ + response_dict: dict[str, Any] = {} + row_key_str: str = "" + + try: + if self._row_key_fn: + self._row_key, row_key = self._row_key_fn(request) + else: + request_dict = request._asdict() + row_key_str = str(request_dict[self._row_key]) + row_key = row_key_str + + query = f"SELECT * FROM {self._table_id} WHERE {self._row_key} = %s" + self.cursor.execute(query, (row_key, )) + result = self.cursor.fetchone() + + if result: + columns = [col[0] for col in self.cursor.description] + for i, value in enumerate(result): + response_dict[columns[i]] = value + elif self._exception_level == ExceptionLevel.WARN: + _LOGGER.warning( + 'No matching record found for row_key: %s in table: %s', + row_key_str, + self._table_id) + elif self._exception_level == ExceptionLevel.RAISE: + raise ValueError( + 'No matching record found for row_key: %s in table: %s' % + (row_key_str, self._table_id)) + except KeyError: + raise KeyError('row_key %s not found in input PCollection.' % row_key_str) + except Exception as e: + raise e + + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Clean the instantiated Cloud SQL client.""" + self.cursor.close() + self.client.close() + self.connector.close() + self.cursor, self.client, self.connector = None, None, None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with row key since it is unique to + a request made to the Cloud SQL instance.""" + if self._row_key_fn: + id, value = self._row_key_fn(request) + return f"{id}: {value}" + return f"{self._row_key}: {request._asdict()[self._row_key]}" From 507d3ec5461965783fd89363b431a4a862de918e Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Mon, 24 Mar 2025 21:15:47 +0200 Subject: [PATCH 02/18] sdks/python: unit test `CloudSQLEnrichmentHandler` --- .../enrichment_handlers/cloudsql_test.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py new file mode 100644 index 000000000000..4f2547b67fee --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from parameterized import parameterized + +try: + from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter + from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import _row_key_fn +except ImportError: + raise unittest.SkipTest('Cloud SQL test dependencies are not installed.') + + +class TestCloudSQLEnrichmentHandler(unittest.TestCase): + @parameterized.expand([('product_id', _row_key_fn), ('', None)]) + def test_cloud_sql_enrichment_invalid_args(self, row_key, row_key_fn): + with self.assertRaises(ValueError): + _ = CloudSQLEnrichmentHandler( + project_id='apache-beam-testing', + region_id='us-east1', + instance_id='beam-test', + table_id='cloudsql-enrichment-test', + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_id='', + database_user='', + database_password='', + row_key=row_key, + row_key_fn=row_key_fn) + + +if __name__ == '__main__': + unittest.main() From b4c11d3a93b177652c0b86b344ab22773346ecf1 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Tue, 25 Mar 2025 23:19:09 +0200 Subject: [PATCH 03/18] sdks/python: itest `CloudSQLEnrichmentHandler` --- .../enrichment_handlers/cloudsql_it_test.py | 396 ++++++++++++++++++ 1 file changed, 396 insertions(+) create mode 100644 sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py new file mode 100644 index 000000000000..63b9ee2d17b9 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -0,0 +1,396 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import unittest +from unittest.mock import MagicMock +import pytest +import apache_beam as beam +from apache_beam.coders import coders +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import BeamAssertException +from apache_beam.transforms.enrichment import Enrichment +from apache_beam.transforms.enrichment_handlers.cloudsql import ( + CloudSQLEnrichmentHandler, + DatabaseTypeAdapter, + ExceptionLevel, +) +from testcontainers.redis import RedisContainer +from google.cloud.sql.connector import Connector +import os + +_LOGGER = logging.getLogger(__name__) + + +def _row_key_fn(request: beam.Row, key_id="product_id") -> tuple[str]: + key_value = str(getattr(request, key_id)) + return (key_id, key_value) + + +class ValidateResponse(beam.DoFn): + """ValidateResponse validates if a PCollection of `beam.Row` + has the required fields.""" + def __init__( + self, + n_fields: int, + fields: list[str], + enriched_fields: dict[str, list[str]], + ): + self.n_fields = n_fields + self._fields = fields + self._enriched_fields = enriched_fields + + def process(self, element: beam.Row, *args, **kwargs): + element_dict = element.as_dict() + if len(element_dict.keys()) != self.n_fields: + raise BeamAssertException( + "Expected %d fields in enriched PCollection:" % self.n_fields) + + for field in self._fields: + if field not in element_dict or element_dict[field] is None: + raise BeamAssertException(f"Expected a not None field: {field}") + + for key in self._enriched_fields: + if key not in element_dict: + raise BeamAssertException( + f"Response from Cloud SQL should contain {key} column.") + + +def create_rows(cursor): + """Insert test rows into the Cloud SQL database table.""" + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS products ( + product_id SERIAL PRIMARY KEY, + product_name VARCHAR(255), + product_stock INT + ) + """) + cursor.execute( + """ + INSERT INTO products (product_name, product_stock) + VALUES + ('pixel 5', 2), + ('pixel 6', 4), + ('pixel 7', 20), + ('pixel 8', 10), + ('iphone 11', 3), + ('iphone 12', 7), + ('iphone 13', 8), + ('iphone 14', 3) + ON CONFLICT DO NOTHING + """) + + +@pytest.mark.uses_testcontainer +class TestCloudSQLEnrichment(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.project_id = "apache-beam-testing" + cls.region_id = "us-central1" + cls.instance_id = "beam-test" + cls.database_id = "postgres" + cls.database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER") + cls.database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD") + cls.table_id = "products" + cls.row_key = "product_id" + cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL + cls.req = [ + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), + beam.Row(sale_id=5, customer_id=5, product_id=3, quantity=2), + beam.Row(sale_id=7, customer_id=7, product_id=4, quantity=1), + ] + cls.connector = Connector() + cls.client = cls.connector.connect( + f"{cls.project_id}:{cls.region_id}:{cls.instance_id}", + driver=cls.database_type_adapter.value, + db=cls.database_id, + user=cls.database_user, + password=cls.database_password, + ) + cls.cursor = cls.client.cursor() + create_rows(cls.cursor) + cls.cache_client_retries = 3 + + def _start_cache_container(self): + for i in range(self.cache_client_retries): + try: + self.container = RedisContainer(image="redis:7.2.4") + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.cache_client = self.container.get_client() + break + except Exception as e: + if i == self.cache_client_retries - 1: + _LOGGER.error( + f"Unable to start redis container for RRIO tests after {self.cache_client_retries} retries." + ) + raise e + + @classmethod + def tearDownClass(cls): + cls.cursor.close() + cls.client.close() + cls.connector.close() + cls.cursor, cls.client, cls.connector = None, None, None + + def test_enrichment_with_cloudsql(self): + expected_fields = [ + "sale_id", + "customer_id", + "product_id", + "quantity", + "product_name", + "product_stock", + ] + expected_enriched_fields = ["product_id", "product_name", "product_stock"] + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key=self.row_key, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields, + ))) + + def test_enrichment_with_cloudsql_no_enrichment(self): + expected_fields = ["sale_id", "customer_id", "product_id", "quantity"] + expected_enriched_fields = {} + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key=self.row_key, + ) + req = [beam.Row(sale_id=1, customer_id=1, product_id=99, quantity=1)] + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields, + ))) + + def test_enrichment_with_cloudsql_raises_key_error(self): + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key="car_name", + ) + with self.assertRaises(KeyError): + test_pipeline = TestPipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_enrichment_with_cloudsql_raises_not_found(self): + """Raises a database error when the GCP Cloud SQL table doesn't exist.""" + table_id = "invalid_table" + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=table_id, + row_key=self.row_key, + ) + try: + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) + res = test_pipeline.run() + res.wait_until_finish() + except (PgDatabaseError, RuntimeError) as e: + self.assertIn(f'relation "{table_id}" does not exist', str(e)) + + def test_enrichment_with_cloudsql_exception_level(self): + """raises a `ValueError` exception when the GCP Cloud SQL query returns + an empty row.""" + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key=self.row_key, + exception_level=ExceptionLevel.RAISE, + ) + req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)] + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_cloudsql_enrichment_with_lambda(self): + expected_fields = [ + "sale_id", + "customer_id", + "product_id", + "quantity", + "product_name", + "product_stock", + ] + expected_enriched_fields = ["product_id", "product_name", "product_stock"] + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key_fn=_row_key_fn, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create" >> beam.Create(self.req) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + + @pytest.fixture + def cache_container(self): + # Setup phase: start the container. + self._start_cache_container() + + # Hand control to the test. + yield + + # Cleanup phase: stop the container. It runs after the test completion + # even if it failed. + self.container.stop() + self.container = None + + @pytest.mark.usefixtures("cache_container") + def test_cloudsql_enrichment_with_redis(self): + expected_fields = [ + "sale_id", + "customer_id", + "product_id", + "quantity", + "product_name", + "product_stock", + ] + expected_enriched_fields = ["product_id", "product_name", "product_stock"] + cloudsql = CloudSQLEnrichmentHandler( + region_id=self.region_id, + project_id=self.project_id, + instance_id=self.instance_id, + database_type_adapter=self.database_type_adapter, + database_id=self.database_id, + database_user=self.database_user, + database_password=self.database_password, + table_id=self.table_id, + row_key_fn=_row_key_fn, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create1" >> beam.Create(self.req) + | "Enrich W/ CloudSQL1" >> Enrichment(cloudsql).with_redis_cache( + self.host, self.port, 300) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields, + ))) + + # Manually check cache entry to verify entries were correctly stored. + c = coders.StrUtf8Coder() + for req in self.req: + key = cloudsql.get_cache_key(req) + response = self.cache_client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + # Mock the CloudSQL handler to avoid actual database calls. + # This simulates a cache hit scenario by returning predefined data. + actual = CloudSQLEnrichmentHandler.__call__ + CloudSQLEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row(), + )) + + # Run a second pipeline to verify cache is being used. + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | "Create2" >> beam.Create(self.req) + | "Enrich W/ CloudSQL2" >> Enrichment(cloudsql).with_redis_cache( + self.host, self.port) + | "Validate Response" >> beam.ParDo( + ValidateResponse( + len(expected_fields), + expected_fields, + expected_enriched_fields))) + CloudSQLEnrichmentHandler.__call__ = actual + + +if __name__ == "__main__": + unittest.main() From 43943b1adef6a9ddcbdce245ab416ff4a853679f Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Wed, 26 Mar 2025 00:59:47 +0200 Subject: [PATCH 04/18] website+sdks: doc `CloudSQLEnrichmentHandler` --- .../transforms/elementwise/enrichment.py | 40 +++++++++++++++++++ .../transforms/elementwise/enrichment_test.py | 14 +++++++ .../python/elementwise/enrichment.md | 1 + 3 files changed, 55 insertions(+) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index acee633b6f67..c88696bbc24d 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -50,6 +50,46 @@ def enrichment_with_bigtable(): | "Print" >> beam.Map(print)) # [END enrichment_with_bigtable] +def enrichment_with_cloudsql(): + # [START enrichment_with_cloudsql] + import apache_beam as beam + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter + import os + + project_id = 'apache-beam-testing' + region_id = 'us-east1' + instance_id = 'beam-test' + table_id = 'cloudsql-enrichment-test' + database_id = 'test-database' + database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER") + database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD") + row_key = 'product_id' + + data = [ + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), + beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), + beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2), + ] + + cloudsql_handler = CloudSQLEnrichmentHandler( + project_id=project_id, + region_id=region_id, + instance_id=instance_id, + table_id=table_id, + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_id=database_id, + database_user=database_user, + database_password=database_password, + row_key=row_key + ) + with beam.Pipeline() as p: + _ = ( + p + | "Create" >> beam.Create(data) + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql_handler) + | "Print" >> beam.Map(print)) + # [END enrichment_with_cloudsql] def enrichment_with_vertex_ai(): # [START enrichment_with_vertex_ai] diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 8a7cdfbe9263..7086261c254e 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -41,6 +41,13 @@ def validate_enrichment_with_bigtable(): [END enrichment_with_bigtable]'''.splitlines()[1:-1] return expected +def validate_enrichment_with_cloudsql(): + expected = '''[START enrichment_with_cloudsql] +Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) +Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'}) +Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'}) + [END enrichment_with_cloudsql]'''.splitlines()[1:-1] + return expected def validate_enrichment_with_vertex_ai(): expected = '''[START enrichment_with_vertex_ai] @@ -68,6 +75,13 @@ def test_enrichment_with_bigtable(self, mock_stdout): expected = validate_enrichment_with_bigtable() self.assertEqual(output, expected) + def test_enrichment_with_cloudsql(self, mock_stdout): + from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_cloudsql + enrichment_with_cloudsql() + output = mock_stdout.getvalue().splitlines() + expected = validate_enrichment_with_cloudsql() + self.assertEqual(output, expected) + def test_enrichment_with_vertex_ai(self, mock_stdout): enrichment_with_vertex_ai() output = mock_stdout.getvalue().splitlines() diff --git a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md index 6c05b6b515a4..0993963ec057 100644 --- a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md +++ b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md @@ -42,6 +42,7 @@ The following examples demonstrate how to create a pipeline that use the enrichm | Service | Example | |:-----------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Cloud Bigtable | [Enrichment with Bigtable](/documentation/transforms/python/elementwise/enrichment-bigtable/#example) | +| Cloud SQL | [Enrichment with CloudSQL](/documentation/transforms/python/elementwise/enrichment-cloudsql/#example) | | Vertex AI Feature Store | [Enrichment with Vertex AI Feature Store](/documentation/transforms/python/elementwise/enrichment-vertexai/#example-1-enrichment-with-vertex-ai-feature-store) | | Vertex AI Feature Store (Legacy) | [Enrichment with Legacy Vertex AI Feature Store](/documentation/transforms/python/elementwise/enrichment-vertexai/#example-2-enrichment-with-vertex-ai-feature-store-legacy) | {{< /table >}} From 846c30dbd7b9b6ed57fdd768460ad1c17ca39ee0 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 11 Apr 2025 14:10:03 +0200 Subject: [PATCH 05/18] sdks/python: address claudevdm feedback (1) --- .../enrichment_handlers/cloudsql.py | 382 +++++++++++++----- 1 file changed, 270 insertions(+), 112 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index a1a19d664777..77ab9889c8da 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -14,156 +14,314 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import logging from collections.abc import Callable +from collections.abc import Mapping from enum import Enum from typing import Any from typing import Optional +from typing import Union -from google.cloud.sql.connector import Connector +from sqlalchemy import create_engine, text import apache_beam as beam from apache_beam.transforms.enrichment import EnrichmentSourceHandler from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel -__all__ = [ - 'CloudSQLEnrichmentHandler', -] +QueryFn = Callable[[beam.Row], str] +ConditionValueFn = Callable[[beam.Row], list[Any]] -# RowKeyFn takes beam.Row and returns tuple of (key_id, key_value). -RowKeyFn = Callable[[beam.Row], tuple[str]] -_LOGGER = logging.getLogger(__name__) +def _validate_cloudsql_metadata( + table_id, + where_clause_template, + where_clause_fields, + where_clause_value_fn, + query_fn): + if query_fn: + if any([table_id, + where_clause_template, + where_clause_fields, + where_clause_value_fn]): + raise ValueError( + "Please provide either `query_fn` or the parameters `table_id`, " + "`where_clause_template`, and `where_clause_fields/where_clause_value_fn` " + "together.") + else: + if not (table_id and where_clause_template): + raise ValueError( + "Please provide either `query_fn` or the parameters " + "`table_id` and `where_clause_template` together.") + if (bool(where_clause_fields) == bool(where_clause_value_fn)): + raise ValueError( + "Please provide exactly one of `where_clause_fields` or " + "`where_clause_value_fn`.") class DatabaseTypeAdapter(Enum): - POSTGRESQL = "pg8000" + POSTGRESQL = "psycopg2" MYSQL = "pymysql" SQLSERVER = "pytds" + def to_sqlalchemy_dialect(self): + """ + Map the adapter type to its corresponding SQLAlchemy dialect. + Returns: + str: SQLAlchemy dialect string. + """ + if self == DatabaseTypeAdapter.POSTGRESQL: + return f"postgresql+{self.value}" + elif self == DatabaseTypeAdapter.MYSQL: + return f"mysql+{self.value}" + elif self == DatabaseTypeAdapter.SQLSERVER: + return f"mssql+{self.value}" + else: + raise ValueError(f"Unsupported adapter type: {self.name}") + class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): - """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` - transform to interact with Google Cloud SQL databases. - - Args: - project_id (str): GCP project-id of the Cloud SQL instance. - region_id (str): GCP region-id of the Cloud SQL instance. - instance_id (str): GCP instance-id of the Cloud SQL instance. - database_type_adapter (DatabaseTypeAdapter): The type of database adapter to use. - Supported adapters are: POSTGRESQL (pg8000), MYSQL (pymysql), and SQLSERVER (pytds). - database_id (str): The id of the database to connect to. - database_user (str): The username for connecting to the database. - database_password (str): The password for connecting to the database. - table_id (str): The name of the table to query. - row_key (str): Field name from the input `beam.Row` object to use as - identifier for database querying. - row_key_fn: A lambda function that returns a string key from the - input row. Used to build/extract the identifier for the database query. - exception_level: A `enum.Enum` value from - ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` - to set the level when no matching record is found from the database query. - Defaults to ``ExceptionLevel.WARN``. + """ + Enrichment handler for Cloud SQL databases. + + This handler is designed to work with the + :class:`apache_beam.transforms.enrichment.Enrichment` transform. + + To use this handler, you need to provide either of the following combinations: + * `table_id`, `where_clause_template`, `where_clause_fields` + * `table_id`, `where_clause_template`, `where_clause_value_fn` + * `query_fn` + + By default, the handler retrieves all columns from the specified table. + To limit the columns, use the `column_names` parameter to specify + the desired column names. + + This handler queries the Cloud SQL database per element by default. + To enable batching, set the `min_batch_size` and `max_batch_size` parameters. + These values control the batching behavior in the + :class:`apache_beam.transforms.utils.BatchElements` transform. + + NOTE: Batching is not supported when using the `query_fn` parameter. """ def __init__( self, - region_id: str, - project_id: str, - instance_id: str, database_type_adapter: DatabaseTypeAdapter, - database_id: str, + database_address: str, database_user: str, database_password: str, - table_id: str, - row_key: str = "", + database_id: str, *, - row_key_fn: Optional[RowKeyFn] = None, - exception_level: ExceptionLevel = ExceptionLevel.WARN, + table_id: str = "", + where_clause_template: str = "", + where_clause_fields: Optional[list[str]] = None, + where_clause_value_fn: Optional[ConditionValueFn] = None, + query_fn: Optional[QueryFn] = None, + column_names: Optional[list[str]] = None, + min_batch_size: int = 1, + max_batch_size: int = 10000, + **kwargs, ): - self._project_id = project_id - self._region_id = region_id - self._instance_id = instance_id + """ + Example Usage: + handler = CloudSQLEnrichmentHandler( + database_type_adapter=adapter, + database_address='127.0.0.1:5432', + database_user='user', + database_password='password', + database_id='my_database', + table_id='my_table', + where_clause_template="id = '{}'", + where_clause_fields=['id'], + min_batch_size=2, + max_batch_size=100 + ) + + Args: + database_type_adapter: Adapter to handle specific database type operations + (e.g., MySQL, PostgreSQL). + database_address (str): Address or hostname of the Cloud SQL database, in + the form `:`. The port is optional if the database uses + the default port. + database_user (str): Username for accessing the database. + database_password (str): Password for accessing the database. + database_id (str): Identifier for the database to query. + table_id (str): Name of the table to query in the Cloud SQL database. + where_clause_template (str): A template string for the `WHERE` clause + in the SQL query with placeholders (`{}`) for dynamic filtering + based on input data. + where_clause_fields (Optional[list[str]]): List of field names from the input + `beam.Row` used to construct the `WHERE` clause if `where_clause_value_fn` + is not provided. + where_clause_value_fn (Optional[Callable[[beam.Row], Any]]): Function that + takes a `beam.Row` and returns a list of values to populate the + placeholders `{}` in the `WHERE` clause. + query_fn (Optional[Callable[[beam.Row], str]]): Function that takes a + `beam.Row` and returns a complete SQL query string. + column_names (Optional[list[str]]): List of column names to select from the + Cloud SQL table. If not provided, all columns (`*`) are selected. + min_batch_size (int): Minimum number of rows to batch together when + querying the database. Defaults to 1 if `query_fn` is not used. + max_batch_size (int): Maximum number of rows to batch together. Defaults + to 10,000 if `query_fn` is not used. + **kwargs: Additional keyword arguments for database connection or query handling. + + Note: + * `min_batch_size` and `max_batch_size` cannot be used if `query_fn` is provided. + * Either `where_clause_fields` or `where_clause_value_fn` must be provided + for query construction if `query_fn` is not provided. + * Ensure that the database user has the necessary permissions to query the + specified table. + """ + _validate_cloudsql_metadata( + table_id, + where_clause_template, + where_clause_fields, + where_clause_value_fn, + query_fn) self._database_type_adapter = database_type_adapter self._database_id = database_id self._database_user = database_user self._database_password = database_password + self._database_address = database_address self._table_id = table_id - self._row_key = row_key - self._row_key_fn = row_key_fn - self._exception_level = exception_level - if ((not self._row_key_fn and not self._row_key) or - bool(self._row_key_fn and self._row_key)): - raise ValueError( - "Please specify exactly one of `row_key` or a lambda " - "function with `row_key_fn` to extract the row key " - "from the input row.") + self._where_clause_template = where_clause_template + self._where_clause_fields = where_clause_fields + self._where_clause_value_fn = where_clause_value_fn + self._query_fn = query_fn + self._column_names = ",".join(column_names) if column_names else "*" + self.query_template = f"SELECT {self._column_names} FROM {self._table_id} WHERE {self._where_clause_template}" + self.kwargs = kwargs + self._batching_kwargs = {} + if not query_fn: + self._batching_kwargs['min_batch_size'] = min_batch_size + self._batching_kwargs['max_batch_size'] = max_batch_size def __enter__(self): - """Connect to the the Cloud SQL instance.""" - self.connector = Connector() - self.client = self.connector.connect( - f"{self._project_id}:{self._region_id}:{self._instance_id}", - driver=self._database_type_adapter.value, - db=self._database_id, - user=self._database_user, - password=self._database_password, - ) - self.cursor = self.client.cursor() - - def __call__(self, request: beam.Row, *args, **kwargs): - """ - Executes a query to the Cloud SQL instance and returns - a `Tuple` of request and response. + db_url = self._get_db_url() + self._engine = create_engine(db_url) + self._connection = self._engine.connect() - Args: - request: the input `beam.Row` to enrich. - """ - response_dict: dict[str, Any] = {} - row_key_str: str = "" + def _get_db_url(self) -> str: + dialect = self._database_type_adapter.to_sqlalchemy_dialect() + string = f"{dialect}://{self._database_user}:{self._database_password}@{self._database_address}/{self._database_id}" + return string + def _execute_query(self, query: str, is_batch: bool, **params): try: - if self._row_key_fn: - self._row_key, row_key = self._row_key_fn(request) + result = self._connection.execute(text(query), **params) + if is_batch: + return [row._asdict() for row in result] else: - request_dict = request._asdict() - row_key_str = str(request_dict[self._row_key]) - row_key = row_key_str - - query = f"SELECT * FROM {self._table_id} WHERE {self._row_key} = %s" - self.cursor.execute(query, (row_key, )) - result = self.cursor.fetchone() - - if result: - columns = [col[0] for col in self.cursor.description] - for i, value in enumerate(result): - response_dict[columns[i]] = value - elif self._exception_level == ExceptionLevel.WARN: - _LOGGER.warning( - 'No matching record found for row_key: %s in table: %s', - row_key_str, - self._table_id) - elif self._exception_level == ExceptionLevel.RAISE: - raise ValueError( - 'No matching record found for row_key: %s in table: %s' % - (row_key_str, self._table_id)) - except KeyError: - raise KeyError('row_key %s not found in input PCollection.' % row_key_str) - except Exception as e: - raise e - - return request, beam.Row(**response_dict) + return result.first()._asdict() + except RuntimeError as e: + raise RuntimeError( + f'Could not execute the query: {query}. Please check if ' + f'the query is properly formatted and the BigQuery ' + f'table exists. {e}') + + def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + if isinstance(request, list): + values, responses = [], [] + requests_map: dict[Any, Any] = {} + batch_size = len(request) + raw_query = self.query_template + + # For multiple requests in the batch, combine the WHERE clause conditions + # using 'OR' and update the query template to handle all requests. + if batch_size > 1: + where_clause_template_batched = ' OR '.join( + [fr'({self._where_clause_template})'] * batch_size) + raw_query = self.query_template.replace( + self._where_clause_template, where_clause_template_batched) + + # Extract where_clause_fields values and map the generated request key to + # the original request object. + for req in request: + request_dict = req._asdict() + try: + current_values = ( + self._where_clause_value_fn(req) if self._where_clause_value_fn + else [request_dict[field] for field in self._where_clause_fields]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are the " + "keys in the input `beam.Row`." + str(e)) + values.extend(current_values) + requests_map[self.create_row_key(req)] = req + + # Formulate the query, execute it, and return a list of original requests + # paired with their responses. + query = raw_query.format(*values) + responses_dict = self._execute_query(query, is_batch=True) + for response in responses_dict: + response_row = beam.Row(**response) + response_key = self.create_row_key(response_row) + if response_key in requests_map: + responses.append((requests_map[response_key], response_row)) + return responses + else: + request_dict = request._asdict() + if self._query_fn: + query = self._query_fn(request) + else: + try: + values = ( + self._where_clause_value_fn(request) + if self._where_clause_value_fn else + [request_dict[field] for field in self._where_clause_fields]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are the " + "keys in the input `beam.Row`." + str(e)) + query = self.query_template.format(*values) + response_dict = self._execute_query(query, is_batch=False) + return request, beam.Row(**response_dict) + + def create_row_key(self, row: beam.Row): + if self._where_clause_value_fn: + return tuple(self._where_clause_value_fn(row)) + if self._where_clause_fields: + row_dict = row._asdict() + return ( + tuple( + row_dict[where_clause_field] + for where_clause_field in self._where_clause_fields)) + raise ValueError( + "Either where_clause_fields or where_clause_value_fn must be specified") def __exit__(self, exc_type, exc_val, exc_tb): - """Clean the instantiated Cloud SQL client.""" - self.cursor.close() - self.client.close() - self.connector.close() - self.cursor, self.client, self.connector = None, None, None - - def get_cache_key(self, request: beam.Row) -> str: - """Returns a string formatted with row key since it is unique to - a request made to the Cloud SQL instance.""" - if self._row_key_fn: - id, value = self._row_key_fn(request) - return f"{id}: {value}" - return f"{self._row_key}: {request._asdict()[self._row_key]}" + self._connection.close() + self._engine.dispose(close=True) + self._engine, self._connection = None, None + + def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(request, list): + cache_keys = [] + for req in request: + req_dict = req._asdict() + try: + current_values = ( + self._where_clause_value_fn(req) if self._where_clause_value_fn + else [req_dict[field] for field in self._where_clause_fields]) + key = ";".join(["%s"] * len(current_values)) + cache_keys.extend([key % tuple(current_values)]) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are the " + "keys in the input `beam.Row`." + str(e)) + return cache_keys + else: + req_dict = request._asdict() + try: + current_values = ( + self._where_clause_value_fn(request) if self._where_clause_value_fn + else [req_dict[field] for field in self._where_clause_fields]) + key = ";".join(["%s"] * len(current_values)) + cache_key = key % tuple(current_values) + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are the " + "keys in the input `beam.Row`." + str(e)) + return cache_key + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + """Returns a kwargs suitable for `beam.BatchElements`.""" + return self._batching_kwargs From 9a1ed453fa9bf5f1fb2dcc7fb47bf05b21e03fae Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 11 Apr 2025 14:12:22 +0200 Subject: [PATCH 06/18] sdks/python: address claudevdm feedback (1)(test) --- .../enrichment_handlers/cloudsql_test.py | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py index 4f2547b67fee..2fa74bf33233 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -18,28 +18,52 @@ from parameterized import parameterized +# pylint: disable=ungrouped-imports try: from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter - from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import _row_key_fn + from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import where_clause_value_fn + from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import query_fn except ImportError: - raise unittest.SkipTest('Cloud SQL test dependencies are not installed.') + raise unittest.SkipTest('Google Cloud SQL dependencies are not installed.') -class TestCloudSQLEnrichmentHandler(unittest.TestCase): - @parameterized.expand([('product_id', _row_key_fn), ('', None)]) - def test_cloud_sql_enrichment_invalid_args(self, row_key, row_key_fn): +class TestCloudSQLEnrichment(unittest.TestCase): + @parameterized.expand([ + ("", "", [], None, None, 1, 2), + ("table", "", ["id"], where_clause_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], where_clause_value_fn, None, 2, 10), + ("table", "id='{}'", ["id"], None, query_fn, 2, 10), + ]) + def test_valid_params( + self, + table_id, + where_clause_template, + where_clause_fields, + where_clause_value_fn, + query_fn, + min_batch_size, + max_batch_size): + """ + TC 1: Only batch size are provided. It should raise an error. + TC 2: Either of `where_clause_template` or `query_fn` is not provided. + TC 3: Both `where_clause_fields` and `where_clause_value_fn` are provided. + TC 4: Query construction details are provided along with `query_fn`. + """ with self.assertRaises(ValueError): _ = CloudSQLEnrichmentHandler( - project_id='apache-beam-testing', - region_id='us-east1', - instance_id='beam-test', - table_id='cloudsql-enrichment-test', database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, - database_id='', + database_address='', database_user='', database_password='', - row_key=row_key, - row_key_fn=row_key_fn) + database_id='', + table_id=table_id, + where_clause_template=where_clause_template, + where_clause_fields=where_clause_fields, + where_clause_value_fn=where_clause_value_fn, + query_fn=query_fn, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + ) if __name__ == '__main__': From 9261459ae1d5d43777d3fd9b9289c0ff81547200 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 11 Apr 2025 14:13:10 +0200 Subject: [PATCH 07/18] sdks/python: address claudevdm feedback (1)(itest) --- .../enrichment_handlers/cloudsql_it_test.py | 639 +++++++++--------- 1 file changed, 330 insertions(+), 309 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index 63b9ee2d17b9..c7939b09c07d 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -14,381 +14,402 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import functools import logging import unittest from unittest.mock import MagicMock + import pytest + import apache_beam as beam from apache_beam.coders import coders from apache_beam.testing.test_pipeline import TestPipeline -from apache_beam.testing.util import BeamAssertException -from apache_beam.transforms.enrichment import Enrichment -from apache_beam.transforms.enrichment_handlers.cloudsql import ( - CloudSQLEnrichmentHandler, - DatabaseTypeAdapter, - ExceptionLevel, -) -from testcontainers.redis import RedisContainer -from google.cloud.sql.connector import Connector -import os +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=ungrouped-imports +try: + from testcontainers.postgres import PostgresContainer + from testcontainers.redis import RedisContainer + from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.cloudsql import ( + CloudSQLEnrichmentHandler, DatabaseTypeAdapter, DatabaseTypeAdapter) +except ImportError: + raise unittest.SkipTest('Google Cloud SQL dependencies are not installed.') _LOGGER = logging.getLogger(__name__) -def _row_key_fn(request: beam.Row, key_id="product_id") -> tuple[str]: - key_value = str(getattr(request, key_id)) - return (key_id, key_value) - - -class ValidateResponse(beam.DoFn): - """ValidateResponse validates if a PCollection of `beam.Row` - has the required fields.""" - def __init__( - self, - n_fields: int, - fields: list[str], - enriched_fields: dict[str, list[str]], - ): - self.n_fields = n_fields - self._fields = fields - self._enriched_fields = enriched_fields - - def process(self, element: beam.Row, *args, **kwargs): - element_dict = element.as_dict() - if len(element_dict.keys()) != self.n_fields: - raise BeamAssertException( - "Expected %d fields in enriched PCollection:" % self.n_fields) - - for field in self._fields: - if field not in element_dict or element_dict[field] is None: - raise BeamAssertException(f"Expected a not None field: {field}") - - for key in self._enriched_fields: - if key not in element_dict: - raise BeamAssertException( - f"Response from Cloud SQL should contain {key} column.") - - -def create_rows(cursor): - """Insert test rows into the Cloud SQL database table.""" - cursor.execute( - """ - CREATE TABLE IF NOT EXISTS products ( - product_id SERIAL PRIMARY KEY, - product_name VARCHAR(255), - product_stock INT +def where_clause_value_fn(row: beam.Row): + return [row.id] # type: ignore[attr-defined] + + +def query_fn(table, row: beam.Row): + return f"SELECT * FROM `{table}` WHERE id = {row.id}" # type: ignore[attr-defined] + + +@pytest.mark.uses_testcontainer +class CloudSQLEnrichmentIT(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._sql_client_retries = 3 + cls._start_sql_db_container() + + @classmethod + def tearDownClass(cls): + cls._stop_sql_db_container() + + @classmethod + def _start_sql_db_container(cls): + for i in range(cls._sql_client_retries): + try: + cls._sql_db_container = PostgresContainer(image="postgres:16") + cls._sql_db_container.start() + cls.sql_db_container_host = cls._sql_db_container.get_container_host_ip( ) - """) - cursor.execute( - """ - INSERT INTO products (product_name, product_stock) - VALUES - ('pixel 5', 2), - ('pixel 6', 4), - ('pixel 7', 20), - ('pixel 8', 10), - ('iphone 11', 3), - ('iphone 12', 7), - ('iphone 13', 8), - ('iphone 14', 3) - ON CONFLICT DO NOTHING - """) + cls.sql_db_container_port = cls._sql_db_container.get_exposed_port(5432) + cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL + cls.sql_db_user, cls.sql_db_password, cls.sql_db_id = "test", "test", "test" + _LOGGER.info( + f"PostgreSQL container started successfully on {cls.get_db_address()}." + ) + break + except Exception as e: + _LOGGER.warning( + f"Retry {i + 1}/{cls._sql_client_retries}: Failed to start PostgreSQL container. Reason: {e}" + ) + if i == cls._sql_client_retries - 1: + _LOGGER.error( + f"Unable to start PostgreSQL container for IO tests after {cls._sql_client_retries} retries. Tests cannot proceed." + ) + raise e + + @classmethod + def _stop_sql_db_container(cls): + try: + _LOGGER.info("Stopping PostgreSQL container.") + cls._sql_db_container.stop() + cls._sql_db_container = None + _LOGGER.info("PostgreSQL container stopped successfully.") + except Exception as e: + _LOGGER.warning( + f"Error encountered while stopping PostgreSQL container: {e}") + + @classmethod + def get_db_address(cls): + return f"{cls.sql_db_container_host}:{cls.sql_db_container_port}" @pytest.mark.uses_testcontainer -class TestCloudSQLEnrichment(unittest.TestCase): +class TestCloudSQLEnrichment(CloudSQLEnrichmentIT): + _table_id = "product_details" + _table_data = [ + { + "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 + }, + { + "id": 2, "name": "B", 'quantity': 3, 'distribution_center_id': 1 + }, + { + "id": 3, "name": "C", 'quantity': 10, 'distribution_center_id': 4 + }, + { + "id": 4, "name": "D", 'quantity': 1, 'distribution_center_id': 3 + }, + { + "id": 5, "name": "C", 'quantity': 100, 'distribution_center_id': 4 + }, + { + "id": 6, "name": "D", 'quantity': 11, 'distribution_center_id': 3 + }, + { + "id": 7, "name": "C", 'quantity': 7, 'distribution_center_id': 1 + }, + { + "id": 8, "name": "D", 'quantity': 4, 'distribution_center_id': 1 + }, + ] + @classmethod def setUpClass(cls): - cls.project_id = "apache-beam-testing" - cls.region_id = "us-central1" - cls.instance_id = "beam-test" - cls.database_id = "postgres" - cls.database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER") - cls.database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD") - cls.table_id = "products" - cls.row_key = "product_id" - cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL - cls.req = [ - beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), - beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), - beam.Row(sale_id=5, customer_id=5, product_id=3, quantity=2), - beam.Row(sale_id=7, customer_id=7, product_id=4, quantity=1), - ] - cls.connector = Connector() - cls.client = cls.connector.connect( - f"{cls.project_id}:{cls.region_id}:{cls.instance_id}", - driver=cls.database_type_adapter.value, - db=cls.database_id, - user=cls.database_user, - password=cls.database_password, + super(TestCloudSQLEnrichment, cls).setUpClass() + cls.create_table(cls._table_id) + cls._cache_client_retries = 3 + + @classmethod + def create_table(cls, table_id): + cls._engine = create_engine(cls._get_db_url()) + + # Define the table schema. + metadata = MetaData() + table = Table( + table_id, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("quantity", Integer, nullable=False), + Column("distribution_center_id", Integer, nullable=False), ) - cls.cursor = cls.client.cursor() - create_rows(cls.cursor) - cls.cache_client_retries = 3 + + # Create the table in the database. + metadata.create_all(cls._engine) + + # Insert data into the table. + with cls._engine.connect() as connection: + transaction = connection.begin() + try: + connection.execute(table.insert(), cls._table_data) + transaction.commit() + except Exception as e: + transaction.rollback() + raise e + + @classmethod + def _get_db_url(cls): + dialect = cls.database_type_adapter.to_sqlalchemy_dialect() + db_url = f"{dialect}://{cls.sql_db_user}:{cls.sql_db_password}@{cls.get_db_address()}/{cls.sql_db_id}" + return db_url + + @pytest.fixture + def cache_container(self): + self._start_cache_container() + + # Hand control to the test. + yield + + self._cache_container.stop() + self._cache_container = None def _start_cache_container(self): - for i in range(self.cache_client_retries): + for i in range(self._cache_client_retries): try: - self.container = RedisContainer(image="redis:7.2.4") - self.container.start() - self.host = self.container.get_container_host_ip() - self.port = self.container.get_exposed_port(6379) - self.cache_client = self.container.get_client() + self._cache_container = RedisContainer(image="redis:7.2.4") + self._cache_container.start() + self._cache_container_host = self._cache_container.get_container_host_ip( + ) + self._cache_container_port = self._cache_container.get_exposed_port( + 6379) + self._cache_client = self._cache_container.get_client() break except Exception as e: - if i == self.cache_client_retries - 1: + if i == self._cache_client_retries - 1: _LOGGER.error( - f"Unable to start redis container for RRIO tests after {self.cache_client_retries} retries." + f"Unable to start redis container for RRIO tests after {self._cache_client_retries} retries." ) raise e @classmethod def tearDownClass(cls): - cls.cursor.close() - cls.client.close() - cls.connector.close() - cls.cursor, cls.client, cls.connector = None, None, None - - def test_enrichment_with_cloudsql(self): - expected_fields = [ - "sale_id", - "customer_id", - "product_id", - "quantity", - "product_name", - "product_stock", + cls._engine.dispose(close=True) + super(TestCloudSQLEnrichment, cls).tearDownClass() + cls._engine = None + + def test_cloudsql_enrichment(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] - expected_enriched_fields = ["product_id", "product_name", "product_stock"] - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + fields = ['id'] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key=self.row_key, + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_fields=fields, + min_batch_size=1, + max_batch_size=100, ) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | "Create" >> beam.Create(self.req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) - | "Validate Response" >> beam.ParDo( - ValidateResponse( - len(expected_fields), - expected_fields, - expected_enriched_fields, - ))) - - def test_enrichment_with_cloudsql_no_enrichment(self): - expected_fields = ["sale_id", "customer_id", "product_id", "quantity"] - expected_enriched_fields = {} - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_cloudsql_enrichment_batched(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + fields = ['id'] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key=self.row_key, + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_fields=fields, + min_batch_size=2, + max_batch_size=100, ) - req = [beam.Row(sale_id=1, customer_id=1, product_id=99, quantity=1)] with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( - test_pipeline - | "Create" >> beam.Create(req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) - | "Validate Response" >> beam.ParDo( - ValidateResponse( - len(expected_fields), - expected_fields, - expected_enriched_fields, - ))) - - def test_enrichment_with_cloudsql_raises_key_error(self): - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_cloudsql_enrichment_batched_multiple_fields(self): + expected_rows = [ + beam.Row(id=1, distribution_center_id=3, name="A", quantity=2), + beam.Row(id=2, distribution_center_id=1, name="B", quantity=3) + ] + fields = ['id', 'distribution_center_id'] + requests = [ + beam.Row(id=1, distribution_center_id=3), + beam.Row(id=2, distribution_center_id=1), + ] + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key="car_name", + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {} AND distribution_center_id = {}", + where_clause_fields=fields, + min_batch_size=8, + max_batch_size=100, ) - with self.assertRaises(KeyError): - test_pipeline = TestPipeline() - _ = ( - test_pipeline - | "Create" >> beam.Create(self.req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) - res = test_pipeline.run() - res.wait_until_finish() + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) - def test_enrichment_with_cloudsql_raises_not_found(self): - """Raises a database error when the GCP Cloud SQL table doesn't exist.""" - table_id = "invalid_table" - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + def test_cloudsql_enrichment_with_query_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + fn = functools.partial(query_fn, self._table_id) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=table_id, - row_key=self.row_key, - ) - try: - test_pipeline = beam.Pipeline() - _ = ( - test_pipeline - | "Create" >> beam.Create(self.req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) - res = test_pipeline.run() - res.wait_until_finish() - except (PgDatabaseError, RuntimeError) as e: - self.assertIn(f'relation "{table_id}" does not exist', str(e)) - - def test_enrichment_with_cloudsql_exception_level(self): - """raises a `ValueError` exception when the GCP Cloud SQL query returns - an empty row.""" - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + query_fn=fn) + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) + + assert_that(pcoll, equal_to(expected_rows)) + + def test_cloudsql_enrichment_with_condition_value_fn(self): + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) + ] + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key=self.row_key, - exception_level=ExceptionLevel.RAISE, - ) - req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)] - with self.assertRaises(ValueError): - test_pipeline = beam.Pipeline() - _ = ( - test_pipeline - | "Create" >> beam.Create(req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) - res = test_pipeline.run() - res.wait_until_finish() + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn, + min_batch_size=2, + max_batch_size=100) + with TestPipeline(is_integration_test=True) as test_pipeline: + pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) - def test_cloudsql_enrichment_with_lambda(self): - expected_fields = [ - "sale_id", - "customer_id", - "product_id", - "quantity", - "product_name", - "product_stock", + assert_that(pcoll, equal_to(expected_rows)) + + def test_cloudsql_enrichment_table_nonexistent_runtime_error_raised(self): + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), ] - expected_enriched_fields = ["product_id", "product_name", "product_stock"] - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key_fn=_row_key_fn, + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn, + column_names=["wrong_column"], ) - with TestPipeline(is_integration_test=True) as test_pipeline: + with self.assertRaises(RuntimeError): + test_pipeline = beam.Pipeline() _ = ( test_pipeline - | "Create" >> beam.Create(self.req) - | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) - | "Validate Response" >> beam.ParDo( - ValidateResponse( - len(expected_fields), - expected_fields, - expected_enriched_fields))) - - @pytest.fixture - def cache_container(self): - # Setup phase: start the container. - self._start_cache_container() - - # Hand control to the test. - yield - - # Cleanup phase: stop the container. It runs after the test completion - # even if it failed. - self.container.stop() - self.container = None + | "Create" >> beam.Create(requests) + | "Enrichment" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() @pytest.mark.usefixtures("cache_container") def test_cloudsql_enrichment_with_redis(self): - expected_fields = [ - "sale_id", - "customer_id", - "product_id", - "quantity", - "product_name", - "product_stock", + requests = [ + beam.Row(id=1, name='A'), + beam.Row(id=2, name='B'), + ] + expected_rows = [ + beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), + beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] - expected_enriched_fields = ["product_id", "product_name", "product_stock"] - cloudsql = CloudSQLEnrichmentHandler( - region_id=self.region_id, - project_id=self.project_id, - instance_id=self.instance_id, + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.database_type_adapter, - database_id=self.database_id, - database_user=self.database_user, - database_password=self.database_password, - table_id=self.table_id, - row_key_fn=_row_key_fn, - ) + database_address=self.get_db_address(), + database_user=self.sql_db_user, + database_password=self.sql_db_password, + database_id=self.sql_db_id, + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn, + min_batch_size=2, + max_batch_size=100) with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_populate_cache = ( test_pipeline - | "Create1" >> beam.Create(self.req) - | "Enrich W/ CloudSQL1" >> Enrichment(cloudsql).with_redis_cache( - self.host, self.port, 300) - | "Validate Response" >> beam.ParDo( - ValidateResponse( - len(expected_fields), - expected_fields, - expected_enriched_fields, - ))) + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_populate_cache, equal_to(expected_rows)) # Manually check cache entry to verify entries were correctly stored. c = coders.StrUtf8Coder() - for req in self.req: - key = cloudsql.get_cache_key(req) - response = self.cache_client.get(c.encode(key)) + for req in requests: + key = handler.get_cache_key(req) + response = self._cache_client.get(c.encode(key)) if not response: raise ValueError("No cache entry found for %s" % key) - # Mock the CloudSQL handler to avoid actual database calls. + # Mock the CloudSQL enrichment handler to avoid actual database calls. # This simulates a cache hit scenario by returning predefined data. actual = CloudSQLEnrichmentHandler.__call__ - CloudSQLEnrichmentHandler.__call__ = MagicMock( - return_value=( - beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), - beam.Row(), - )) + CloudSQLEnrichmentHandler.__call__ = MagicMock(return_value=(beam.Row())) # Run a second pipeline to verify cache is being used. with TestPipeline(is_integration_test=True) as test_pipeline: - _ = ( + pcoll_cached = ( test_pipeline - | "Create2" >> beam.Create(self.req) - | "Enrich W/ CloudSQL2" >> Enrichment(cloudsql).with_redis_cache( - self.host, self.port) - | "Validate Response" >> beam.ParDo( - ValidateResponse( - len(expected_fields), - expected_fields, - expected_enriched_fields))) + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port)) + + assert_that(pcoll_cached, equal_to(expected_rows)) + + # Restore the original CloudSQL enrichment handler implementation. CloudSQLEnrichmentHandler.__call__ = actual From ecbb61da2a5762b14bf19212f749e15b63012709 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 11 Apr 2025 14:27:47 +0200 Subject: [PATCH 08/18] sdks: update doc `CloudSQLEnrichmentHandler` --- .../transforms/elementwise/enrichment.py | 32 +++++++++---------- .../transforms/elementwise/enrichment_test.py | 2 ++ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index c88696bbc24d..41b22863e81c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -50,6 +50,7 @@ def enrichment_with_bigtable(): | "Print" >> beam.Map(print)) # [END enrichment_with_bigtable] + def enrichment_with_cloudsql(): # [START enrichment_with_cloudsql] import apache_beam as beam @@ -57,14 +58,14 @@ def enrichment_with_cloudsql(): from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter import os - project_id = 'apache-beam-testing' - region_id = 'us-east1' - instance_id = 'beam-test' - table_id = 'cloudsql-enrichment-test' - database_id = 'test-database' - database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER") - database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD") - row_key = 'product_id' + database_type_adapter = DatabaseTypeAdapter.POSTGRESQL + database_address = "10.0.0.42:5432" + database_user = "test" + database_password = os.getenv("DB_PASSWORD") + database_id = "test" + table_id = "products" + where_clause_template = "product_id = {}" + where_clause_fields = ["id"] data = [ beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), @@ -73,16 +74,14 @@ def enrichment_with_cloudsql(): ] cloudsql_handler = CloudSQLEnrichmentHandler( - project_id=project_id, - region_id=region_id, - instance_id=instance_id, - table_id=table_id, - database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, - database_id=database_id, + database_type_adapter=database_type_adapter, + database_address=database_address, database_user=database_user, database_password=database_password, - row_key=row_key - ) + database_id=database_id, + table_id=table_id, + where_clause_template=where_clause_template, + where_clause_fields=where_clause_fields) with beam.Pipeline() as p: _ = ( p @@ -91,6 +90,7 @@ def enrichment_with_cloudsql(): | "Print" >> beam.Map(print)) # [END enrichment_with_cloudsql] + def enrichment_with_vertex_ai(): # [START enrichment_with_vertex_ai] import apache_beam as beam diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 7086261c254e..e217bb3b49da 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -41,6 +41,7 @@ def validate_enrichment_with_bigtable(): [END enrichment_with_bigtable]'''.splitlines()[1:-1] return expected + def validate_enrichment_with_cloudsql(): expected = '''[START enrichment_with_cloudsql] Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) @@ -49,6 +50,7 @@ def validate_enrichment_with_cloudsql(): [END enrichment_with_cloudsql]'''.splitlines()[1:-1] return expected + def validate_enrichment_with_vertex_ai(): expected = '''[START enrichment_with_vertex_ai] Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1') From 83e715405bed89b42aacbfdf269de4fcb52c9703 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 11 Apr 2025 20:05:54 +0200 Subject: [PATCH 09/18] sdks/python: fix linting issues --- .../transforms/elementwise/enrichment.py | 23 +- .../transforms/elementwise/enrichment_test.py | 81 +++++- .../enrichment_handlers/cloudsql.py | 37 ++- .../enrichment_handlers/cloudsql_it_test.py | 263 ++++++++++-------- .../enrichment_handlers/cloudsql_test.py | 9 +- 5 files changed, 255 insertions(+), 158 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index 41b22863e81c..13a0ad2e226f 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -55,22 +55,23 @@ def enrichment_with_cloudsql(): # [START enrichment_with_cloudsql] import apache_beam as beam from apache_beam.transforms.enrichment import Enrichment - from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter + from apache_beam.transforms.enrichment_handlers.cloudsql import ( + CloudSQLEnrichmentHandler, DatabaseTypeAdapter) import os - database_type_adapter = DatabaseTypeAdapter.POSTGRESQL - database_address = "10.0.0.42:5432" - database_user = "test" - database_password = os.getenv("DB_PASSWORD") - database_id = "test" - table_id = "products" + database_type_adapter = DatabaseTypeAdapter[os.environ.get("SQL_DB_TYPE")] + database_address = os.environ.get("SQL_DB_ADDRESS") + database_user = os.environ.get("SQL_DB_USER") + database_password = os.environ.get("SQL_DB_PASSWORD") + database_id = os.environ.get("SQL_DB_ID") + table_id = os.environ.get("SQL_TABLE_ID") where_clause_template = "product_id = {}" - where_clause_fields = ["id"] + where_clause_fields = ["product_id"] data = [ - beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), - beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), - beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2), + beam.Row(product_id=1, name='A'), + beam.Row(product_id=2, name='B'), + beam.Row(product_id=3, name='C'), ] cloudsql_handler = CloudSQLEnrichmentHandler( diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index e217bb3b49da..82d9caaf3a6b 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -18,16 +18,22 @@ # pytype: skip-file # pylint: disable=line-too-long +import os import unittest from io import StringIO import mock +import pytest # pylint: disable=unused-import try: - from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_bigtable, \ - enrichment_with_vertex_ai_legacy - from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_vertex_ai + from sqlalchemy import Column, Integer, String, Engine + from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( + enrichment_with_bigtable, enrichment_with_vertex_ai_legacy) + from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( + enrichment_with_vertex_ai, enrichment_with_cloudsql) + from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import ( + CloudSQLEnrichmentTestHelper, SQLDBContainerInfo) from apache_beam.io.requestresponse import RequestResponseIO except ImportError: raise unittest.SkipTest('RequestResponseIO dependencies are not installed') @@ -44,9 +50,9 @@ def validate_enrichment_with_bigtable(): def validate_enrichment_with_cloudsql(): expected = '''[START enrichment_with_cloudsql] -Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) -Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'}) -Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'}) +Row(product_id=1, name='A', quantity=2, region_id=3) +Row(product_id=2, name='B', quantity=3, region_id=1) +Row(product_id=3, name='C', quantity=10, region_id=4) [END enrichment_with_cloudsql]'''.splitlines()[1:-1] return expected @@ -70,6 +76,7 @@ def validate_enrichment_with_vertex_ai_legacy(): @mock.patch('sys.stdout', new_callable=StringIO) +@pytest.mark.uses_testcontainer class EnrichmentTest(unittest.TestCase): def test_enrichment_with_bigtable(self, mock_stdout): enrichment_with_bigtable() @@ -78,11 +85,18 @@ def test_enrichment_with_bigtable(self, mock_stdout): self.assertEqual(output, expected) def test_enrichment_with_cloudsql(self, mock_stdout): - from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_cloudsql - enrichment_with_cloudsql() - output = mock_stdout.getvalue().splitlines() - expected = validate_enrichment_with_cloudsql() - self.assertEqual(output, expected) + db, engine = None, None + try: + db, engine = self.pre_cloudsql_enrichment_test() + enrichment_with_cloudsql() + output = mock_stdout.getvalue().splitlines() + expected = validate_enrichment_with_cloudsql() + self.assertEqual(output, expected) + except Exception as e: + self.fail(f"Test failed with unexpected error: {e}") + finally: + if db and engine: + self.post_cloudsql_enrichment_test(db, engine) def test_enrichment_with_vertex_ai(self, mock_stdout): enrichment_with_vertex_ai() @@ -99,6 +113,51 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): self.maxDiff = None self.assertEqual(output, expected) + def pre_cloudsql_enrichment_test(self): + columns = [ + Column("product_id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("quantity", Integer, nullable=False), + Column("region_id", Integer, nullable=False), + ] + table_data = [ + { + "product_id": 1, "name": "A", 'quantity': 2, 'region_id': 3 + }, + { + "product_id": 2, "name": "B", 'quantity': 3, 'region_id': 1 + }, + { + "product_id": 3, "name": "C", 'quantity': 10, 'region_id': 4 + }, + ] + db = CloudSQLEnrichmentTestHelper.start_sql_db_container() + os.environ['SQL_DB_TYPE'] = db.adapter.name + os.environ['SQL_DB_ADDRESS'] = db.address + os.environ['SQL_DB_USER'] = db.user + os.environ['SQL_DB_PASSWORD'] = db.password + os.environ['SQL_DB_ID'] = db.id + os.environ['SQL_DB_URL'] = db.url + os.environ['SQL_TABLE_ID'] = "products" + engine = CloudSQLEnrichmentTestHelper.create_table( + table_id=os.environ.get("SQL_TABLE_ID"), + db_url=os.environ.get("SQL_DB_URL"), + columns=columns, + table_data=table_data) + return db, engine + + def post_cloudsql_enrichment_test( + self, db: SQLDBContainerInfo, engine: Engine): + engine.dispose(close=True) + CloudSQLEnrichmentTestHelper.stop_sql_db_container(db.container) + os.environ.pop('SQL_DB_TYPE', None) + os.environ.pop('SQL_DB_ADDRESS', None) + os.environ.pop('SQL_DB_USER', None) + os.environ.pop('SQL_DB_PASSWORD', None) + os.environ.pop('SQL_DB_ID', None) + os.environ.pop('SQL_DB_URL', None) + os.environ.pop('SQL_TABLE_ID', None) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index 77ab9889c8da..4f2ccea3575c 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -21,11 +21,11 @@ from typing import Optional from typing import Union -from sqlalchemy import create_engine, text +from sqlalchemy import create_engine +from sqlalchemy import text import apache_beam as beam from apache_beam.transforms.enrichment import EnrichmentSourceHandler -from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel QueryFn = Callable[[beam.Row], str] ConditionValueFn = Callable[[beam.Row], list[Any]] @@ -44,8 +44,8 @@ def _validate_cloudsql_metadata( where_clause_value_fn]): raise ValueError( "Please provide either `query_fn` or the parameters `table_id`, " - "`where_clause_template`, and `where_clause_fields/where_clause_value_fn` " - "together.") + "`where_clause_template`, and " + "`where_clause_fields/where_clause_value_fn` together.") else: if not (table_id and where_clause_template): raise ValueError( @@ -147,24 +147,25 @@ def __init__( where_clause_template (str): A template string for the `WHERE` clause in the SQL query with placeholders (`{}`) for dynamic filtering based on input data. - where_clause_fields (Optional[list[str]]): List of field names from the input - `beam.Row` used to construct the `WHERE` clause if `where_clause_value_fn` - is not provided. + where_clause_fields (Optional[list[str]]): List of field names from the + input `beam.Row` used to construct the `WHERE` clause if + `where_clause_value_fn` is not provided. where_clause_value_fn (Optional[Callable[[beam.Row], Any]]): Function that takes a `beam.Row` and returns a list of values to populate the placeholders `{}` in the `WHERE` clause. query_fn (Optional[Callable[[beam.Row], str]]): Function that takes a `beam.Row` and returns a complete SQL query string. - column_names (Optional[list[str]]): List of column names to select from the - Cloud SQL table. If not provided, all columns (`*`) are selected. + column_names (Optional[list[str]]): List of column names to select from + the Cloud SQL table. If not provided, all columns (`*`) are selected. min_batch_size (int): Minimum number of rows to batch together when querying the database. Defaults to 1 if `query_fn` is not used. max_batch_size (int): Maximum number of rows to batch together. Defaults to 10,000 if `query_fn` is not used. - **kwargs: Additional keyword arguments for database connection or query handling. + **kwargs: Additional keyword arguments for database connection or query + handling. Note: - * `min_batch_size` and `max_batch_size` cannot be used if `query_fn` is provided. + * Cannot use `min_batch_size` or `max_batch_size` with `query_fn`. * Either `where_clause_fields` or `where_clause_value_fn` must be provided for query construction if `query_fn` is not provided. * Ensure that the database user has the necessary permissions to query the @@ -183,11 +184,15 @@ def __init__( self._database_address = database_address self._table_id = table_id self._where_clause_template = where_clause_template - self._where_clause_fields = where_clause_fields self._where_clause_value_fn = where_clause_value_fn self._query_fn = query_fn + fields = where_clause_fields if where_clause_fields else [] + self._where_clause_fields = fields self._column_names = ",".join(column_names) if column_names else "*" - self.query_template = f"SELECT {self._column_names} FROM {self._table_id} WHERE {self._where_clause_template}" + self.query_template = ( + f"SELECT {self._column_names} " + f"FROM {self._table_id} " + f"WHERE {self._where_clause_template}") self.kwargs = kwargs self._batching_kwargs = {} if not query_fn: @@ -201,8 +206,10 @@ def __enter__(self): def _get_db_url(self) -> str: dialect = self._database_type_adapter.to_sqlalchemy_dialect() - string = f"{dialect}://{self._database_user}:{self._database_password}@{self._database_address}/{self._database_id}" - return string + url = ( + f"{dialect}://{self._database_user}:{self._database_password}" + f"@{self._database_address}/{self._database_id}") + return url def _execute_query(self, query: str, is_batch: bool, **params): try: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index c7939b09c07d..58ce8dce5263 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -17,6 +17,8 @@ import functools import logging import unittest +from dataclasses import dataclass +from typing import Optional from unittest.mock import MagicMock import pytest @@ -29,12 +31,16 @@ # pylint: disable=ungrouped-imports try: + from testcontainers.core.generic import DbContainer from testcontainers.postgres import PostgresContainer from testcontainers.redis import RedisContainer - from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String + from sqlalchemy import ( + create_engine, MetaData, Table, Column, Integer, String, Engine) from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.cloudsql import ( - CloudSQLEnrichmentHandler, DatabaseTypeAdapter, DatabaseTypeAdapter) + CloudSQLEnrichmentHandler, + DatabaseTypeAdapter, + ) except ImportError: raise unittest.SkipTest('Google Cloud SQL dependencies are not installed.') @@ -49,61 +55,114 @@ def query_fn(table, row: beam.Row): return f"SELECT * FROM `{table}` WHERE id = {row.id}" # type: ignore[attr-defined] -@pytest.mark.uses_testcontainer -class CloudSQLEnrichmentIT(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls._sql_client_retries = 3 - cls._start_sql_db_container() - - @classmethod - def tearDownClass(cls): - cls._stop_sql_db_container() - - @classmethod - def _start_sql_db_container(cls): - for i in range(cls._sql_client_retries): +@dataclass +class SQLDBContainerInfo: + adapter: DatabaseTypeAdapter + container: DbContainer + host: str + port: int + user: str + password: str + id: str + + @property + def address(self) -> str: + return f"{self.host}:{self.port}" + + @property + def url(self) -> str: + dialect = self.adapter.to_sqlalchemy_dialect() + return f"{dialect}://{self.user}:{self.password}@{self.address}/{self.id}" + + +class CloudSQLEnrichmentTestHelper: + @staticmethod + def start_sql_db_container( + sql_client_retries=3) -> Optional[SQLDBContainerInfo]: + info = None + for i in range(sql_client_retries): try: - cls._sql_db_container = PostgresContainer(image="postgres:16") - cls._sql_db_container.start() - cls.sql_db_container_host = cls._sql_db_container.get_container_host_ip( - ) - cls.sql_db_container_port = cls._sql_db_container.get_exposed_port(5432) - cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL - cls.sql_db_user, cls.sql_db_password, cls.sql_db_id = "test", "test", "test" + database_type_adapter = DatabaseTypeAdapter.POSTGRESQL + sql_db_container = PostgresContainer(image="postgres:16") + sql_db_container.start() + host = sql_db_container.get_container_host_ip() + port = sql_db_container.get_exposed_port(5432) + user, password, db_id = "test", "test", "test" + info = SQLDBContainerInfo( + adapter=database_type_adapter, + container=sql_db_container, + host=host, + port=port, + user=user, + password=password, + id=db_id) _LOGGER.info( - f"PostgreSQL container started successfully on {cls.get_db_address()}." - ) + "PostgreSQL container started successfully on %s.", info.address) break except Exception as e: _LOGGER.warning( - f"Retry {i + 1}/{cls._sql_client_retries}: Failed to start PostgreSQL container. Reason: {e}" - ) - if i == cls._sql_client_retries - 1: + "Retry %d/%d: Failed to start PostgreSQL container. Reason: %s", + i + 1, + sql_client_retries, + e) + if i == sql_client_retries - 1: _LOGGER.error( - f"Unable to start PostgreSQL container for IO tests after {cls._sql_client_retries} retries. Tests cannot proceed." - ) + "Unable to start PostgreSQL container for IO tests after %d " + "retries. Tests cannot proceed.", + sql_client_retries) raise e - @classmethod - def _stop_sql_db_container(cls): + return info + + @staticmethod + def stop_sql_db_container(sql_db: DbContainer): try: - _LOGGER.info("Stopping PostgreSQL container.") - cls._sql_db_container.stop() - cls._sql_db_container = None + _LOGGER.debug("Stopping PostgreSQL container.") + sql_db.stop() _LOGGER.info("PostgreSQL container stopped successfully.") except Exception as e: _LOGGER.warning( - f"Error encountered while stopping PostgreSQL container: {e}") + "Error encountered while stopping PostgreSQL container: %s", e) + + @staticmethod + def create_table( + table_id: str, + db_url: str, + columns: list[Column], + table_data: list[dict], + metadata: MetaData = MetaData()) -> Engine: + engine = create_engine(db_url) + table = Table(table_id, metadata, *columns) + + # metadata = MetaData() + # Column("id", Integer, primary_key=True), + # Column("name", String, nullable=False), + # Column("quantity", Integer, nullable=False), + # Column("distribution_center_id", Integer, nullable=False), + # Create the table in the database. + metadata.create_all(engine) - @classmethod - def get_db_address(cls): - return f"{cls.sql_db_container_host}:{cls.sql_db_container_port}" + # Insert data into the table. + with engine.connect() as connection: + transaction = connection.begin() + try: + connection.execute(table.insert(), table_data) + transaction.commit() + return engine + except Exception as e: + transaction.rollback() + raise e @pytest.mark.uses_testcontainer -class TestCloudSQLEnrichment(CloudSQLEnrichmentIT): +class TestCloudSQLEnrichment(unittest.TestCase): _table_id = "product_details" + _columns = [ + Column("id", Integer, primary_key=True), + Column("name", String, nullable=False), + Column("quantity", Integer, nullable=False), + Column("distribution_center_id", Integer, nullable=False), + ] _table_data = [ { "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 @@ -133,44 +192,11 @@ class TestCloudSQLEnrichment(CloudSQLEnrichmentIT): @classmethod def setUpClass(cls): - super(TestCloudSQLEnrichment, cls).setUpClass() - cls.create_table(cls._table_id) + cls.db = CloudSQLEnrichmentTestHelper.start_sql_db_container() + cls._engine = CloudSQLEnrichmentTestHelper.create_table( + cls._table_id, cls.db.url, cls._columns, cls._table_data) cls._cache_client_retries = 3 - @classmethod - def create_table(cls, table_id): - cls._engine = create_engine(cls._get_db_url()) - - # Define the table schema. - metadata = MetaData() - table = Table( - table_id, - metadata, - Column("id", Integer, primary_key=True), - Column("name", String, nullable=False), - Column("quantity", Integer, nullable=False), - Column("distribution_center_id", Integer, nullable=False), - ) - - # Create the table in the database. - metadata.create_all(cls._engine) - - # Insert data into the table. - with cls._engine.connect() as connection: - transaction = connection.begin() - try: - connection.execute(table.insert(), cls._table_data) - transaction.commit() - except Exception as e: - transaction.rollback() - raise e - - @classmethod - def _get_db_url(cls): - dialect = cls.database_type_adapter.to_sqlalchemy_dialect() - db_url = f"{dialect}://{cls.sql_db_user}:{cls.sql_db_password}@{cls.get_db_address()}/{cls.sql_db_id}" - return db_url - @pytest.fixture def cache_container(self): self._start_cache_container() @@ -186,23 +212,24 @@ def _start_cache_container(self): try: self._cache_container = RedisContainer(image="redis:7.2.4") self._cache_container.start() - self._cache_container_host = self._cache_container.get_container_host_ip( - ) - self._cache_container_port = self._cache_container.get_exposed_port( - 6379) + host = self._cache_container.get_container_host_ip() + port = self._cache_container.get_exposed_port(6379) + self._cache_container_host = host + self._cache_container_port = port self._cache_client = self._cache_container.get_client() break except Exception as e: if i == self._cache_client_retries - 1: _LOGGER.error( - f"Unable to start redis container for RRIO tests after {self._cache_client_retries} retries." - ) + "Unable to start redis container for RRIO tests after " + "%d retries.", + self._cache_client_retries) raise e @classmethod def tearDownClass(cls): cls._engine.dispose(close=True) - super(TestCloudSQLEnrichment, cls).tearDownClass() + CloudSQLEnrichmentTestHelper.stop_sql_db_container(cls.db.container) cls._engine = None def test_cloudsql_enrichment(self): @@ -216,11 +243,11 @@ def test_cloudsql_enrichment(self): beam.Row(id=2, name='B'), ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.id, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {}", where_clause_fields=fields, @@ -243,11 +270,11 @@ def test_cloudsql_enrichment_batched(self): beam.Row(id=2, name='B'), ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {}", where_clause_fields=fields, @@ -270,11 +297,11 @@ def test_cloudsql_enrichment_batched_multiple_fields(self): beam.Row(id=2, distribution_center_id=1), ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {} AND distribution_center_id = {}", where_clause_fields=fields, @@ -297,11 +324,11 @@ def test_cloudsql_enrichment_with_query_fn(self): ] fn = functools.partial(query_fn, self._table_id) handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, query_fn=fn) with TestPipeline(is_integration_test=True) as test_pipeline: pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) @@ -318,11 +345,11 @@ def test_cloudsql_enrichment_with_condition_value_fn(self): beam.Row(id=2, name='B'), ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {}", where_clause_value_fn=where_clause_value_fn, @@ -339,11 +366,11 @@ def test_cloudsql_enrichment_table_nonexistent_runtime_error_raised(self): beam.Row(id=2, name='B'), ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {}", where_clause_value_fn=where_clause_value_fn, @@ -369,11 +396,11 @@ def test_cloudsql_enrichment_with_redis(self): beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] handler = CloudSQLEnrichmentHandler( - database_type_adapter=self.database_type_adapter, - database_address=self.get_db_address(), - database_user=self.sql_db_user, - database_password=self.sql_db_password, - database_id=self.sql_db_id, + database_type_adapter=self.db.adapter, + database_address=self.db.address, + database_user=self.db.user, + database_password=self.db.password, + database_id=self.db.id, table_id=self._table_id, where_clause_template="id = {}", where_clause_value_fn=where_clause_value_fn, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py index 2fa74bf33233..888886479c75 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -20,9 +20,12 @@ # pylint: disable=ungrouped-imports try: - from apache_beam.transforms.enrichment_handlers.cloudsql import CloudSQLEnrichmentHandler, DatabaseTypeAdapter - from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import where_clause_value_fn - from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import query_fn + from apache_beam.transforms.enrichment_handlers.cloudsql import ( + CloudSQLEnrichmentHandler, DatabaseTypeAdapter) + from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import ( + query_fn, + where_clause_value_fn, + ) except ImportError: raise unittest.SkipTest('Google Cloud SQL dependencies are not installed.') From 4fa1830837153c5969861c59e724bcbc661fe641 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Sun, 13 Apr 2025 00:08:02 +0000 Subject: [PATCH 10/18] website: add missing `enrichment-cloudsql.md` --- .../python/elementwise/enrichment-cloudsql.md | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-cloudsql.md diff --git a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-cloudsql.md b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-cloudsql.md new file mode 100644 index 000000000000..a8c5de1c2694 --- /dev/null +++ b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-cloudsql.md @@ -0,0 +1,63 @@ +--- +title: "Enrichment with CloudSQL" +--- + + +# Use CloudSQL to enrich data + +{{< localstorage language language-py >}} + + + + + +
+ + {{< button-pydoc path="apache_beam.transforms.enrichment_handlers.cloudsql" class="CloudSQLEnrichmentHandler" >}} + +
+ +In Apache Beam and later versions, the enrichment transform includes +a built-in enrichment handler for +[CloudSQL](https://cloud.google.com/sql/docs). +The following example demonstrates how to create a pipeline that use the enrichment transform with the [`CloudSQLEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.cloudsql.html#apache_beam.transforms.enrichment_handlers.cloudsql.CloudSQLEnrichmentHandler) handler. + +The data in the CloudSQL PostgreSQL products table instance follows this format: + +{{< table >}} +| product_id | name | quantity | region_id | +|:----------:|:----:|:--------:|:---------:| +| 1 | A | 2 | 3 | +| 2 | B | 3 | 1 | +| 3 | C | 10 | 4 | +{{< /table >}} + + +{{< highlight language="py" >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py" enrichment_with_cloudsql >}} +{{}} + +{{< paragraph class="notebook-skip" >}} +Output: +{{< /paragraph >}} +{{< highlight class="notebook-skip" >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py" enrichment_with_cloudsql >}} +{{< /highlight >}} + +## Related transforms + +Not applicable. + +{{< button-pydoc path="apache_beam.transforms.enrichment_handlers.cloudsql" class="CloudSQLEnrichmentHandler" >}} \ No newline at end of file From 4251a7677b9514b324db0c7afa87c6636fb6df4d Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Sun, 13 Apr 2025 00:19:26 +0000 Subject: [PATCH 11/18] nits: remove commented code --- .../transforms/enrichment_handlers/cloudsql_it_test.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index 58ce8dce5263..599528198431 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -133,13 +133,6 @@ def create_table( metadata: MetaData = MetaData()) -> Engine: engine = create_engine(db_url) table = Table(table_id, metadata, *columns) - - # metadata = MetaData() - # Column("id", Integer, primary_key=True), - # Column("name", String, nullable=False), - # Column("quantity", Integer, nullable=False), - # Column("distribution_center_id", Integer, nullable=False), - # Create the table in the database. metadata.create_all(engine) # Insert data into the table. From ab7cf7b6711430adacb97c6740d48049f15d439e Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Sun, 13 Apr 2025 00:28:58 +0000 Subject: [PATCH 12/18] sdks/python: remove `SQL_TABLE_ID` env variable --- .../examples/snippets/transforms/elementwise/enrichment.py | 2 +- .../snippets/transforms/elementwise/enrichment_test.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index 13a0ad2e226f..05ecddbef54d 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -64,7 +64,7 @@ def enrichment_with_cloudsql(): database_user = os.environ.get("SQL_DB_USER") database_password = os.environ.get("SQL_DB_PASSWORD") database_id = os.environ.get("SQL_DB_ID") - table_id = os.environ.get("SQL_TABLE_ID") + table_id = "products" where_clause_template = "product_id = {}" where_clause_fields = ["product_id"] diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 82d9caaf3a6b..c62da660551a 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -114,6 +114,7 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): self.assertEqual(output, expected) def pre_cloudsql_enrichment_test(self): + table_id ="products" columns = [ Column("product_id", Integer, primary_key=True), Column("name", String, nullable=False), @@ -138,10 +139,9 @@ def pre_cloudsql_enrichment_test(self): os.environ['SQL_DB_PASSWORD'] = db.password os.environ['SQL_DB_ID'] = db.id os.environ['SQL_DB_URL'] = db.url - os.environ['SQL_TABLE_ID'] = "products" engine = CloudSQLEnrichmentTestHelper.create_table( - table_id=os.environ.get("SQL_TABLE_ID"), db_url=os.environ.get("SQL_DB_URL"), + table_id=table_id, columns=columns, table_data=table_data) return db, engine @@ -156,7 +156,6 @@ def post_cloudsql_enrichment_test( os.environ.pop('SQL_DB_PASSWORD', None) os.environ.pop('SQL_DB_ID', None) os.environ.pop('SQL_DB_URL', None) - os.environ.pop('SQL_TABLE_ID', None) if __name__ == '__main__': From 669d2d4d88b45378c2c80c31e67a37c7959d1dd1 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Sun, 13 Apr 2025 00:47:56 +0000 Subject: [PATCH 13/18] sdks/python: fix formatting issues --- .../examples/snippets/transforms/elementwise/enrichment_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index c62da660551a..7f2987b240a1 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -114,7 +114,7 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): self.assertEqual(output, expected) def pre_cloudsql_enrichment_test(self): - table_id ="products" + table_id = "products" columns = [ Column("product_id", Integer, primary_key=True), Column("name", String, nullable=False), From 720933542cad6a63c19c684ef497bbdfe43283cd Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Wed, 21 May 2025 04:05:44 +0000 Subject: [PATCH 14/18] sdks/python: address claudevdm feedback (2) --- .../transforms/elementwise/enrichment.py | 10 +- .../transforms/elementwise/enrichment_test.py | 7 +- .../enrichment_handlers/cloudsql.py | 223 ++++++++++-------- .../enrichment_handlers/cloudsql_it_test.py | 184 +++++++++++---- .../enrichment_handlers/cloudsql_test.py | 135 +++++++++-- 5 files changed, 389 insertions(+), 170 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index 05ecddbef54d..0190b12b700d 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -56,7 +56,7 @@ def enrichment_with_cloudsql(): import apache_beam as beam from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.cloudsql import ( - CloudSQLEnrichmentHandler, DatabaseTypeAdapter) + CloudSQLEnrichmentHandler, DatabaseTypeAdapter, TableFieldsQueryConfig) import os database_type_adapter = DatabaseTypeAdapter[os.environ.get("SQL_DB_TYPE")] @@ -74,6 +74,11 @@ def enrichment_with_cloudsql(): beam.Row(product_id=3, name='C'), ] + query_config = TableFieldsQueryConfig( + table_id=table_id, + where_clause_template=where_clause_template, + where_clause_fields=where_clause_fields) + cloudsql_handler = CloudSQLEnrichmentHandler( database_type_adapter=database_type_adapter, database_address=database_address, @@ -81,8 +86,7 @@ def enrichment_with_cloudsql(): database_password=database_password, database_id=database_id, table_id=table_id, - where_clause_template=where_clause_template, - where_clause_fields=where_clause_fields) + query_config=query_config) with beam.Pipeline() as p: _ = ( p diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 7f2987b240a1..426f96d58b8e 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -32,6 +32,8 @@ enrichment_with_bigtable, enrichment_with_vertex_ai_legacy) from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( enrichment_with_vertex_ai, enrichment_with_cloudsql) + from apache_beam.transforms.enrichment_handlers.cloudsql import ( + DatabaseTypeAdapter) from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import ( CloudSQLEnrichmentTestHelper, SQLDBContainerInfo) from apache_beam.io.requestresponse import RequestResponseIO @@ -132,7 +134,8 @@ def pre_cloudsql_enrichment_test(self): "product_id": 3, "name": "C", 'quantity': 10, 'region_id': 4 }, ] - db = CloudSQLEnrichmentTestHelper.start_sql_db_container() + db_adapter = DatabaseTypeAdapter.POSTGRESQL + db = CloudSQLEnrichmentTestHelper.start_sql_db_container(db_adapter) os.environ['SQL_DB_TYPE'] = db.adapter.name os.environ['SQL_DB_ADDRESS'] = db.address os.environ['SQL_DB_USER'] = db.user @@ -149,7 +152,7 @@ def pre_cloudsql_enrichment_test(self): def post_cloudsql_enrichment_test( self, db: SQLDBContainerInfo, engine: Engine): engine.dispose(close=True) - CloudSQLEnrichmentTestHelper.stop_sql_db_container(db.container) + CloudSQLEnrichmentTestHelper.stop_sql_db_container(db) os.environ.pop('SQL_DB_TYPE', None) os.environ.pop('SQL_DB_ADDRESS', None) os.environ.pop('SQL_DB_USER', None) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index 4f2ccea3575c..c676f92ecdbd 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -16,10 +16,13 @@ # from collections.abc import Callable from collections.abc import Mapping +from dataclasses import dataclass from enum import Enum from typing import Any +from typing import List from typing import Optional from typing import Union +from typing import cast from sqlalchemy import create_engine from sqlalchemy import text @@ -31,43 +34,75 @@ ConditionValueFn = Callable[[beam.Row], list[Any]] -def _validate_cloudsql_metadata( - table_id, - where_clause_template, - where_clause_fields, - where_clause_value_fn, - query_fn): - if query_fn: - if any([table_id, - where_clause_template, - where_clause_fields, - where_clause_value_fn]): +@dataclass +class CustomQueryConfig: + """Configuration for using a custom query function.""" + query_fn: QueryFn + + +@dataclass +class TableFieldsQueryConfig: + """Configuration for using table name, where clause, and field names.""" + table_id: str + where_clause_template: str + where_clause_fields: List[str] + + +@dataclass +class TableFunctionQueryConfig: + """Configuration for using table name, where clause, and a value function.""" + table_id: str + where_clause_template: str + where_clause_value_fn: ConditionValueFn + + +QueryConfig = Union[CustomQueryConfig, + TableFieldsQueryConfig, + TableFunctionQueryConfig] + + +def _validate_query_config(query_config: QueryConfig): + """Validates the provided query configuration.""" + if isinstance(query_config, CustomQueryConfig): + if not query_config.query_fn: + raise ValueError("CustomQueryConfig must provide a valid query_fn") + elif isinstance(query_config, + (TableFieldsQueryConfig, TableFunctionQueryConfig)): + if not query_config.table_id or not query_config.where_clause_template: raise ValueError( - "Please provide either `query_fn` or the parameters `table_id`, " - "`where_clause_template`, and " - "`where_clause_fields/where_clause_value_fn` together.") + "TableFieldsQueryConfig and " + + "TableFunctionQueryConfig must provide table_id " + + "and where_clause_template") + + is_table_fields = isinstance(query_config, TableFieldsQueryConfig) + if is_table_fields: + table_fields_config = cast(TableFieldsQueryConfig, query_config) + if not table_fields_config.where_clause_fields: + raise ValueError( + "TableFieldsQueryConfig must provide non-empty " + + "where_clause_fields") + + is_table_function = isinstance(query_config, TableFunctionQueryConfig) + if is_table_function: + table_function_config = cast(TableFunctionQueryConfig, query_config) + if not table_function_config.where_clause_value_fn: + raise ValueError( + "TableFunctionQueryConfig must provide " + "where_clause_value_fn") else: - if not (table_id and where_clause_template): - raise ValueError( - "Please provide either `query_fn` or the parameters " - "`table_id` and `where_clause_template` together.") - if (bool(where_clause_fields) == bool(where_clause_value_fn)): - raise ValueError( - "Please provide exactly one of `where_clause_fields` or " - "`where_clause_value_fn`.") + raise ValueError("Invalid query_config type provided") class DatabaseTypeAdapter(Enum): POSTGRESQL = "psycopg2" MYSQL = "pymysql" - SQLSERVER = "pytds" + SQLSERVER = "pymysql" def to_sqlalchemy_dialect(self): + """Map the adapter type to its corresponding SQLAlchemy dialect. + + Returns: + str: SQLAlchemy dialect string. """ - Map the adapter type to its corresponding SQLAlchemy dialect. - Returns: - str: SQLAlchemy dialect string. - """ if self == DatabaseTypeAdapter.POSTGRESQL: return f"postgresql+{self.value}" elif self == DatabaseTypeAdapter.MYSQL: @@ -79,16 +114,15 @@ def to_sqlalchemy_dialect(self): class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): - """ - Enrichment handler for Cloud SQL databases. + """Enrichment handler for Cloud SQL databases. This handler is designed to work with the :class:`apache_beam.transforms.enrichment.Enrichment` transform. - To use this handler, you need to provide either of the following combinations: - * `table_id`, `where_clause_template`, `where_clause_fields` - * `table_id`, `where_clause_template`, `where_clause_value_fn` - * `query_fn` + To use this handler, you need to provide one of the following query configs: + * CustomQueryConfig - For providing a custom query function + * TableFieldsQueryConfig - For specifying table, where clause, and fields + * TableFunctionQueryConfig - For specifying table, where clause, and val fn By default, the handler retrieves all columns from the specified table. To limit the columns, use the `column_names` parameter to specify @@ -99,7 +133,7 @@ class CloudSQLEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): These values control the batching behavior in the :class:`apache_beam.transforms.utils.BatchElements` transform. - NOTE: Batching is not supported when using the `query_fn` parameter. + NOTE: Batching is not supported when using the CustomQueryConfig. """ def __init__( self, @@ -109,11 +143,7 @@ def __init__( database_password: str, database_id: str, *, - table_id: str = "", - where_clause_template: str = "", - where_clause_fields: Optional[list[str]] = None, - where_clause_value_fn: Optional[ConditionValueFn] = None, - query_fn: Optional[QueryFn] = None, + query_config: QueryConfig, column_names: Optional[list[str]] = None, min_batch_size: int = 1, max_batch_size: int = 10000, @@ -127,12 +157,9 @@ def __init__( database_user='user', database_password='password', database_id='my_database', - table_id='my_table', - where_clause_template="id = '{}'", - where_clause_fields=['id'], + query_config=TableFieldsQueryConfig('my_table',"id = '{}'",['id']), min_batch_size=2, - max_batch_size=100 - ) + max_batch_size=100) Args: database_type_adapter: Adapter to handle specific database type operations @@ -143,18 +170,10 @@ def __init__( database_user (str): Username for accessing the database. database_password (str): Password for accessing the database. database_id (str): Identifier for the database to query. - table_id (str): Name of the table to query in the Cloud SQL database. - where_clause_template (str): A template string for the `WHERE` clause - in the SQL query with placeholders (`{}`) for dynamic filtering - based on input data. - where_clause_fields (Optional[list[str]]): List of field names from the - input `beam.Row` used to construct the `WHERE` clause if - `where_clause_value_fn` is not provided. - where_clause_value_fn (Optional[Callable[[beam.Row], Any]]): Function that - takes a `beam.Row` and returns a list of values to populate the - placeholders `{}` in the `WHERE` clause. - query_fn (Optional[Callable[[beam.Row], str]]): Function that takes a - `beam.Row` and returns a complete SQL query string. + query_config: Configuration for database queries. Must be one of: + * CustomQueryConfig: For providing a custom query function + * TableFieldsQueryConfig: specifies table, where clause, and field names + * TableFunctionQueryConfig: specifies table, where clause, and val func column_names (Optional[list[str]]): List of column names to select from the Cloud SQL table. If not provided, all columns (`*`) are selected. min_batch_size (int): Minimum number of rows to batch together when @@ -171,31 +190,22 @@ def __init__( * Ensure that the database user has the necessary permissions to query the specified table. """ - _validate_cloudsql_metadata( - table_id, - where_clause_template, - where_clause_fields, - where_clause_value_fn, - query_fn) + _validate_query_config(query_config) self._database_type_adapter = database_type_adapter self._database_id = database_id self._database_user = database_user self._database_password = database_password self._database_address = database_address - self._table_id = table_id - self._where_clause_template = where_clause_template - self._where_clause_value_fn = where_clause_value_fn - self._query_fn = query_fn - fields = where_clause_fields if where_clause_fields else [] - self._where_clause_fields = fields + self._query_config = query_config self._column_names = ",".join(column_names) if column_names else "*" - self.query_template = ( - f"SELECT {self._column_names} " - f"FROM {self._table_id} " - f"WHERE {self._where_clause_template}") self.kwargs = kwargs self._batching_kwargs = {} - if not query_fn: + table_query_configs = (TableFieldsQueryConfig, TableFunctionQueryConfig) + if isinstance(query_config, table_query_configs): + self.query_template = ( + f"SELECT {self._column_names} " + f"FROM {query_config.table_id} " + f"WHERE {query_config.where_clause_template}") self._batching_kwargs['min_batch_size'] = min_batch_size self._batching_kwargs['max_batch_size'] = max_batch_size @@ -233,20 +243,26 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): # For multiple requests in the batch, combine the WHERE clause conditions # using 'OR' and update the query template to handle all requests. - if batch_size > 1: + table_query_configs = (TableFieldsQueryConfig, TableFunctionQueryConfig) + if batch_size > 1 and isinstance(self._query_config, table_query_configs): where_clause_template_batched = ' OR '.join( - [fr'({self._where_clause_template})'] * batch_size) + [fr'({self._query_config.where_clause_template})'] * batch_size) raw_query = self.query_template.replace( - self._where_clause_template, where_clause_template_batched) + self._query_config.where_clause_template, + where_clause_template_batched) # Extract where_clause_fields values and map the generated request key to # the original request object. for req in request: request_dict = req._asdict() try: - current_values = ( - self._where_clause_value_fn(req) if self._where_clause_value_fn - else [request_dict[field] for field in self._where_clause_fields]) + if isinstance(self._query_config, TableFunctionQueryConfig): + current_values = self._query_config.where_clause_value_fn(req) + elif isinstance(self._query_config, TableFieldsQueryConfig): + current_values = [ + request_dict[field] + for field in self._query_config.where_clause_fields + ] except KeyError as e: raise KeyError( "Make sure the values passed in `where_clause_fields` are the " @@ -266,14 +282,17 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): return responses else: request_dict = request._asdict() - if self._query_fn: - query = self._query_fn(request) + if isinstance(self._query_config, CustomQueryConfig): + query = self._query_config.query_fn(request) else: try: - values = ( - self._where_clause_value_fn(request) - if self._where_clause_value_fn else - [request_dict[field] for field in self._where_clause_fields]) + if isinstance(self._query_config, TableFunctionQueryConfig): + values = self._query_config.where_clause_value_fn(request) + elif isinstance(self._query_config, TableFieldsQueryConfig): + values = [ + request_dict[field] + for field in self._query_config.where_clause_fields + ] except KeyError as e: raise KeyError( "Make sure the values passed in `where_clause_fields` are the " @@ -283,14 +302,14 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): return request, beam.Row(**response_dict) def create_row_key(self, row: beam.Row): - if self._where_clause_value_fn: - return tuple(self._where_clause_value_fn(row)) - if self._where_clause_fields: + if isinstance(self._query_config, TableFunctionQueryConfig): + return tuple(self._query_config.where_clause_value_fn(row)) + if isinstance(self._query_config, TableFieldsQueryConfig): row_dict = row._asdict() return ( tuple( row_dict[where_clause_field] - for where_clause_field in self._where_clause_fields)) + for where_clause_field in self._query_config.where_clause_fields)) raise ValueError( "Either where_clause_fields or where_clause_value_fn must be specified") @@ -300,14 +319,24 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._engine, self._connection = None, None def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): + if isinstance(self._query_config, CustomQueryConfig): + raise NotImplementedError( + "Caching is not supported for CustomQueryConfig. " + "Consider using TableFieldsQueryConfig or " + + "TableFunctionQueryConfig instead.") + if isinstance(request, list): cache_keys = [] for req in request: req_dict = req._asdict() try: - current_values = ( - self._where_clause_value_fn(req) if self._where_clause_value_fn - else [req_dict[field] for field in self._where_clause_fields]) + if isinstance(self._query_config, TableFunctionQueryConfig): + current_values = self._query_config.where_clause_value_fn(req) + elif isinstance(self._query_config, TableFieldsQueryConfig): + current_values = [ + req_dict[field] + for field in self._query_config.where_clause_fields + ] key = ";".join(["%s"] * len(current_values)) cache_keys.extend([key % tuple(current_values)]) except KeyError as e: @@ -318,9 +347,13 @@ def get_cache_key(self, request: Union[beam.Row, list[beam.Row]]): else: req_dict = request._asdict() try: - current_values = ( - self._where_clause_value_fn(request) if self._where_clause_value_fn - else [req_dict[field] for field in self._where_clause_fields]) + if isinstance(self._query_config, TableFunctionQueryConfig): + current_values = self._query_config.where_clause_value_fn(request) + else: # TableFieldsQueryConfig + current_values = [ + req_dict[field] + for field in self._query_config.where_clause_fields + ] key = ";".join(["%s"] * len(current_values)) cache_key = key % tuple(current_values) except KeyError as e: diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index 599528198431..e989be7079ba 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -33,6 +33,8 @@ try: from testcontainers.core.generic import DbContainer from testcontainers.postgres import PostgresContainer + from testcontainers.mysql import MySqlContainer + from testcontainers.mssql import SqlServerContainer from testcontainers.redis import RedisContainer from sqlalchemy import ( create_engine, MetaData, Table, Column, Integer, String, Engine) @@ -40,7 +42,9 @@ from apache_beam.transforms.enrichment_handlers.cloudsql import ( CloudSQLEnrichmentHandler, DatabaseTypeAdapter, - ) + CustomQueryConfig, + TableFieldsQueryConfig, + TableFunctionQueryConfig) except ImportError: raise unittest.SkipTest('Google Cloud SQL dependencies are not installed.') @@ -78,18 +82,36 @@ def url(self) -> str: class CloudSQLEnrichmentTestHelper: @staticmethod def start_sql_db_container( + database_type: DatabaseTypeAdapter, sql_client_retries=3) -> Optional[SQLDBContainerInfo]: info = None for i in range(sql_client_retries): try: - database_type_adapter = DatabaseTypeAdapter.POSTGRESQL - sql_db_container = PostgresContainer(image="postgres:16") - sql_db_container.start() - host = sql_db_container.get_container_host_ip() - port = sql_db_container.get_exposed_port(5432) - user, password, db_id = "test", "test", "test" + if database_type == DatabaseTypeAdapter.POSTGRESQL: + sql_db_container = PostgresContainer(image="postgres:16") + sql_db_container.start() + host = sql_db_container.get_container_host_ip() + port = sql_db_container.get_exposed_port(5432) + user, password, db_id = "test", "test", "test" + + elif database_type == DatabaseTypeAdapter.MYSQL: + sql_db_container = MySqlContainer(image="mysql:8.0") + sql_db_container.start() + host = sql_db_container.get_container_host_ip() + port = sql_db_container.get_exposed_port(3306) + user, password, db_id = "test", "test", "test" + + elif database_type == DatabaseTypeAdapter.SQLSERVER: + sql_db_container = SqlServerContainer() + sql_db_container.start() + host = sql_db_container.get_container_host_ip() + port = sql_db_container.get_exposed_port(1433) + user, password, db_id = "sa", "A_Str0ng_Required_Password", "tempdb" + else: + raise ValueError(f"Unsupported database type: {database_type}") + info = SQLDBContainerInfo( - adapter=database_type_adapter, + adapter=database_type, container=sql_db_container, host=host, port=port, @@ -97,32 +119,38 @@ def start_sql_db_container( password=password, id=db_id) _LOGGER.info( - "PostgreSQL container started successfully on %s.", info.address) + "%s container started successfully on %s.", + database_type.name, + info.address) break except Exception as e: _LOGGER.warning( - "Retry %d/%d: Failed to start PostgreSQL container. Reason: %s", + "Retry %d/%d: Failed to start %s container. Reason: %s", i + 1, sql_client_retries, + database_type.name, e) if i == sql_client_retries - 1: _LOGGER.error( - "Unable to start PostgreSQL container for IO tests after %d " + "Unable to start %s container for I/O tests after %d " "retries. Tests cannot proceed.", + database_type.name, sql_client_retries) raise e return info @staticmethod - def stop_sql_db_container(sql_db: DbContainer): + def stop_sql_db_container(db_info: SQLDBContainerInfo): try: - _LOGGER.debug("Stopping PostgreSQL container.") - sql_db.stop() - _LOGGER.info("PostgreSQL container stopped successfully.") + _LOGGER.debug("Stopping %s container.", db_info.adapter.name) + db_info.container.stop() + _LOGGER.info("%s container stopped successfully.", db_info.adapter.name) except Exception as e: _LOGGER.warning( - "Error encountered while stopping PostgreSQL container: %s", e) + "Error encountered while stopping %s container: %s", + db_info.adapter.name, + e) @staticmethod def create_table( @@ -141,21 +169,23 @@ def create_table( try: connection.execute(table.insert(), table_data) transaction.commit() - return engine except Exception as e: transaction.rollback() raise e + return engine + + +def init_db_type(db_type): + def wrapper(cls): + cls.db_type = db_type + return cls + + return wrapper + @pytest.mark.uses_testcontainer -class TestCloudSQLEnrichment(unittest.TestCase): - _table_id = "product_details" - _columns = [ - Column("id", Integer, primary_key=True), - Column("name", String, nullable=False), - Column("quantity", Integer, nullable=False), - Column("distribution_center_id", Integer, nullable=False), - ] +class BaseTestCloudSQLEnrichment(unittest.TestCase): _table_data = [ { "id": 1, "name": "A", 'quantity': 2, 'distribution_center_id': 3 @@ -182,14 +212,29 @@ class TestCloudSQLEnrichment(unittest.TestCase): "id": 8, "name": "D", 'quantity': 4, 'distribution_center_id': 1 }, ] + db = None + _engine = None @classmethod def setUpClass(cls): - cls.db = CloudSQLEnrichmentTestHelper.start_sql_db_container() + if not hasattr(cls, 'db_type'): + # Skip setup for the base class. + raise unittest.SkipTest("Base class - no db_type defined") + cls.db = CloudSQLEnrichmentTestHelper.start_sql_db_container(cls.db_type) cls._engine = CloudSQLEnrichmentTestHelper.create_table( - cls._table_id, cls.db.url, cls._columns, cls._table_data) + cls._table_id, cls.db.url, cls.get_columns(), cls._table_data) cls._cache_client_retries = 3 + @classmethod + def get_columns(cls): + """Returns fresh column objects each time it's called.""" + return [ + Column("id", Integer, primary_key=True), + Column("name", String(255), nullable=False), + Column("quantity", Integer, nullable=False), + Column("distribution_center_id", Integer, nullable=False), + ] + @pytest.fixture def cache_container(self): self._start_cache_container() @@ -222,7 +267,7 @@ def _start_cache_container(self): @classmethod def tearDownClass(cls): cls._engine.dispose(close=True) - CloudSQLEnrichmentTestHelper.stop_sql_db_container(cls.db.container) + CloudSQLEnrichmentTestHelper.stop_sql_db_container(cls.db) cls._engine = None def test_cloudsql_enrichment(self): @@ -235,15 +280,19 @@ def test_cloudsql_enrichment(self): beam.Row(id=1, name='A'), beam.Row(id=2, name='B'), ] + + query_config = TableFieldsQueryConfig( + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_fields=fields) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.id, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {}", - where_clause_fields=fields, + query_config=query_config, min_batch_size=1, max_batch_size=100, ) @@ -262,15 +311,19 @@ def test_cloudsql_enrichment_batched(self): beam.Row(id=1, name='A'), beam.Row(id=2, name='B'), ] + + query_config = TableFieldsQueryConfig( + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_fields=fields) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {}", - where_clause_fields=fields, + query_config=query_config, min_batch_size=2, max_batch_size=100, ) @@ -289,15 +342,19 @@ def test_cloudsql_enrichment_batched_multiple_fields(self): beam.Row(id=1, distribution_center_id=3), beam.Row(id=2, distribution_center_id=1), ] + + query_config = TableFieldsQueryConfig( + table_id=self._table_id, + where_clause_template="id = {} AND distribution_center_id = {}", + where_clause_fields=fields) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {} AND distribution_center_id = {}", - where_clause_fields=fields, + query_config=query_config, min_batch_size=8, max_batch_size=100, ) @@ -316,13 +373,16 @@ def test_cloudsql_enrichment_with_query_fn(self): beam.Row(id=2, name='B'), ] fn = functools.partial(query_fn, self._table_id) + + query_config = CustomQueryConfig(query_fn=fn) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - query_fn=fn) + query_config=query_config) with TestPipeline(is_integration_test=True) as test_pipeline: pcoll = (test_pipeline | beam.Create(requests) | Enrichment(handler)) @@ -337,15 +397,19 @@ def test_cloudsql_enrichment_with_condition_value_fn(self): beam.Row(id=1, name='A'), beam.Row(id=2, name='B'), ] + + query_config = TableFunctionQueryConfig( + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {}", - where_clause_value_fn=where_clause_value_fn, + query_config=query_config, min_batch_size=2, max_batch_size=100) with TestPipeline(is_integration_test=True) as test_pipeline: @@ -358,15 +422,19 @@ def test_cloudsql_enrichment_table_nonexistent_runtime_error_raised(self): beam.Row(id=1, name='A'), beam.Row(id=2, name='B'), ] + + query_config = TableFunctionQueryConfig( + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {}", - where_clause_value_fn=where_clause_value_fn, + query_config=query_config, column_names=["wrong_column"], ) with self.assertRaises(RuntimeError): @@ -388,15 +456,19 @@ def test_cloudsql_enrichment_with_redis(self): beam.Row(id=1, name="A", quantity=2, distribution_center_id=3), beam.Row(id=2, name="B", quantity=3, distribution_center_id=1) ] + + query_config = TableFunctionQueryConfig( + table_id=self._table_id, + where_clause_template="id = {}", + where_clause_value_fn=where_clause_value_fn) + handler = CloudSQLEnrichmentHandler( database_type_adapter=self.db.adapter, database_address=self.db.address, database_user=self.db.user, database_password=self.db.password, database_id=self.db.id, - table_id=self._table_id, - where_clause_template="id = {}", - where_clause_value_fn=where_clause_value_fn, + query_config=query_config, min_batch_size=2, max_batch_size=100) with TestPipeline(is_integration_test=True) as test_pipeline: @@ -433,5 +505,23 @@ def test_cloudsql_enrichment_with_redis(self): CloudSQLEnrichmentHandler.__call__ = actual +@init_db_type(DatabaseTypeAdapter.POSTGRESQL) +@pytest.mark.uses_testcontainer +class TestCloudSQLEnrichmentPostgres(BaseTestCloudSQLEnrichment): + _table_id = "product_details_pg" + + +@init_db_type(DatabaseTypeAdapter.MYSQL) +@pytest.mark.uses_testcontainer +class TestCloudSQLEnrichmentMySQL(BaseTestCloudSQLEnrichment): + _table_id = "product_details_mysql" + + +@init_db_type(DatabaseTypeAdapter.SQLSERVER) +@pytest.mark.uses_testcontainer +class TestCloudSQLEnrichmentSQLServer(BaseTestCloudSQLEnrichment): + _table_id = "product_details_mssql" + + if __name__ == "__main__": unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py index 888886479c75..f3da541cd89a 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -21,7 +21,11 @@ # pylint: disable=ungrouped-imports try: from apache_beam.transforms.enrichment_handlers.cloudsql import ( - CloudSQLEnrichmentHandler, DatabaseTypeAdapter) + CloudSQLEnrichmentHandler, + DatabaseTypeAdapter, + CustomQueryConfig, + TableFieldsQueryConfig, + TableFunctionQueryConfig) from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import ( query_fn, where_clause_value_fn, @@ -32,25 +36,40 @@ class TestCloudSQLEnrichment(unittest.TestCase): @parameterized.expand([ - ("", "", [], None, None, 1, 2), - ("table", "", ["id"], where_clause_value_fn, None, 2, 10), - ("table", "id='{}'", ["id"], where_clause_value_fn, None, 2, 10), - ("table", "id='{}'", ["id"], None, query_fn, 2, 10), + # Empty TableFieldsQueryConfig. + ( + TableFieldsQueryConfig( + table_id="", where_clause_template="", where_clause_fields=[]), + 1, + 2), + # Missing where_clause_template in TableFieldsQueryConfig. + ( + TableFieldsQueryConfig( + table_id="table", + where_clause_template="", + where_clause_fields=["id"]), + 2, + 10), + # Invalid CustomQueryConfig with None query_fn. + (CustomQueryConfig(query_fn=None), 2, 10), # type: ignore[arg-type] + # Missing table_id in TableFunctionQueryConfig. + ( + TableFunctionQueryConfig( + table_id="", + where_clause_template="id='{}'", + where_clause_value_fn=where_clause_value_fn), + 2, + 10), ]) - def test_valid_params( - self, - table_id, - where_clause_template, - where_clause_fields, - where_clause_value_fn, - query_fn, - min_batch_size, - max_batch_size): + def test_invalid_query_config( + self, query_config, min_batch_size, max_batch_size): """ - TC 1: Only batch size are provided. It should raise an error. - TC 2: Either of `where_clause_template` or `query_fn` is not provided. - TC 3: Both `where_clause_fields` and `where_clause_value_fn` are provided. - TC 4: Query construction details are provided along with `query_fn`. + TC 1: Empty TableFieldsQueryConfig. + + It should raise an error. + TC 2: Missing where_clause_template in TableFieldsQueryConfig. + TC 3: Invalid CustomQueryConfig with None query_fn. + TC 4: Missing table_id in TableFunctionQueryConfig. """ with self.assertRaises(ValueError): _ = CloudSQLEnrichmentHandler( @@ -59,15 +78,85 @@ def test_valid_params( database_user='', database_password='', database_id='', - table_id=table_id, - where_clause_template=where_clause_template, - where_clause_fields=where_clause_fields, - where_clause_value_fn=where_clause_value_fn, - query_fn=query_fn, + query_config=query_config, min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) + def test_valid_query_configs(self): + """Test valid query configuration cases.""" + # Valid TableFieldsQueryConfig. + table_fields_config = TableFieldsQueryConfig( + table_id="my_table", + where_clause_template="id = '{}'", + where_clause_fields=["id"]) + + handler1 = CloudSQLEnrichmentHandler( + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_address='localhost', + database_user='user', + database_password='password', + database_id='db', + query_config=table_fields_config, + min_batch_size=1, + max_batch_size=10) + + self.assertEqual( + handler1.query_template, "SELECT * FROM my_table WHERE id = '{}'") + + # Valid TableFunctionQueryConfig. + table_function_config = TableFunctionQueryConfig( + table_id="my_table", + where_clause_template="id = '{}'", + where_clause_value_fn=where_clause_value_fn) + + handler2 = CloudSQLEnrichmentHandler( + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_address='localhost', + database_user='user', + database_password='password', + database_id='db', + query_config=table_function_config, + min_batch_size=1, + max_batch_size=10) + + self.assertEqual( + handler2.query_template, "SELECT * FROM my_table WHERE id = '{}'") + + # Valid CustomQueryConfig. + custom_config = CustomQueryConfig(query_fn=query_fn) + + handler3 = CloudSQLEnrichmentHandler( + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_address='localhost', + database_user='user', + database_password='password', + database_id='db', + query_config=custom_config) + + # Verify that batching kwargs are empty for CustomQueryConfig. + self.assertEqual(handler3.batch_elements_kwargs(), {}) + + def test_custom_query_config_cache_key_error(self): + """Test get_cache_key raises NotImplementedError with CustomQueryConfig.""" + custom_config = CustomQueryConfig(query_fn=query_fn) + + handler = CloudSQLEnrichmentHandler( + database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, + database_address='localhost', + database_user='user', + database_password='password', + database_id='db', + query_config=custom_config) + + # Create a dummy request + import apache_beam as beam + request = beam.Row(id=1) + + # Verify that get_cache_key raises NotImplementedError + with self.assertRaises(NotImplementedError): + handler.get_cache_key(request) + if __name__ == '__main__': unittest.main() From abb89b73f3c4c370898eecf371c20ab05545b6d0 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Wed, 28 May 2025 19:38:38 +0000 Subject: [PATCH 15/18] sdks/python: address claudevdm feedback (3) --- .../enrichment_handlers/cloudsql.py | 193 +++++++++--------- .../enrichment_handlers/cloudsql_it_test.py | 26 ++- .../enrichment_handlers/cloudsql_test.py | 64 ++++-- .../py310/base_image_requirements.txt | 1 + .../py311/base_image_requirements.txt | 1 + .../py312/base_image_requirements.txt | 1 + .../py39/base_image_requirements.txt | 1 + sdks/python/setup.py | 1 + 8 files changed, 171 insertions(+), 117 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index c676f92ecdbd..c4491b43dacd 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -22,7 +22,6 @@ from typing import List from typing import Optional from typing import Union -from typing import cast from sqlalchemy import create_engine from sqlalchemy import text @@ -39,6 +38,10 @@ class CustomQueryConfig: """Configuration for using a custom query function.""" query_fn: QueryFn + def __post_init__(self): + if not self.query_fn: + raise ValueError("CustomQueryConfig must provide a valid query_fn") + @dataclass class TableFieldsQueryConfig: @@ -47,6 +50,18 @@ class TableFieldsQueryConfig: where_clause_template: str where_clause_fields: List[str] + def __post_init__(self): + if not self.table_id or not self.where_clause_template: + raise ValueError( + "TableFieldsQueryConfig and " + + "TableFunctionQueryConfig must provide table_id " + + "and where_clause_template") + + if not self.where_clause_fields: + raise ValueError( + "TableFieldsQueryConfig must provide non-empty " + + "where_clause_fields") + @dataclass class TableFunctionQueryConfig: @@ -55,47 +70,27 @@ class TableFunctionQueryConfig: where_clause_template: str where_clause_value_fn: ConditionValueFn - -QueryConfig = Union[CustomQueryConfig, - TableFieldsQueryConfig, - TableFunctionQueryConfig] - - -def _validate_query_config(query_config: QueryConfig): - """Validates the provided query configuration.""" - if isinstance(query_config, CustomQueryConfig): - if not query_config.query_fn: - raise ValueError("CustomQueryConfig must provide a valid query_fn") - elif isinstance(query_config, - (TableFieldsQueryConfig, TableFunctionQueryConfig)): - if not query_config.table_id or not query_config.where_clause_template: + def __post_init__(self): + if not self.table_id or not self.where_clause_template: raise ValueError( "TableFieldsQueryConfig and " + "TableFunctionQueryConfig must provide table_id " + "and where_clause_template") - is_table_fields = isinstance(query_config, TableFieldsQueryConfig) - if is_table_fields: - table_fields_config = cast(TableFieldsQueryConfig, query_config) - if not table_fields_config.where_clause_fields: - raise ValueError( - "TableFieldsQueryConfig must provide non-empty " + - "where_clause_fields") + if not self.where_clause_value_fn: + raise ValueError( + "TableFunctionQueryConfig must provide " + "where_clause_value_fn") + - is_table_function = isinstance(query_config, TableFunctionQueryConfig) - if is_table_function: - table_function_config = cast(TableFunctionQueryConfig, query_config) - if not table_function_config.where_clause_value_fn: - raise ValueError( - "TableFunctionQueryConfig must provide " + "where_clause_value_fn") - else: - raise ValueError("Invalid query_config type provided") +QueryConfig = Union[CustomQueryConfig, + TableFieldsQueryConfig, + TableFunctionQueryConfig] class DatabaseTypeAdapter(Enum): POSTGRESQL = "psycopg2" MYSQL = "pymysql" - SQLSERVER = "pymysql" + SQLSERVER = "pymssql" def to_sqlalchemy_dialect(self): """Map the adapter type to its corresponding SQLAlchemy dialect. @@ -190,7 +185,6 @@ def __init__( * Ensure that the database user has the necessary permissions to query the specified table. """ - _validate_query_config(query_config) self._database_type_adapter = database_type_adapter self._database_id = database_id self._database_user = database_user @@ -235,71 +229,80 @@ def _execute_query(self, query: str, is_batch: bool, **params): f'table exists. {e}') def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): + """Handle requests by delegating to single or batch processing.""" if isinstance(request, list): - values, responses = [], [] - requests_map: dict[Any, Any] = {} - batch_size = len(request) - raw_query = self.query_template - - # For multiple requests in the batch, combine the WHERE clause conditions - # using 'OR' and update the query template to handle all requests. - table_query_configs = (TableFieldsQueryConfig, TableFunctionQueryConfig) - if batch_size > 1 and isinstance(self._query_config, table_query_configs): - where_clause_template_batched = ' OR '.join( - [fr'({self._query_config.where_clause_template})'] * batch_size) - raw_query = self.query_template.replace( - self._query_config.where_clause_template, - where_clause_template_batched) - - # Extract where_clause_fields values and map the generated request key to - # the original request object. - for req in request: - request_dict = req._asdict() - try: - if isinstance(self._query_config, TableFunctionQueryConfig): - current_values = self._query_config.where_clause_value_fn(req) - elif isinstance(self._query_config, TableFieldsQueryConfig): - current_values = [ - request_dict[field] - for field in self._query_config.where_clause_fields - ] - except KeyError as e: - raise KeyError( - "Make sure the values passed in `where_clause_fields` are the " - "keys in the input `beam.Row`." + str(e)) - values.extend(current_values) - requests_map[self.create_row_key(req)] = req - - # Formulate the query, execute it, and return a list of original requests - # paired with their responses. - query = raw_query.format(*values) - responses_dict = self._execute_query(query, is_batch=True) - for response in responses_dict: - response_row = beam.Row(**response) - response_key = self.create_row_key(response_row) - if response_key in requests_map: - responses.append((requests_map[response_key], response_row)) - return responses + return self._process_batch_request(request) else: - request_dict = request._asdict() - if isinstance(self._query_config, CustomQueryConfig): - query = self._query_config.query_fn(request) - else: - try: - if isinstance(self._query_config, TableFunctionQueryConfig): - values = self._query_config.where_clause_value_fn(request) - elif isinstance(self._query_config, TableFieldsQueryConfig): - values = [ - request_dict[field] - for field in self._query_config.where_clause_fields - ] - except KeyError as e: - raise KeyError( - "Make sure the values passed in `where_clause_fields` are the " - "keys in the input `beam.Row`." + str(e)) - query = self.query_template.format(*values) - response_dict = self._execute_query(query, is_batch=False) - return request, beam.Row(**response_dict) + return self._process_single_request(request) + + def _process_batch_request(self, requests: list[beam.Row]): + """Process batch requests and match responses to original requests.""" + values, responses = [], [] + requests_map: dict[Any, Any] = {} + batch_size = len(requests) + raw_query = self.query_template + + # For multiple requests in the batch, combine the WHERE clause conditions + # using 'OR' and update the query template to handle all requests. + table_query_configs = (TableFieldsQueryConfig, TableFunctionQueryConfig) + if batch_size > 1 and isinstance(self._query_config, table_query_configs): + where_clause_template_batched = ' OR '.join( + [fr'({self._query_config.where_clause_template})'] * batch_size) + raw_query = self.query_template.replace( + self._query_config.where_clause_template, + where_clause_template_batched) + + # Extract where_clause_fields values and map the generated request key to + # the original request object. + for req in requests: + request_dict = req._asdict() + try: + if isinstance(self._query_config, TableFunctionQueryConfig): + current_values = self._query_config.where_clause_value_fn(req) + elif isinstance(self._query_config, TableFieldsQueryConfig): + current_values = [ + request_dict[field] + for field in self._query_config.where_clause_fields + ] + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are " + " thekeys in the input `beam.Row`." + str(e)) + values.extend(current_values) + requests_map[self.create_row_key(req)] = req + + # Formulate the query, execute it, and return a list of original requests + # paired with their responses. + query = raw_query.format(*values) + responses_dict = self._execute_query(query, is_batch=True) + for response in responses_dict: + response_row = beam.Row(**response) + response_key = self.create_row_key(response_row) + if response_key in requests_map: + responses.append((requests_map[response_key], response_row)) + return responses + + def _process_single_request(self, request: beam.Row): + """Process a single request and return with its response.""" + request_dict = request._asdict() + if isinstance(self._query_config, CustomQueryConfig): + query = self._query_config.query_fn(request) + else: + try: + if isinstance(self._query_config, TableFunctionQueryConfig): + values = self._query_config.where_clause_value_fn(request) + elif isinstance(self._query_config, TableFieldsQueryConfig): + values = [ + request_dict[field] + for field in self._query_config.where_clause_fields + ] + except KeyError as e: + raise KeyError( + "Make sure the values passed in `where_clause_fields` are " + "the keys in the input `beam.Row`." + str(e)) + query = self.query_template.format(*values) + response_dict = self._execute_query(query, is_batch=False) + return request, beam.Row(**response_dict) def create_row_key(self, row: beam.Row): if isinstance(self._query_config, TableFunctionQueryConfig): diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index e989be7079ba..6c12e5888c05 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -88,25 +88,39 @@ def start_sql_db_container( for i in range(sql_client_retries): try: if database_type == DatabaseTypeAdapter.POSTGRESQL: - sql_db_container = PostgresContainer(image="postgres:16") + user, password, db_id = "test", "test", "test" + sql_db_container = PostgresContainer( + image="postgres:16", + user=user, + password=password, + dbname=db_id, + driver=database_type.value) sql_db_container.start() host = sql_db_container.get_container_host_ip() port = sql_db_container.get_exposed_port(5432) - user, password, db_id = "test", "test", "test" elif database_type == DatabaseTypeAdapter.MYSQL: - sql_db_container = MySqlContainer(image="mysql:8.0") + user, password, db_id = "test", "test", "test" + sql_db_container = MySqlContainer( + image="mysql:8.0", + MYSQL_USER=user, + MYSQL_ROOT_PASSWORD=password, + MYSQL_PASSWORD=password, + MYSQL_DATABASE=db_id) sql_db_container.start() host = sql_db_container.get_container_host_ip() port = sql_db_container.get_exposed_port(3306) - user, password, db_id = "test", "test", "test" elif database_type == DatabaseTypeAdapter.SQLSERVER: - sql_db_container = SqlServerContainer() + user, password, db_id = "SA", "A_Str0ng_Required_Password", "tempdb" + sql_db_container = SqlServerContainer( + image="mcr.microsoft.com/mssql/server:2022-latest", + user=user, + password=password, + dbname=db_id) sql_db_container.start() host = sql_db_container.get_container_host_ip() port = sql_db_container.get_exposed_port(1433) - user, password, db_id = "sa", "A_Str0ng_Required_Password", "tempdb" else: raise ValueError(f"Unsupported database type: {database_type}") diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py index f3da541cd89a..30aeca4d04f6 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_test.py @@ -38,40 +38,70 @@ class TestCloudSQLEnrichment(unittest.TestCase): @parameterized.expand([ # Empty TableFieldsQueryConfig. ( - TableFieldsQueryConfig( + lambda: TableFieldsQueryConfig( table_id="", where_clause_template="", where_clause_fields=[]), 1, - 2), + 2, + "must provide table_id and where_clause_template" + ), # Missing where_clause_template in TableFieldsQueryConfig. ( - TableFieldsQueryConfig( + lambda: TableFieldsQueryConfig( table_id="table", where_clause_template="", where_clause_fields=["id"]), 2, - 10), + 10, + "must provide table_id and where_clause_template" + ), # Invalid CustomQueryConfig with None query_fn. - (CustomQueryConfig(query_fn=None), 2, 10), # type: ignore[arg-type] + ( + lambda: CustomQueryConfig(query_fn=None), # type: ignore[arg-type] + 2, + 10, + "must provide a valid query_fn" + ), # Missing table_id in TableFunctionQueryConfig. ( - TableFunctionQueryConfig( + lambda: TableFunctionQueryConfig( table_id="", where_clause_template="id='{}'", where_clause_value_fn=where_clause_value_fn), 2, - 10), + 10, + "must provide table_id and where_clause_template" + ), + # Missing where_clause_fields in TableFieldsQueryConfig. + ( + lambda: TableFieldsQueryConfig( + table_id="table", + where_clause_template="id = '{}'", + where_clause_fields=[]), + 1, + 10, + "must provide non-empty where_clause_fields" + ), + # Missing where_clause_value_fn in TableFunctionQueryConfig. + ( + lambda: TableFunctionQueryConfig( + table_id="table", + where_clause_template="id = '{}'", + where_clause_value_fn=None), # type: ignore[arg-type] + 1, + 10, + "must provide where_clause_value_fn" + ), ]) def test_invalid_query_config( - self, query_config, min_batch_size, max_batch_size): - """ - TC 1: Empty TableFieldsQueryConfig. - - It should raise an error. - TC 2: Missing where_clause_template in TableFieldsQueryConfig. - TC 3: Invalid CustomQueryConfig with None query_fn. - TC 4: Missing table_id in TableFunctionQueryConfig. + self, create_config, min_batch_size, max_batch_size, expected_error_msg): + """Test that validation errors are raised for invalid query configs. + + The test verifies both that the appropriate ValueError is raised and that + the error message contains the expected text. """ - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as context: + # Call the lambda to create the config. + query_config = create_config() _ = CloudSQLEnrichmentHandler( database_type_adapter=DatabaseTypeAdapter.POSTGRESQL, database_address='', @@ -82,6 +112,8 @@ def test_invalid_query_config( min_batch_size=min_batch_size, max_batch_size=max_batch_size, ) + # Verify the error message contains the expected text. + self.assertIn(expected_error_msg, str(context.exception)) def test_valid_query_configs(self): """Test valid query configuration cases.""" diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index d8c84479b50f..08ec0b990773 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -119,6 +119,7 @@ pluggy==1.5.0 proto-plus==1.26.1 protobuf==5.29.4 psycopg2-binary==2.9.9 +pymssql==2.3.4 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index dc0c3f2b95a6..6bc29d658707 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -117,6 +117,7 @@ pluggy==1.5.0 proto-plus==1.26.1 protobuf==5.29.4 psycopg2-binary==2.9.9 +pymssql==2.3.4 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index a34f4ccef489..ef04011e46e6 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -116,6 +116,7 @@ pluggy==1.5.0 proto-plus==1.26.1 protobuf==5.29.4 psycopg2-binary==2.9.9 +pymssql==2.3.4 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 6be1fdd3b0d4..7cb9eae1b12f 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -119,6 +119,7 @@ pluggy==1.5.0 proto-plus==1.26.1 protobuf==5.29.4 psycopg2-binary==2.9.9 +pymssql==2.3.4 pyarrow==16.1.0 pyarrow-hotfix==0.6 pyasn1==0.6.1 diff --git a/sdks/python/setup.py b/sdks/python/setup.py index e1756176093b..1afc48cb8967 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -418,6 +418,7 @@ def get_portability_package_data(): 'sqlalchemy>=1.3,<3.0', 'psycopg2-binary>=2.8.5,<3.0.0,!=2.9.10', 'testcontainers[mysql]>=3.0.3,<4.0.0', + 'pymssql>=2.3.4,<3.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0', From bc84430bdc14e102fb1a23222b58236d14024eb0 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 6 Jun 2025 14:55:04 +0000 Subject: [PATCH 16/18] multi: fix pymssql mac os deps --- sdks/python/scripts/install_macos_deps.sh | 45 +++++++++++++++++++++++ sdks/python/setup.py | 2 +- sdks/python/tox.ini | 1 + 3 files changed, 47 insertions(+), 1 deletion(-) create mode 100755 sdks/python/scripts/install_macos_deps.sh diff --git a/sdks/python/scripts/install_macos_deps.sh b/sdks/python/scripts/install_macos_deps.sh new file mode 100755 index 000000000000..a8210048fcc2 --- /dev/null +++ b/sdks/python/scripts/install_macos_deps.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -euo pipefail + +# Create temporary directory. +TMPDIR=$(mktemp -d -t install_macos_deps.XXXXXX) +cleanup() { + rm -rf "$TMPDIR" +} +trap cleanup EXIT + +echo "Using temporary directory: $TMPDIR" +cd "$TMPDIR" + +# Download and extract FreeTDS. +curl -LO https://www.freetds.org/files/stable/freetds-1.5.2.tar.gz +tar -xzf freetds-1.5.2.tar.gz +cd freetds-1.5.2 + +# Configure, build and install FreeTDS. +./configure --prefix="$HOME/freetds" --with-tdsver=7.4 +make +make install + +# Set environment variables for pymssql installation. +export CFLAGS="-I$HOME/freetds/include" +export LDFLAGS="-L$HOME/freetds/lib" + +echo "FreeTDS installed to $HOME/freetds" diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 656983042a96..76b50d71337a 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -431,7 +431,7 @@ def get_portability_package_data(): 'psycopg2-binary>=2.8.5,<2.9.10; python_version <= "3.9"', 'psycopg2-binary>=2.8.5,<3.0; python_version >= "3.10"', 'testcontainers[mysql,kafka]>=3.0.3,<4.0.0', - 'pymssql>=2.3.4,<3.0.0', + 'pymssql>=2.3.4,<3.0.0; python_version >= "3.8"', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0', diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 5131769509d9..741a5d74b215 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -84,6 +84,7 @@ commands_pre = python --version pip --version # pip check + bash {toxinidir}/scripts/install_macos_deps.sh bash {toxinidir}/scripts/run_tox_cleanup.sh commands = python apache_beam/examples/complete/autocomplete_test.py From 7a56f0e73b3196bfcc7a17d50c494620a815daac Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 6 Jun 2025 16:36:29 +0000 Subject: [PATCH 17/18] multi: use pytds sql server database adapter In this commit, we use pytds microsoft sqlserver adapter instead of pymssql for one main reason it supports cross-platform compatibility epsecially with macOS so we don't need to install FreeTDS C dependencies. --- .../enrichment_handlers/cloudsql.py | 2 +- .../enrichment_handlers/cloudsql_it_test.py | 3 +- .../py310/base_image_requirements.txt | 1 - .../py311/base_image_requirements.txt | 1 - .../py312/base_image_requirements.txt | 1 - .../py39/base_image_requirements.txt | 1 - sdks/python/scripts/install_macos_deps.sh | 45 ------------------- sdks/python/setup.py | 1 - sdks/python/tox.ini | 1 - 9 files changed, 3 insertions(+), 53 deletions(-) delete mode 100755 sdks/python/scripts/install_macos_deps.sh diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py index c4491b43dacd..39ef2eb50fe7 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql.py @@ -90,7 +90,7 @@ def __post_init__(self): class DatabaseTypeAdapter(Enum): POSTGRESQL = "psycopg2" MYSQL = "pymysql" - SQLSERVER = "pymssql" + SQLSERVER = "pytds" def to_sqlalchemy_dialect(self): """Map the adapter type to its corresponding SQLAlchemy dialect. diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index 6c12e5888c05..f44819c115af 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -117,7 +117,8 @@ def start_sql_db_container( image="mcr.microsoft.com/mssql/server:2022-latest", user=user, password=password, - dbname=db_id) + dbname=db_id, + dialect=database_type.to_sqlalchemy_dialect()) sql_db_container.start() host = sql_db_container.get_container_host_ip() port = sql_db_container.get_exposed_port(1433) diff --git a/sdks/python/container/py310/base_image_requirements.txt b/sdks/python/container/py310/base_image_requirements.txt index 3699d37e32e2..71dbfb4dd9cb 100644 --- a/sdks/python/container/py310/base_image_requirements.txt +++ b/sdks/python/container/py310/base_image_requirements.txt @@ -124,7 +124,6 @@ protobuf==5.29.4 psycopg2-binary==2.9.10 pyarrow==18.1.0 pyarrow-hotfix==0.7 -pymssql==2.3.4 pyasn1==0.6.1 pyasn1_modules==0.4.2 pycparser==2.22 diff --git a/sdks/python/container/py311/base_image_requirements.txt b/sdks/python/container/py311/base_image_requirements.txt index 3f08c4816855..4a818db73073 100644 --- a/sdks/python/container/py311/base_image_requirements.txt +++ b/sdks/python/container/py311/base_image_requirements.txt @@ -122,7 +122,6 @@ protobuf==5.29.4 psycopg2-binary==2.9.10 pyarrow==18.1.0 pyarrow-hotfix==0.7 -pymssql==2.3.4 pyasn1==0.6.1 pyasn1_modules==0.4.2 pycparser==2.22 diff --git a/sdks/python/container/py312/base_image_requirements.txt b/sdks/python/container/py312/base_image_requirements.txt index d146010af0c4..44a3e8d21046 100644 --- a/sdks/python/container/py312/base_image_requirements.txt +++ b/sdks/python/container/py312/base_image_requirements.txt @@ -121,7 +121,6 @@ protobuf==5.29.4 psycopg2-binary==2.9.10 pyarrow==18.1.0 pyarrow-hotfix==0.7 -pymssql==2.3.4 pyasn1==0.6.1 pyasn1_modules==0.4.2 pycparser==2.22 diff --git a/sdks/python/container/py39/base_image_requirements.txt b/sdks/python/container/py39/base_image_requirements.txt index 1fab7e2d2e28..6591f108a99e 100644 --- a/sdks/python/container/py39/base_image_requirements.txt +++ b/sdks/python/container/py39/base_image_requirements.txt @@ -124,7 +124,6 @@ protobuf==5.29.4 psycopg2-binary==2.9.9 pyarrow==18.1.0 pyarrow-hotfix==0.7 -pymssql==2.3.4 pyasn1==0.6.1 pyasn1_modules==0.4.2 pycparser==2.22 diff --git a/sdks/python/scripts/install_macos_deps.sh b/sdks/python/scripts/install_macos_deps.sh deleted file mode 100755 index a8210048fcc2..000000000000 --- a/sdks/python/scripts/install_macos_deps.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -set -euo pipefail - -# Create temporary directory. -TMPDIR=$(mktemp -d -t install_macos_deps.XXXXXX) -cleanup() { - rm -rf "$TMPDIR" -} -trap cleanup EXIT - -echo "Using temporary directory: $TMPDIR" -cd "$TMPDIR" - -# Download and extract FreeTDS. -curl -LO https://www.freetds.org/files/stable/freetds-1.5.2.tar.gz -tar -xzf freetds-1.5.2.tar.gz -cd freetds-1.5.2 - -# Configure, build and install FreeTDS. -./configure --prefix="$HOME/freetds" --with-tdsver=7.4 -make -make install - -# Set environment variables for pymssql installation. -export CFLAGS="-I$HOME/freetds/include" -export LDFLAGS="-L$HOME/freetds/lib" - -echo "FreeTDS installed to $HOME/freetds" diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 76b50d71337a..a98eaab33361 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -431,7 +431,6 @@ def get_portability_package_data(): 'psycopg2-binary>=2.8.5,<2.9.10; python_version <= "3.9"', 'psycopg2-binary>=2.8.5,<3.0; python_version >= "3.10"', 'testcontainers[mysql,kafka]>=3.0.3,<4.0.0', - 'pymssql>=2.3.4,<3.0.0; python_version >= "3.8"', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0', diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 741a5d74b215..5131769509d9 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -84,7 +84,6 @@ commands_pre = python --version pip --version # pip check - bash {toxinidir}/scripts/install_macos_deps.sh bash {toxinidir}/scripts/run_tox_cleanup.sh commands = python apache_beam/examples/complete/autocomplete_test.py From cc7c7f095c5124f66b22e5fbb1c7ec63033c22cc Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 6 Jun 2025 18:10:28 +0000 Subject: [PATCH 18/18] sdks/python: use `VARCHAR` datatype --- .../transforms/enrichment_handlers/cloudsql_it_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py index f44819c115af..8c6c1f9feb37 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/cloudsql_it_test.py @@ -37,7 +37,7 @@ from testcontainers.mssql import SqlServerContainer from testcontainers.redis import RedisContainer from sqlalchemy import ( - create_engine, MetaData, Table, Column, Integer, String, Engine) + create_engine, MetaData, Table, Column, Integer, VARCHAR, Engine) from apache_beam.transforms.enrichment import Enrichment from apache_beam.transforms.enrichment_handlers.cloudsql import ( CloudSQLEnrichmentHandler, @@ -245,7 +245,7 @@ def get_columns(cls): """Returns fresh column objects each time it's called.""" return [ Column("id", Integer, primary_key=True), - Column("name", String(255), nullable=False), + Column("name", VARCHAR(255), nullable=False), Column("quantity", Integer, nullable=False), Column("distribution_center_id", Integer, nullable=False), ]