Skip to content
Draft
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public Table createTable(
DatabaseClient dbClient = getDatabaseClient();
Dialect dialect = dbClient.getDialect();
SpannerInformationSchema schemaInfo = createSchemaInfo(dialect);
String ddl = schemaInfo.toDdl(ident, schema);
String ddl = schemaInfo.createTableDdl(ident, schema);
DatabaseAdminClient dbAdminClient = spanner.getDatabaseAdminClient();
OperationFuture<Void, UpdateDatabaseDdlMetadata> op =
dbAdminClient.updateDatabaseDdl(
Expand Down Expand Up @@ -216,7 +216,7 @@ public Table alterTable(Identifier ident, TableChange... changes) {
public boolean dropTable(Identifier ident) {
DatabaseClient dbClient = getDatabaseClient();
SpannerInformationSchema schemaInfo = createSchemaInfo(dbClient.getDialect());
String ddl = "DROP TABLE " + schemaInfo.quoteIdentifier(ident.name());
String ddl = schemaInfo.dropTableDdl(ident.name());

DatabaseAdminClient dbAdminClient = spanner.getDatabaseAdminClient();
OperationFuture<Void, UpdateDatabaseDdlMetadata> op =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public interface SpannerInformationSchema {

String sparkTypeToSpannerType(StructField field);

default String toDdl(Identifier ident, StructType schema) {
default String createTableDdl(Identifier ident, StructType schema) {
StringBuilder ddl = new StringBuilder();
ddl.append("CREATE TABLE ").append(quoteIdentifier(ident.name())).append(" (");
for (StructField field : schema.fields()) {
Expand Down Expand Up @@ -70,6 +70,14 @@ default String toDdl(Identifier ident, StructType schema) {
return ddl.toString();
}

default Statement truncateTableDml(String tableName) {
return Statement.of("DELETE FROM " + quoteIdentifier(tableName) + " WHERE true");
}

default String dropTableDdl(String tableName) {
return "DROP TABLE " + quoteIdentifier(tableName);
}

static SpannerInformationSchema create(Dialect dialect) {
switch (dialect) {
case POSTGRESQL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public class SpannerTable implements Table, SupportsRead, SupportsWrite {
private final SpannerTableSchema dbSchema;
private final @Nullable StructType dfSchema;
private static final ImmutableSet<TableCapability> tableCapabilities =
ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE);
ImmutableSet.of(
TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE);
private final CaseInsensitiveStringMap properties;

private static final Logger log = LoggerFactory.getLogger(SpannerTable.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,127 @@

package com.google.cloud.spark.spanner;

import com.google.cloud.spanner.DatabaseAdminClient;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.DatabaseId;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.connection.Connection;
import java.util.Collections;
import java.util.concurrent.ExecutionException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.write.BatchWrite;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.SupportsTruncate;
import org.apache.spark.sql.connector.write.WriteBuilder;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SpannerWriteBuilder implements WriteBuilder {
public class SpannerWriteBuilder implements WriteBuilder, SupportsTruncate {

private static final Logger log = LoggerFactory.getLogger(SpannerWriteBuilder.class);
private final LogicalWriteInfo info;
private final StructType schema;

public SpannerWriteBuilder(LogicalWriteInfo info) {
this.info = info;
this.schema = info.schema();
}

@Override
public BatchWrite buildForBatch() {
return new SpannerBatchWrite(info);
}

@Override
public WriteBuilder truncate() {
CaseInsensitiveStringMap opts = new CaseInsensitiveStringMap(this.info.options());
String overwriteMode = opts.getOrDefault("overwriteMode", "truncate");

if (overwriteMode.equalsIgnoreCase("recreate")) {
recreateTable(opts);
} else if (overwriteMode.equalsIgnoreCase("truncate")) {
truncateTable(opts);
} else {
throw new SpannerConnectorException(
SpannerErrorCode.INVALID_ARGUMENT,
"Unsupported overwriteMode '"
+ overwriteMode
+ "'. Supported modes are 'recreate' and 'truncate'.");
}
Comment thread
MaxKsyunz marked this conversation as resolved.

return this;
}

private void recreateTable(CaseInsensitiveStringMap opts) {
String instanceId = SpannerUtils.getRequiredOption(opts, "instanceId");
String databaseId = SpannerUtils.getRequiredOption(opts, "databaseId");
String tableName = SpannerUtils.getRequiredOption(opts, "table");

try (Spanner spanner = SpannerUtils.buildSpannerOptions(opts).getService()) {
DatabaseAdminClient dbAdminClient = spanner.getDatabaseAdminClient();
Dialect dialect;
try (Connection conn = SpannerUtils.connectionFromProperties(opts.asCaseSensitiveMap())) {
dialect = conn.getDialect();
}
SpannerInformationSchema schemaInfo = SpannerInformationSchema.create(dialect);

String dropDdl = schemaInfo.dropTableDdl(tableName);
dbAdminClient
.updateDatabaseDdl(instanceId, databaseId, Collections.singletonList(dropDdl), null)
.get();

// Create the table.
Identifier ident = Identifier.of(new String[0], tableName);
String createDdl = schemaInfo.createTableDdl(ident, this.schema);
dbAdminClient
.updateDatabaseDdl(instanceId, databaseId, Collections.singletonList(createDdl), null)
.get();

} catch (InterruptedException | ExecutionException e) {
throw new SpannerConnectorException(
SpannerErrorCode.DDL_EXCEPTION, "Error recreating table " + tableName, e);
}
Comment on lines +98 to +101
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When an InterruptedException is caught, it's a best practice to restore the interrupted status of the thread by calling Thread.currentThread().interrupt(). This allows code higher up the call stack to be aware of the interruption and handle it appropriately.

Suggested change
} catch (InterruptedException | ExecutionException e) {
throw new SpannerConnectorException(
SpannerErrorCode.DDL_EXCEPTION, "Error recreating table " + tableName, e);
}
} catch (InterruptedException | ExecutionException e) {
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
throw new SpannerConnectorException(
SpannerErrorCode.DDL_EXCEPTION, "Error recreating table " + tableName, e);
}

}

private void truncateTable(CaseInsensitiveStringMap opts) {
String projectId = SpannerUtils.getRequiredOption(opts, "projectId");
String instanceId = SpannerUtils.getRequiredOption(opts, "instanceId");
String databaseId = SpannerUtils.getRequiredOption(opts, "databaseId");
String tableName = SpannerUtils.getRequiredOption(opts, "table");

try (Spanner spanner = SpannerUtils.buildSpannerOptions(opts).getService()) {
DatabaseClient dbClient =
spanner.getDatabaseClient(DatabaseId.of(projectId, instanceId, databaseId));
Dialect dialect = dbClient.getDialect();

SpannerInformationSchema informationSchema = SpannerInformationSchema.create(dialect);

truncateTable(dbClient, tableName, informationSchema);
} catch (Exception e) {
throw new SpannerConnectorException(
SpannerErrorCode.DDL_EXCEPTION, "Error truncating table " + tableName, e);
}
}

private long truncateTable(
DatabaseClient dbClient, String tableName, SpannerInformationSchema informationSchema) {

Statement statement = informationSchema.truncateTableDml(tableName);

try {
// Execute partitioned update. This is a blocking call.
long deletedRowCount = dbClient.executePartitionedUpdate(statement);
log.info("Successfully deleted " + deletedRowCount + " rows.");
return deletedRowCount;

} catch (SpannerException e) {
log.error("Failed to execute Partitioned DML on table: " + tableName, e);
throw e;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public void createTableShouldThrowExceptionOnNoPrimaryKey() {
thrown.expectMessage(
"No primary key found for table no_pk_table. Please specify at least one primary key column.");

SpannerInformationSchema.create(dialect).toDdl(ident, schema);
SpannerInformationSchema.create(dialect).createTableDdl(ident, schema);
}

@Test
Expand Down Expand Up @@ -279,7 +279,7 @@ public void testToDdl() {
new StructField("price", DataTypes.createDecimalType(10, 2), true, Metadata.empty()),
});

String ddl = SpannerInformationSchema.create(dialect).toDdl(ident, schema);
String ddl = SpannerInformationSchema.create(dialect).createTableDdl(ident, schema);

if (dialect == Dialect.POSTGRESQL) {
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,110 @@ public WriteIntegrationTest(boolean usePostgresSql) {
this.usePostgresSql = usePostgresSql;
}

@Test
public void testOverwriteRecreateMode() {
String tableName = TestData.WRITE_TABLE_NAME + "_RECREATE";
spark.sql("DROP TABLE IF EXISTS spanner." + tableName);

// 1. Define schema with primary key metadata (needed for table recreation)
StructType schema =
new StructType(
new StructField[] {
DataTypes.createStructField(
"long_col", DataTypes.LongType, false, SpannerCatalog.PRIMARY_KEY_METADATA),
DataTypes.createStructField("string_col", DataTypes.StringType, true),
});

Map<String, String> props = connectionProperties(usePostgresSql);
props.put("table", tableName);

// 2. Write initial data (creates the table via ErrorIfExists)
List<Row> initialRows =
Arrays.asList(RowFactory.create(1L, "initial-one"), RowFactory.create(2L, "initial-two"));
Dataset<Row> initialDf = spark.createDataFrame(initialRows, schema);
initialDf.write().format("cloud-spanner").options(props).mode(SaveMode.ErrorIfExists).save();

// 3. Verify initial data
Dataset<Row> dfAfterInitialWrite = spark.read().format("cloud-spanner").options(props).load();
assertEquals(2, dfAfterInitialWrite.count());

// 4. Overwrite with recreate mode
List<Row> newRows =
Arrays.asList(
RowFactory.create(3L, "new-three"),
RowFactory.create(4L, "new-four"),
RowFactory.create(5L, "new-five"));
Dataset<Row> newDf = spark.createDataFrame(newRows, schema);

Map<String, String> overwriteProps = connectionProperties(usePostgresSql);
overwriteProps.put("table", tableName);
overwriteProps.put("overwriteMode", "recreate");

newDf.write().format("cloud-spanner").options(overwriteProps).mode(SaveMode.Overwrite).save();

// 5. Verify only new data exists
Dataset<Row> finalDf = spark.read().format("cloud-spanner").options(props).load();
assertEquals(3, finalDf.count());

Map<Long, Row> finalRows =
finalDf.collectAsList().stream()
.collect(java.util.stream.Collectors.toMap(r -> r.getLong(0), r -> r));

assertThat(finalRows.get(3L).getString(1)).isEqualTo("new-three");
assertThat(finalRows.get(4L).getString(1)).isEqualTo("new-four");
assertThat(finalRows.get(5L).getString(1)).isEqualTo("new-five");
}

@Test
public void testOverwriteTruncateMode() {
String tableName = TestData.WRITE_TABLE_NAME + "_TRUNCATE";
spark.sql("DROP TABLE IF EXISTS spanner." + tableName);

// 1. Define schema with primary key metadata (needed for initial table creation)
StructType schema =
new StructType(
new StructField[] {
DataTypes.createStructField(
"long_col", DataTypes.LongType, false, SpannerCatalog.PRIMARY_KEY_METADATA),
DataTypes.createStructField("string_col", DataTypes.StringType, true),
});

Map<String, String> props = connectionProperties(usePostgresSql);
props.put("table", tableName);

// 2. Write initial data (creates the table via ErrorIfExists)
List<Row> initialRows =
Arrays.asList(RowFactory.create(1L, "initial-one"), RowFactory.create(2L, "initial-two"));
Dataset<Row> initialDf = spark.createDataFrame(initialRows, schema);
initialDf.write().format("cloud-spanner").options(props).mode(SaveMode.ErrorIfExists).save();

// 3. Verify initial data
Dataset<Row> dfAfterInitialWrite = spark.read().format("cloud-spanner").options(props).load();
assertEquals(2, dfAfterInitialWrite.count());

// 4. Overwrite with default truncate mode
List<Row> newRows =
Arrays.asList(
RowFactory.create(3L, "new-three"),
RowFactory.create(4L, "new-four"),
RowFactory.create(5L, "new-five"));
Dataset<Row> newDf = spark.createDataFrame(newRows, schema);

newDf.write().format("cloud-spanner").options(props).mode(SaveMode.Overwrite).save();

// 5. Verify only new data exists
Dataset<Row> finalDf = spark.read().format("cloud-spanner").options(props).load();
assertEquals(3, finalDf.count());

Map<Long, Row> finalRows =
finalDf.collectAsList().stream()
.collect(java.util.stream.Collectors.toMap(r -> r.getLong(0), r -> r));

assertThat(finalRows.get(3L).getString(1)).isEqualTo("new-three");
assertThat(finalRows.get(4L).getString(1)).isEqualTo("new-four");
assertThat(finalRows.get(5L).getString(1)).isEqualTo("new-five");
}

@Override
protected boolean getUsePostgreSql() {
return usePostgresSql;
Expand Down