Skip to content

Adding project and database support in write transform for firestoreIO #35017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.gcp.firestore;

import static java.util.Objects.requireNonNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.firestore.v1.BatchGetDocumentsRequest;
import com.google.firestore.v1.BatchGetDocumentsResponse;
Expand Down Expand Up @@ -67,6 +68,7 @@
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;

Expand Down Expand Up @@ -502,6 +504,9 @@ public PartitionQuery.Builder partitionQuery() {
*/
@Immutable
public static final class Write {
private @Nullable String projectId;
private @Nullable String databaseId;

private static final Write INSTANCE = new Write();

private Write() {}
Expand Down Expand Up @@ -537,8 +542,20 @@ private Write() {}
* @see <a target="_blank" rel="noopener noreferrer"
* href="https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.BatchWriteResponse">google.firestore.v1.BatchWriteResponse</a>
*/
public Write withProjectId(String projectId) {
checkArgument(projectId != null, "projectId can not be null");
this.projectId = projectId;
return this;
}

public Write withDatabaseId(String databaseId) {
checkArgument(databaseId != null, "databaseId can not be null");
this.databaseId = databaseId;
return this;
}

public BatchWriteWithSummary.Builder batchWrite() {
return new BatchWriteWithSummary.Builder();
return new BatchWriteWithSummary.Builder().setProjectId(projectId).setDatabaseId(databaseId);
}
}

Expand Down Expand Up @@ -1348,11 +1365,18 @@ public static final class BatchWriteWithSummary
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {

private BatchWriteWithSummary(
private final @Nullable String projectId;
private final @Nullable String databaseId;

public BatchWriteWithSummary(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
RpcQosOptions rpcQosOptions,
@Nullable String projectId,
@Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
this.projectId = projectId;
this.databaseId = databaseId;
}

@Override
Expand All @@ -1365,7 +1389,9 @@ public PCollection<WriteSuccessSummary> expand(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
CounterFactory.DEFAULT)));
CounterFactory.DEFAULT,
projectId,
databaseId)));
}

@Override
Expand Down Expand Up @@ -1403,6 +1429,9 @@ public static final class Builder
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {

private @Nullable String projectId;
private @Nullable String databaseId;

private Builder() {
super();
}
Expand All @@ -1414,9 +1443,35 @@ private Builder(
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
}

/** Set the GCP project ID to be used by the Firestore client. */
private Builder setProjectId(@Nullable String projectId) {
this.projectId = projectId;
return this;
}

/** Set the Firestore database ID (e.g., "(default)"). */
private Builder setDatabaseId(@Nullable String databaseId) {
this.databaseId = databaseId;
return this;
}

@VisibleForTesting
@Nullable
String getProjectId() {
return this.projectId;
}

@VisibleForTesting
@Nullable
String getDatabaseId() {
return this.databaseId;
}

public BatchWriteWithDeadLetterQueue.Builder withDeadLetterQueue() {
return new BatchWriteWithDeadLetterQueue.Builder(
clock, firestoreStatefulComponentFactory, rpcQosOptions);
clock, firestoreStatefulComponentFactory, rpcQosOptions)
.setProjectId(projectId)
.setDatabaseId(databaseId);
}

@Override
Expand All @@ -1429,7 +1484,8 @@ BatchWriteWithSummary buildSafe(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
return new BatchWriteWithSummary(clock, firestoreStatefulComponentFactory, rpcQosOptions);
return new BatchWriteWithSummary(
clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
Expand Down Expand Up @@ -1474,11 +1530,18 @@ public static final class BatchWriteWithDeadLetterQueue
BatchWriteWithDeadLetterQueue,
BatchWriteWithDeadLetterQueue.Builder> {

private final @Nullable String projectId;
private final @Nullable String databaseId;

private BatchWriteWithDeadLetterQueue(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
RpcQosOptions rpcQosOptions,
@Nullable String projectId,
@Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
this.projectId = projectId;
this.databaseId = databaseId;
}

@Override
Expand All @@ -1490,7 +1553,9 @@ public PCollection<WriteFailure> expand(PCollection<com.google.firestore.v1.Writ
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
CounterFactory.DEFAULT)));
CounterFactory.DEFAULT,
projectId,
databaseId)));
}

