Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,10 @@ public TableInfo getReadTable(ReadTableOptions options) {
validateViewsEnabled(options);
String sql = query.get();
return materializeQueryToTable(
sql, options.expirationTimeInMinutes(), options.getQueryParameterHelper());
sql,
options.expirationTimeInMinutes(),
options.getQueryParameterHelper(),
options.getKmsKeyName());
}

TableInfo table = getTable(options.tableId());
Expand Down Expand Up @@ -504,7 +507,8 @@ public Schema getReadTableSchema(ReadTableOptions options) {
if (query.isPresent()) {
validateViewsEnabled(options);
String sql = query.get();
return getQueryResultSchema(sql, Collections.emptyMap(), options.getQueryParameterHelper());
return getQueryResultSchema(
sql, Collections.emptyMap(), options.getQueryParameterHelper(), options.getKmsKeyName());
}
TableInfo table = getReadTable(options);
return table != null ? table.getDefinition().getSchema() : null;
Expand Down Expand Up @@ -694,13 +698,20 @@ private long getNumberOfRows(String sql) {
* Runs the provided query on BigQuery and saves the result in a temporary table.
*
* @param querySql the query to be run
* @param expirationTimeInMinutes the expiration time of the temporary table
* @param queryParameterHelper the query parameter helper
* @param kmsKeyName optional KMS key name to be used for encrypting the temporary table
* @return a reference to the table
*/
public TableInfo materializeQueryToTable(
String querySql, int expirationTimeInMinutes, QueryParameterHelper queryParameterHelper) {
String querySql,
int expirationTimeInMinutes,
QueryParameterHelper queryParameterHelper,
Optional<String> kmsKeyName) {
Optional<TableId> tableId =
materializationDataset.map(ignored -> createDestinationTableWithoutReference());
return materializeTable(querySql, tableId, expirationTimeInMinutes, queryParameterHelper);
return materializeTable(
querySql, tableId, expirationTimeInMinutes, queryParameterHelper, kmsKeyName);
}

TableId createDestinationTableWithoutReference() {
Expand Down Expand Up @@ -729,6 +740,27 @@ public TableInfo materializeQueryToTable(
int expirationTimeInMinutes,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper) {
return materializeQueryToTable(
querySql,
expirationTimeInMinutes,
additionalQueryJobLabels,
queryParameterHelper,
Optional.empty());
}

/**
* Runs the provided query on BigQuery and saves the result in a temporary table.
*
* @param querySql the query to be run
* @param additionalQueryJobLabels the labels to insert on the query job
* @return a reference to the table
*/
public TableInfo materializeQueryToTable(
String querySql,
int expirationTimeInMinutes,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper,
Optional<String> kmsKeyName) {
Optional<TableId> destinationTableId =
materializationDataset.map(ignored -> createDestinationTableWithoutReference());
TempTableBuilder tableBuilder =
Expand All @@ -739,7 +771,8 @@ public TableInfo materializeQueryToTable(
expirationTimeInMinutes,
jobConfigurationFactory,
additionalQueryJobLabels,
queryParameterHelper);
queryParameterHelper,
kmsKeyName);

return materializeTable(querySql, tableBuilder);
}
Expand All @@ -760,20 +793,36 @@ public TableInfo materializeViewToTable(
createDestinationTable(
Optional.ofNullable(viewId.getProject()), Optional.ofNullable(viewId.getDataset()));
return materializeTable(
querySql, Optional.of(tableId), expirationTimeInMinutes, QueryParameterHelper.none());
querySql,
Optional.of(tableId),
expirationTimeInMinutes,
QueryParameterHelper.none(),
Optional.empty());
}

public Schema getQueryResultSchema(
String querySql,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper) {
JobInfo jobInfo =
JobInfo.of(
jobConfigurationFactory
.createParameterizedQueryJobConfigurationBuilder(
querySql, additionalQueryJobLabels, queryParameterHelper)
.setDryRun(true)
.build());
return getQueryResultSchema(
querySql, additionalQueryJobLabels, queryParameterHelper, Optional.empty());
}

public Schema getQueryResultSchema(
String querySql,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper,
Optional<String> kmsKeyName) {
QueryJobConfiguration.Builder builder =
jobConfigurationFactory
.createParameterizedQueryJobConfigurationBuilder(
querySql, additionalQueryJobLabels, queryParameterHelper)
.setDryRun(true);
kmsKeyName.ifPresent(
k ->
builder.setDestinationEncryptionConfiguration(
EncryptionConfiguration.newBuilder().setKmsKeyName(k).build()));
JobInfo jobInfo = JobInfo.of(builder.build());

log.info("running query dryRun {}", querySql);
JobInfo completedJobInfo = create(jobInfo);
Expand All @@ -789,6 +838,20 @@ private TableInfo materializeTable(
Optional<TableId> destinationTableId,
int expirationTimeInMinutes,
QueryParameterHelper queryParameterHelper) {
return materializeTable(
querySql,
destinationTableId,
expirationTimeInMinutes,
queryParameterHelper,
Optional.empty());
}

private TableInfo materializeTable(
String querySql,
Optional<TableId> destinationTableId,
int expirationTimeInMinutes,
QueryParameterHelper queryParameterHelper,
Optional<String> kmsKeyName) {
try {
return destinationTableCache.get(
querySql,
Expand All @@ -799,7 +862,8 @@ private TableInfo materializeTable(
expirationTimeInMinutes,
jobConfigurationFactory,
Collections.emptyMap(),
queryParameterHelper));
queryParameterHelper,
kmsKeyName));
} catch (Exception e) {
throw new BigQueryConnectorException(
BigQueryErrorCode.BIGQUERY_VIEW_DESTINATION_TABLE_CREATION_FAILED,
Expand Down Expand Up @@ -1006,6 +1070,10 @@ public interface ReadTableOptions {
int expirationTimeInMinutes();

QueryParameterHelper getQueryParameterHelper();

default Optional<String> getKmsKeyName() {
return Optional.empty();
}
}

public interface LoadDataOptions {
Expand Down Expand Up @@ -1083,6 +1151,7 @@ static class TempTableBuilder implements Callable<TableInfo> {
final JobConfigurationFactory jobConfigurationFactory;
final Map<String, String> additionalQueryJobLabels;
final QueryParameterHelper queryParameterHelper;
final Optional<String> kmsKeyName;

TempTableBuilder(
BigQueryClient bigQueryClient,
Expand All @@ -1092,13 +1161,34 @@ static class TempTableBuilder implements Callable<TableInfo> {
JobConfigurationFactory jobConfigurationFactory,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper) {
this(
bigQueryClient,
querySql,
tempTable,
expirationTimeInMinutes,
jobConfigurationFactory,
additionalQueryJobLabels,
queryParameterHelper,
Optional.empty());
}

TempTableBuilder(
BigQueryClient bigQueryClient,
String querySql,
Optional<TableId> tempTable,
int expirationTimeInMinutes,
JobConfigurationFactory jobConfigurationFactory,
Map<String, String> additionalQueryJobLabels,
QueryParameterHelper queryParameterHelper,
Optional<String> kmsKeyName) {
this.bigQueryClient = bigQueryClient;
this.querySql = querySql;
this.tempTable = tempTable;
this.expirationTimeInMinutes = expirationTimeInMinutes;
this.jobConfigurationFactory = jobConfigurationFactory;
this.additionalQueryJobLabels = additionalQueryJobLabels;
this.queryParameterHelper = queryParameterHelper;
this.kmsKeyName = kmsKeyName;
}

@Override
Expand All @@ -1116,6 +1206,10 @@ TableInfo createTableFromQuery() {
jobConfigurationFactory.createParameterizedQueryJobConfigurationBuilder(
querySql, additionalQueryJobLabels, queryParameterHelper);
tempTable.ifPresent(queryJobConfigurationBuilder::setDestinationTable);
kmsKeyName.ifPresent(
k ->
queryJobConfigurationBuilder.setDestinationEncryptionConfiguration(
EncryptionConfiguration.newBuilder().setKmsKeyName(k).build()));

JobInfo jobInfo = JobInfo.of(queryJobConfigurationBuilder.build());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,11 @@ public QueryParameterHelper getQueryParameterHelper() {
public int expirationTimeInMinutes() {
return SparkBigQueryConfig.this.getMaterializationExpirationTimeInMinutes();
}

@Override
public Optional<String> getKmsKeyName() {
return SparkBigQueryConfig.this.getKmsKeyName();
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,28 @@ public void testGetAnyOptionWithFallbackOnlyNewConfigExist() {
assertThat(config.getMaterializationProject()).isEqualTo(Optional.of("foo"));
}

@Test
public void testKmsKeyPropagationToReadTableOptions() {
String kmsKeyName = "projects/p/locations/l/keyRings/k/cryptoKeys/c";
DataSourceOptions options =
new DataSourceOptions(
ImmutableMap.of("table", "dataset.table", "destinationTableKmsKeyName", kmsKeyName));
SparkBigQueryConfig config =
SparkBigQueryConfig.from(
options.asMap(),
defaultGlobalOptions,
new Configuration(),
ImmutableMap.of(),
DEFAULT_PARALLELISM,
new SQLConf(),
SPARK_VERSION,
Optional.empty(),
true);

assertThat(config.getKmsKeyName()).isEqualTo(Optional.of(kmsKeyName));
assertThat(config.toReadTableOptions().getKmsKeyName()).isEqualTo(Optional.of(kmsKeyName));
}

@Test
public void testGetAnyOptionWithFallbackBothConfigsExist() {
SparkBigQueryConfig config =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,35 @@ public void testReadWithMixedParametersFails() {

assertThat(listener.getJobInfos()).isEmpty();
}

@Test
public void testReadFromQueryWithKmsKey() {
String random = String.valueOf(System.nanoTime());
String query =
String.format(
"SELECT corpus, word_count FROM `bigquery-public-data.samples.shakespeare` WHERE word='spark' AND '%s'='%s'",
random, random);
String envKmsKey = System.getenv("BIGQUERY_KMS_KEY_NAME");
String kmsKeyName =
envKmsKey != null ? envKmsKey : "projects/p/locations/l/keyRings/k/cryptoKeys/c";
spark
.read()
.format("bigquery")
.option("viewsEnabled", true)
.option("materializationDataset", testDataset.toString())
.option("destinationTableKmsKeyName", kmsKeyName)
.load(query)
.collect();
// validate event publishing
List<JobInfo> jobInfos = listener.getJobInfos();
assertThat(jobInfos).hasSize(1);
JobInfo jobInfo = jobInfos.iterator().next();
assertThat(
((QueryJobConfiguration) jobInfo.getConfiguration())
.getDestinationEncryptionConfiguration()
.getKmsKeyName())
.isEqualTo(kmsKeyName + "/cryptoKeyVersions/1");
Comment on lines +404 to +408
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion for the KMS key name assumes a specific suffix (/cryptoKeyVersions/1). This might be too rigid. If the provided kmsKeyName is already a full resource name including a version, or if BigQuery's API returns the key name without a version suffix in some cases, this test could fail unnecessarily. It would be more robust to assert that the getKmsKeyName() value contains the provided kmsKeyName or matches it exactly if no version is expected to be appended by the API. Consider using startsWith or a more flexible comparison.

Suggested change
assertThat(
((QueryJobConfiguration) jobInfo.getConfiguration())
.getDestinationEncryptionConfiguration()
.getKmsKeyName())
.isEqualTo(kmsKeyName + "/cryptoKeyVersions/1");
assertThat(
((QueryJobConfiguration) jobInfo.getConfiguration())
.getDestinationEncryptionConfiguration()
.getKmsKeyName())
.startsWith(kmsKeyName);

}
}

class TestBigQueryJobCompletionListener extends SparkListener {
Expand Down
Loading