Skip to content

[Java]Add Support for Named Schemas in SpannerIO.Write by Enhancing Schema Retrieval #34261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,8 @@
public class ReadSpannerSchema extends DoFn<Void, SpannerSchema> {

private final SpannerConfig config;

private final PCollectionView<Dialect> dialectView;

private final Set<String> allowedTableNames;

private transient SpannerAccessor spannerAccessor;

/**
Expand Down Expand Up @@ -93,76 +90,79 @@ public void processElement(ProcessContext c) throws Exception {
ResultSet resultSet = readTableInfo(tx, dialect);

while (resultSet.next()) {
String tableName = resultSet.getString(0);
String columnName = resultSet.getString(1);
String type = resultSet.getString(2);
long cellsMutated = resultSet.getLong(3);
if (allowedTableNames.size() > 0 && !allowedTableNames.contains(tableName)) {
// If we want to filter out table names, and the current table name is not part
// of the allowed names, we exclude it.
String schemaName = resultSet.getString(0); // TABLE_SCHEMA
String tableName = resultSet.getString(1); // TABLE_NAME
String fullTableName = schemaName.isEmpty() ? tableName : schemaName + "." + tableName;
String columnName = resultSet.getString(2); // COLUMN_NAME
String type = resultSet.getString(3); // SPANNER_TYPE
long cellsMutated = resultSet.getLong(4); // CELLS_MUTATED

// Apply allowedTableNames filter on full table name if specified
if (allowedTableNames.size() > 0 && !allowedTableNames.contains(fullTableName)) {
continue;
}
builder.addColumn(tableName, columnName, type, cellsMutated);
builder.addColumn(fullTableName, columnName, type, cellsMutated);
}

resultSet = readPrimaryKeyInfo(tx, dialect);
while (resultSet.next()) {
String tableName = resultSet.getString(0);
String columnName = resultSet.getString(1);
String ordering = resultSet.getString(2);

builder.addKeyPart(tableName, columnName, "DESC".equalsIgnoreCase(ordering));
String schemaName = resultSet.getString(0); // TABLE_SCHEMA
String tableName = resultSet.getString(1); // TABLE_NAME
String fullTableName = schemaName.isEmpty() ? tableName : schemaName + "." + tableName;
String columnName = resultSet.getString(2); // COLUMN_NAME
String ordering = resultSet.getString(3); // COLUMN_ORDERING

// Apply allowedTableNames filter on full table name if specified
if (allowedTableNames.size() > 0 && !allowedTableNames.contains(fullTableName)) {
continue;
}
builder.addKeyPart(fullTableName, columnName, "DESC".equalsIgnoreCase(ordering));
}
}
c.output(builder.build());
}

private ResultSet readTableInfo(ReadOnlyTransaction tx, Dialect dialect) {
// retrieve schema information for all tables, as well as aggregating the
// number of indexes that cover each column. this will be used to estimate
// the number of cells (table column plus indexes) mutated in an upsert operation
// in order to stay below the 20k threshold
// Retrieve schema information for all tables across all schemas, including the number of
// indexes covering each column to estimate cells mutated in upserts.
String statement = "";
switch (dialect) {
case GOOGLE_STANDARD_SQL:
statement =
"SELECT"
+ " c.table_name"
+ " , c.column_name"
+ " , c.spanner_type"
+ " , (1 + COALESCE(t.indices, 0)) AS cells_mutated"
+ " FROM ("
+ " SELECT c.table_name, c.column_name, c.spanner_type, c.ordinal_position"
+ " FROM information_schema.columns as c"
+ " WHERE c.table_catalog = '' AND c.table_schema = '') AS c"
+ " LEFT OUTER JOIN ("
+ " SELECT t.table_name, t.column_name, COUNT(*) AS indices"
+ " FROM information_schema.index_columns AS t "
+ " WHERE t.index_name != 'PRIMARY_KEY' AND t.table_catalog = ''"
+ " AND t.table_schema = ''"
+ " GROUP BY t.table_name, t.column_name) AS t"
+ " USING (table_name, column_name)"
+ " ORDER BY c.table_name, c.ordinal_position";
"SELECT "
+ " c.TABLE_SCHEMA, "
+ " c.TABLE_NAME, "
+ " c.COLUMN_NAME, "
+ " c.SPANNER_TYPE, "
+ " (1 + COALESCE(t.indices, 0)) AS cells_mutated "
+ "FROM INFORMATION_SCHEMA.COLUMNS AS c "
+ "LEFT OUTER JOIN ("
+ " SELECT t.TABLE_SCHEMA, t.TABLE_NAME, t.COLUMN_NAME, COUNT(*) AS indices "
+ " FROM INFORMATION_SCHEMA.INDEX_COLUMNS AS t "
+ " WHERE t.INDEX_NAME != 'PRIMARY_KEY' AND t.TABLE_CATALOG = '' "
+ " GROUP BY t.TABLE_SCHEMA, t.TABLE_NAME, t.COLUMN_NAME "
+ ") AS t "
+ "ON c.TABLE_SCHEMA = t.TABLE_SCHEMA AND c.TABLE_NAME = t.TABLE_NAME AND c.COLUMN_NAME = t.COLUMN_NAME "
+ "WHERE c.TABLE_CATALOG = '' "
+ "ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION";
break;
case POSTGRESQL:
statement =
"SELECT"
+ " c.table_name"
+ " , c.column_name"
+ " , c.spanner_type"
+ " , (1 + COALESCE(t.indices, 0)) AS cells_mutated"
+ " FROM ("
+ " SELECT c.table_name, c.column_name, c.spanner_type, c.ordinal_position"
+ " FROM information_schema.columns as c"
+ " WHERE c.table_schema='public') AS c"
+ " LEFT OUTER JOIN ("
+ " SELECT t.table_name, t.column_name, COUNT(*) AS indices"
+ " FROM information_schema.index_columns AS t "
+ " WHERE t.index_name != 'PRIMARY_KEY'"
+ " AND t.table_schema='public'"
+ " GROUP BY t.table_name, t.column_name) AS t"
+ " USING (table_name, column_name)"
+ " ORDER BY c.table_name, c.ordinal_position";
"SELECT "
+ " c.TABLE_SCHEMA, "
+ " c.TABLE_NAME, "
+ " c.COLUMN_NAME, "
+ " c.SPANNER_TYPE, "
+ " (1 + COALESCE(t.indices, 0)) AS cells_mutated "
+ "FROM INFORMATION_SCHEMA.COLUMNS AS c "
+ "LEFT OUTER JOIN ("
+ " SELECT t.TABLE_SCHEMA, t.TABLE_NAME, t.COLUMN_NAME, COUNT(*) AS indices "
+ " FROM INFORMATION_SCHEMA.INDEX_COLUMNS AS t "
+ " WHERE t.INDEX_NAME != 'PRIMARY_KEY' "
+ " GROUP BY t.TABLE_SCHEMA, t.TABLE_NAME, t.COLUMN_NAME "
+ ") AS t "
+ "ON c.TABLE_SCHEMA = t.TABLE_SCHEMA AND c.TABLE_NAME = t.TABLE_NAME AND c.COLUMN_NAME = t.COLUMN_NAME "
+ "ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION";
break;
default:
throw new IllegalArgumentException("Unrecognized dialect: " + dialect.name());
Expand All @@ -175,19 +175,25 @@ private ResultSet readPrimaryKeyInfo(ReadOnlyTransaction tx, Dialect dialect) {
switch (dialect) {
case GOOGLE_STANDARD_SQL:
statement =
"SELECT t.table_name, t.column_name, t.column_ordering"
+ " FROM information_schema.index_columns AS t "
+ " WHERE t.index_name = 'PRIMARY_KEY' AND t.table_catalog = ''"
+ " AND t.table_schema = ''"
+ " ORDER BY t.table_name, t.ordinal_position";
"SELECT "
+ " t.TABLE_SCHEMA, "
+ " t.TABLE_NAME, "
+ " t.COLUMN_NAME, "
+ " t.COLUMN_ORDERING "
+ "FROM INFORMATION_SCHEMA.INDEX_COLUMNS AS t "
+ "WHERE t.INDEX_NAME = 'PRIMARY_KEY' AND t.TABLE_CATALOG = '' "
+ "ORDER BY t.TABLE_SCHEMA, t.TABLE_NAME, t.ORDINAL_POSITION";
break;
case POSTGRESQL:
statement =
"SELECT t.table_name, t.column_name, t.column_ordering"
+ " FROM information_schema.index_columns AS t "
+ " WHERE t.index_name = 'PRIMARY_KEY'"
+ " AND t.table_schema='public'"
+ " ORDER BY t.table_name, t.ordinal_position";
"SELECT "
+ " t.TABLE_SCHEMA, "
+ " t.TABLE_NAME, "
+ " t.COLUMN_NAME, "
+ " t.COLUMN_ORDERING "
+ "FROM INFORMATION_SCHEMA.INDEX_COLUMNS AS t "
+ "WHERE t.INDEX_NAME = 'PRIMARY_KEY' "
+ "ORDER BY t.TABLE_SCHEMA, t.TABLE_NAME, t.ORDINAL_POSITION";
break;
default:
throw new IllegalArgumentException("Unrecognized dialect: " + dialect.name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.TimestampBound;
import com.google.cloud.spanner.Value;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -1142,6 +1144,11 @@ public enum FailureMode {
/**
* A {@link PTransform} that writes {@link Mutation} objects to Google Cloud Spanner.
*
* <p>Input {@link Mutation} objects should specify table names that include schema prefixes
* (e.g., "my_schema.my_table") when using named schemas. If unqualified names are provided (e.g.,
* "my_table"), the transform will attempt to rewrite them to match the schema-qualified names
* retrieved from the database, but this may fail if the table is ambiguous or not found.
*
* @see SpannerIO
*/
@AutoValue
Expand Down Expand Up @@ -1931,6 +1938,94 @@ public void processElement(ProcessContext c) {
}
}