@Override
Expand Down Expand Up @@ -1528,10 +1593,35 @@ public static final class Builder
BatchWriteWithDeadLetterQueue,
BatchWriteWithDeadLetterQueue.Builder> {

private @Nullable String projectId;
private @Nullable String databaseId;

private Builder() {
super();
}

private Builder setProjectId(@Nullable String projectId) {
this.projectId = projectId;
return this;
}

private Builder setDatabaseId(@Nullable String databaseId) {
this.databaseId = databaseId;
return this;
}

@VisibleForTesting
@Nullable
String getProjectId() {
return this.projectId;
}

@VisibleForTesting
@Nullable
String getDatabaseId() {
return this.databaseId;
}

private Builder(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
Expand All @@ -1550,7 +1640,7 @@ BatchWriteWithDeadLetterQueue buildSafe(
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
return new BatchWriteWithDeadLetterQueue(
clock, firestoreStatefulComponentFactory, rpcQosOptions);
clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,16 @@ static final class BatchWriteFnWithSummary extends BaseBatchWriteFn<WriteSuccess
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions, counterFactory);
CounterFactory counterFactory,
@Nullable String projectId,
@Nullable String databaseId) {
super(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
counterFactory,
projectId,
databaseId);
}

@Override
Expand Down Expand Up @@ -102,8 +110,16 @@ static final class BatchWriteFnWithDeadLetterQueue extends BaseBatchWriteFn<Writ
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions, counterFactory);
CounterFactory counterFactory,
@Nullable String projectId,
@Nullable String databaseId) {
super(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
counterFactory,
projectId,
databaseId);
}

@Override
Expand Down Expand Up @@ -158,6 +174,8 @@ abstract static class BaseBatchWriteFn<OutT> extends ExplicitlyWindowedFirestore
// bundle scoped state
private transient FirestoreStub firestoreStub;
private transient DatabaseRootName databaseRootName;
private final @Nullable String configuredProjectId;
private final @Nullable String configuredDatabaseId;

@VisibleForTesting
transient Queue<@NonNull WriteElement> writes = new PriorityQueue<>(WriteElement.COMPARATOR);
Expand All @@ -171,12 +189,16 @@ abstract static class BaseBatchWriteFn<OutT> extends ExplicitlyWindowedFirestore
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
CounterFactory counterFactory,
@Nullable String configuredProjectId,
@Nullable String configuredDatabaseId) {
this.clock = clock;
this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory;
this.rpcQosOptions = rpcQosOptions;
this.counterFactory = counterFactory;
this.rpcAttemptContext = V1FnRpcAttemptContext.BatchWrite;
this.configuredProjectId = configuredProjectId;
this.configuredDatabaseId = configuredDatabaseId;
}

@Override
Expand All @@ -202,11 +224,19 @@ public void setup() {

@Override
public final void startBundle(StartBundleContext c) {
String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();
String project =
configuredProjectId != null
? configuredProjectId
: c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();

if (project == null) {
project = c.getPipelineOptions().as(GcpOptions.class).getProject();
}
String databaseId = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();

String databaseId =
configuredDatabaseId != null
? configuredDatabaseId
: c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();
databaseRootName =
DatabaseRootName.of(
requireNonNull(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ protected BatchWriteFnWithDeadLetterQueue getFn(
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
return new BatchWriteFnWithDeadLetterQueue(clock, ff, rpcQosOptions, counterFactory);
return new BatchWriteFnWithDeadLetterQueue(
clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
when(callable.call(requestCaptor1.capture())).thenReturn(response1);

BaseBatchWriteFn<WriteSuccessSummary> fn =
new BatchWriteFnWithSummary(clock, ff, options, CounterFactory.DEFAULT);
new BatchWriteFnWithSummary(
clock, ff, options, CounterFactory.DEFAULT, "testing-project", "(default)");
fn.setup();
fn.startBundle(startBundleContext);
fn.processElement(processContext, window); // write0
Expand Down Expand Up @@ -238,13 +239,22 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
verifyNoMoreInteractions(callable);
}

@Test
public void testWithProjectId_thenWithDatabaseId() {
FirestoreV1.Write beamWrite =
FirestoreIO.v1().write().withProjectId("my-project").withDatabaseId("(default)");

assertEquals("my-project", beamWrite.batchWrite().getProjectId());
assertEquals("(default)", beamWrite.batchWrite().getDatabaseId());
}

@Override
protected BatchWriteFnWithSummary getFn(
JodaClock clock,
FirestoreStatefulComponentFactory ff,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory);
return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ abstract class BaseFirestoreIT {
.build();

protected static String project;
protected static String databaseId;

@Before
public void setup() {
project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
databaseId = "firestoredb";
}

private static Instant toWriteTime(WriteResult result) {
Expand Down Expand Up @@ -441,7 +443,14 @@ protected final void runWriteTest(
testPipeline
.apply(Create.of(Collections.singletonList(documentIds)))
.apply(createWrite)
.apply(FirestoreIO.v1().write().batchWrite().withRpcQosOptions(RPC_QOS_OPTIONS).build());
.apply(
FirestoreIO.v1()
.write()
.withProjectId(project)
.withDatabaseId(databaseId)
.batchWrite()
.withRpcQosOptions(RPC_QOS_OPTIONS)
.build());

testPipeline.run(TestPipeline.testingPipelineOptions());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public void batchWrite_partialFailureOutputsToDeadLetterQueue()
.apply(
FirestoreIO.v1()
.write()
.withProjectId(project)
.withDatabaseId(databaseId)
.batchWrite()
.withDeadLetterQueue()
.withRpcQosOptions(options)
Expand Down
Loading