Skip to content

Commit 76708d5

Browse files
authored
added ddl to SqlTransform (#34614)
* added ddl to SqlTransform * uncommented tests
1 parent bd47dc9 commit 76708d5

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

sdks/java/extensions/sql/expansion-service/src/main/java/org/apache/beam/sdk/extensions/sql/expansion/ExternalSqlTransformRegistrar.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ public class ExternalSqlTransformRegistrar implements ExternalTransformRegistrar
5050
public static class Configuration {
5151
String query = "";
5252
@Nullable String dialect;
53+
@Nullable String ddl;
5354

5455
public void setQuery(String query) {
5556
this.query = query;
@@ -58,6 +59,10 @@ public void setQuery(String query) {
5859
public void setDialect(@Nullable String dialect) {
5960
this.dialect = dialect;
6061
}
62+
63+
public void setDdl(@Nullable String ddl) {
64+
this.ddl = ddl;
65+
}
6166
}
6267

6368
private static class Builder
@@ -76,6 +81,10 @@ public PTransform<PInput, PCollection<Row>> buildExternal(Configuration configur
7681
}
7782
transform = transform.withQueryPlannerClass(queryPlanner);
7883
}
84+
// Add any DDL string
85+
if (configuration.ddl != null) {
86+
transform = transform.withDdlString(configuration.ddl);
87+
}
7988
return transform;
8089
}
8190
}

sdks/python/apache_beam/transforms/sql.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
__all__ = ['SqlTransform']
2929

3030
SqlTransformSchema = typing.NamedTuple(
31-
'SqlTransformSchema', [('query', str), ('dialect', typing.Optional[str])])
31+
'SqlTransformSchema',
32+
[('query', str), ('dialect', typing.Optional[str]),
33+
('ddl', typing.Optional[str])])
3234

3335

3436
class SqlTransform(ExternalTransform):
@@ -75,18 +77,19 @@ class SqlTransform(ExternalTransform):
7577
"""
7678
URN = 'beam:external:java:sql:v1'
7779

78-
def __init__(self, query, dialect=None, expansion_service=None):
80+
def __init__(self, query, dialect=None, ddl=None, expansion_service=None):
7981
"""
8082
Creates a SqlTransform which will be expanded to Java's SqlTransform.
8183
(See class docs).
8284
:param query: The SQL query.
8385
:param dialect: (optional) The dialect, e.g. use 'zetasql' for ZetaSQL.
86+
:param ddl: (optional) The DDL statement.
8487
:param expansion_service: (optional) The URL of the expansion service to use
8588
"""
8689
expansion_service = expansion_service or BeamJarExpansionService(
8790
':sdks:java:extensions:sql:expansion-service:shadowJar')
8891
super().__init__(
8992
self.URN,
9093
NamedTupleBasedPayloadBuilder(
91-
SqlTransformSchema(query=query, dialect=dialect)),
94+
SqlTransformSchema(query=query, dialect=dialect, ddl=ddl)),
9295
expansion_service=expansion_service)

sdks/python/apache_beam/transforms/sql_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,30 @@ def test_map(self):
209209
| SqlTransform("SELECT * FROM PCOLLECTION WHERE shopper = 'alice'"))
210210
assert_that(out, equal_to([('alice', {'apples': 2, 'bananas': 3})]))
211211

212+
def test_sql_ddl_set_option(self):
213+
with TestPipeline() as p:
214+
input_data = [
215+
beam.Row(id=1, value=10),
216+
beam.Row(id=2, value=20),
217+
beam.Row(id=3, value=30)
218+
]
219+
# DDL uses SET to modify a session option (tests DDL parsing)
220+
# Using a known Calcite option like sqlConformance
221+
ddl_statement = """
222+
SET sqlConformance = 'LENIENT'
223+
"""
224+
# Query still operates on the implicit PCOLLECTION
225+
query_statement = "SELECT * FROM PCOLLECTION WHERE id > 2"
226+
227+
# Input PCollection is piped directly
228+
out = (
229+
p | beam.Create(input_data)
230+
# Pass both the query and the DDL
231+
| SqlTransform(query=query_statement, ddl=ddl_statement))
232+
233+
# Verify the output matches the query (unaffected by the SET DDL)
234+
assert_that(out, equal_to([(3, 30)]))
235+
212236

213237
if __name__ == "__main__":
214238
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)