@VisibleForTesting
static class RewriteMutationTableNamesFn extends DoFn<MutationGroup, MutationGroup> {
private final PCollectionView<SpannerSchema> schemaView;

RewriteMutationTableNamesFn(PCollectionView<SpannerSchema> schemaView) {
this.schemaView = schemaView;
}

@ProcessElement
public void processElement(ProcessContext c) {
SpannerSchema spannerSchema = c.sideInput(schemaView);
MutationGroup mg = c.element();
List<Mutation> rewrittenMutations = new ArrayList<>();

for (Mutation m : mg) {
String tableName = m.getTable();
// Check if the table exists in the schema (case-insensitive lookup)
String qualifiedTableName = findQualifiedTableName(spannerSchema, tableName);
if (qualifiedTableName == null) {
LOG.warn("Table {} not found in Spanner schema; using as-is", tableName);
rewrittenMutations.add(m);
} else if (!qualifiedTableName.equalsIgnoreCase(tableName)) {
// Rewrite the mutation with the schema-qualified table name
rewrittenMutations.add(rewriteMutation(m, qualifiedTableName));
} else {
rewrittenMutations.add(m); // Already qualified, no rewrite needed
}
}

c.output(
MutationGroup.create(
rewrittenMutations.get(0), rewrittenMutations.subList(1, rewrittenMutations.size())));
}

private Mutation rewriteMutation(Mutation m, String newTableName) {
Op op = m.getOperation();
switch (op) {
case INSERT:
Mutation.WriteBuilder insertBuilder = Mutation.newInsertBuilder(newTableName);
setValues(insertBuilder, m.getColumns(), m.getValues());
return insertBuilder.build();
case UPDATE:
Mutation.WriteBuilder updateBuilder = Mutation.newUpdateBuilder(newTableName);
setValues(updateBuilder, m.getColumns(), m.getValues());
return updateBuilder.build();
case INSERT_OR_UPDATE:
Mutation.WriteBuilder insertOrUpdateBuilder =
Mutation.newInsertOrUpdateBuilder(newTableName);
setValues(insertOrUpdateBuilder, m.getColumns(), m.getValues());
return insertOrUpdateBuilder.build();
case REPLACE:
Mutation.WriteBuilder replaceBuilder = Mutation.newReplaceBuilder(newTableName);
setValues(replaceBuilder, m.getColumns(), m.getValues());
return replaceBuilder.build();
case DELETE:
return Mutation.delete(newTableName, m.getKeySet());
default:
throw new IllegalArgumentException("Unsupported mutation operation: " + op);
}
}

private void setValues(
Mutation.WriteBuilder builder, Iterable<String> columns, Iterable<Value> values) {
Iterator<String> columnIter = columns.iterator();
Iterator<Value> valueIter = values.iterator();
while (columnIter.hasNext() && valueIter.hasNext()) {
String column = columnIter.next();
Value value = valueIter.next();
builder.set(column).to(value);
}
if (columnIter.hasNext() || valueIter.hasNext()) {
throw new IllegalStateException("Mismatch between number of columns and values");
}
}

private String findQualifiedTableName(SpannerSchema schema, String tableName) {
for (String fullTableName : schema.getTables()) {
if (fullTableName.equalsIgnoreCase(tableName)
|| fullTableName
.substring(fullTableName.lastIndexOf('.') + 1)
.equalsIgnoreCase(tableName)) {
return fullTableName;
}
}
return null;
}
}

/**
* Gathers a set of mutations together, gets the keys, encodes them to byte[], sorts them and then
* outputs the encoded sorted list.
Expand Down
Loading
Loading