Skip to content

Adding project and database support in write transform for firestoreIO #35017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.gcp.firestore;

import static java.util.Objects.requireNonNull;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.firestore.v1.BatchGetDocumentsRequest;
import com.google.firestore.v1.BatchGetDocumentsResponse;
Expand Down Expand Up @@ -67,6 +68,7 @@
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;

Expand Down Expand Up @@ -502,9 +504,17 @@ public PartitionQuery.Builder partitionQuery() {
*/
@Immutable
public static final class Write {
private static final Write INSTANCE = new Write();
private @Nullable String projectId;
private @Nullable String databaseId;

private Write() {}
public static final Write INSTANCE = new Write(null, null);

public Write() {}

public Write(@Nullable String projectId, @Nullable String databaseId) {
this.projectId = projectId;
this.databaseId = databaseId;
}

/**
* Factory method to create a new type safe builder for {@link com.google.firestore.v1.Write}
Expand Down Expand Up @@ -537,8 +547,18 @@ private Write() {}
* @see <a target="_blank" rel="noopener noreferrer"
* href="https://cloud.google.com/firestore/docs/reference/rpc/google.firestore.v1#google.firestore.v1.BatchWriteResponse">google.firestore.v1.BatchWriteResponse</a>
*/
public Write withProjectId(String projectId) {
checkArgument(projectId != null, "projectId can not be null");
return new Write(projectId, this.databaseId);
}

public Write withDatabaseId(String databaseId) {
checkArgument(databaseId != null, "databaseId can not be null");
return new Write(this.projectId, databaseId);
}

public BatchWriteWithSummary.Builder batchWrite() {
return new BatchWriteWithSummary.Builder();
return new BatchWriteWithSummary.Builder().setProjectId(projectId).setDatabaseId(databaseId);
}
}

Expand Down Expand Up @@ -1348,11 +1368,18 @@ public static final class BatchWriteWithSummary
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {

private BatchWriteWithSummary(
private final @Nullable String projectId;
private final @Nullable String databaseId;

public BatchWriteWithSummary(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
RpcQosOptions rpcQosOptions,
@Nullable String projectId,
@Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
this.projectId = projectId;
this.databaseId = databaseId;
}

@Override
Expand All @@ -1365,7 +1392,9 @@ public PCollection<WriteSuccessSummary> expand(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
CounterFactory.DEFAULT)));
CounterFactory.DEFAULT,
projectId,
databaseId)));
}

@Override
Expand Down Expand Up @@ -1403,6 +1432,9 @@ public static final class Builder
BatchWriteWithSummary,
BatchWriteWithSummary.Builder> {

private @Nullable String projectId;
private @Nullable String databaseId;

private Builder() {
super();
}
Expand All @@ -1414,9 +1446,33 @@ private Builder(
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
}

/** Set the GCP project ID to be used by the Firestore client. */
public Builder setProjectId(@Nullable String projectId) {
this.projectId = projectId;
return this;
}

/** Set the Firestore database ID (e.g., "(default)"). */
public Builder setDatabaseId(@Nullable String databaseId) {
this.databaseId = databaseId;
return this;
}

@VisibleForTesting
public @Nullable String getProjectId() {
return this.projectId;
}

@VisibleForTesting
public @Nullable String getDatabaseId() {
return this.databaseId;
}

public BatchWriteWithDeadLetterQueue.Builder withDeadLetterQueue() {
return new BatchWriteWithDeadLetterQueue.Builder(
clock, firestoreStatefulComponentFactory, rpcQosOptions);
clock, firestoreStatefulComponentFactory, rpcQosOptions)
.setProjectId(projectId)
.setDatabaseId(databaseId);
}

@Override
Expand All @@ -1429,7 +1485,8 @@ BatchWriteWithSummary buildSafe(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
return new BatchWriteWithSummary(clock, firestoreStatefulComponentFactory, rpcQosOptions);
return new BatchWriteWithSummary(
clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
Expand Down Expand Up @@ -1474,11 +1531,18 @@ public static final class BatchWriteWithDeadLetterQueue
BatchWriteWithDeadLetterQueue,
BatchWriteWithDeadLetterQueue.Builder> {

private final @Nullable String projectId;
private final @Nullable String databaseId;

private BatchWriteWithDeadLetterQueue(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
RpcQosOptions rpcQosOptions,
@Nullable String projectId,
@Nullable String databaseId) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions);
this.projectId = projectId;
this.databaseId = databaseId;
}

@Override
Expand All @@ -1490,7 +1554,9 @@ public PCollection<WriteFailure> expand(PCollection<com.google.firestore.v1.Writ
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
CounterFactory.DEFAULT)));
CounterFactory.DEFAULT,
projectId,
databaseId)));
}

@Override
Expand Down Expand Up @@ -1528,10 +1594,33 @@ public static final class Builder
BatchWriteWithDeadLetterQueue,
BatchWriteWithDeadLetterQueue.Builder> {

private @Nullable String projectId;
private @Nullable String databaseId;

private Builder() {
super();
}

public Builder setProjectId(@Nullable String projectId) {
this.projectId = projectId;
return this;
}

public Builder setDatabaseId(@Nullable String databaseId) {
this.databaseId = databaseId;
return this;
}

@VisibleForTesting
public @Nullable String getProjectId() {
return this.projectId;
}

@VisibleForTesting
public @Nullable String getDatabaseId() {
return this.databaseId;
}

private Builder(
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
Expand All @@ -1550,7 +1639,7 @@ BatchWriteWithDeadLetterQueue buildSafe(
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions) {
return new BatchWriteWithDeadLetterQueue(
clock, firestoreStatefulComponentFactory, rpcQosOptions);
clock, firestoreStatefulComponentFactory, rpcQosOptions, projectId, databaseId);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,16 @@ static final class BatchWriteFnWithSummary extends BaseBatchWriteFn<WriteSuccess
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions, counterFactory);
CounterFactory counterFactory,
@Nullable String projectId,
@Nullable String databaseId) {
super(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
counterFactory,
projectId,
databaseId);
}

@Override
Expand Down Expand Up @@ -102,8 +110,16 @@ static final class BatchWriteFnWithDeadLetterQueue extends BaseBatchWriteFn<Writ
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
super(clock, firestoreStatefulComponentFactory, rpcQosOptions, counterFactory);
CounterFactory counterFactory,
@Nullable String projectId,
@Nullable String databaseId) {
super(
clock,
firestoreStatefulComponentFactory,
rpcQosOptions,
counterFactory,
projectId,
databaseId);
}

@Override
Expand Down Expand Up @@ -158,6 +174,8 @@ abstract static class BaseBatchWriteFn<OutT> extends ExplicitlyWindowedFirestore
// bundle scoped state
private transient FirestoreStub firestoreStub;
private transient DatabaseRootName databaseRootName;
private final @Nullable String configuredProjectId;
private final @Nullable String configuredDatabaseId;

@VisibleForTesting
transient Queue<@NonNull WriteElement> writes = new PriorityQueue<>(WriteElement.COMPARATOR);
Expand All @@ -171,12 +189,16 @@ abstract static class BaseBatchWriteFn<OutT> extends ExplicitlyWindowedFirestore
JodaClock clock,
FirestoreStatefulComponentFactory firestoreStatefulComponentFactory,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory) {
CounterFactory counterFactory,
@Nullable String configuredProjectId,
@Nullable String configuredDatabaseId) {
this.clock = clock;
this.firestoreStatefulComponentFactory = firestoreStatefulComponentFactory;
this.rpcQosOptions = rpcQosOptions;
this.counterFactory = counterFactory;
this.rpcAttemptContext = V1FnRpcAttemptContext.BatchWrite;
this.configuredProjectId = configuredProjectId;
this.configuredDatabaseId = configuredDatabaseId;
}

@Override
Expand All @@ -202,11 +224,19 @@ public void setup() {

@Override
public final void startBundle(StartBundleContext c) {
String project = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();
String project =
configuredProjectId != null
? configuredProjectId
: c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreProject();

if (project == null) {
project = c.getPipelineOptions().as(GcpOptions.class).getProject();
}
String databaseId = c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();

String databaseId =
configuredDatabaseId != null
? configuredDatabaseId
: c.getPipelineOptions().as(FirestoreOptions.class).getFirestoreDb();
databaseRootName =
DatabaseRootName.of(
requireNonNull(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ protected BatchWriteFnWithDeadLetterQueue getFn(
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
return new BatchWriteFnWithDeadLetterQueue(clock, ff, rpcQosOptions, counterFactory);
return new BatchWriteFnWithDeadLetterQueue(
clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
when(callable.call(requestCaptor1.capture())).thenReturn(response1);

BaseBatchWriteFn<WriteSuccessSummary> fn =
new BatchWriteFnWithSummary(clock, ff, options, CounterFactory.DEFAULT);
new BatchWriteFnWithSummary(
clock, ff, options, CounterFactory.DEFAULT, "testing-project", "(default)");
fn.setup();
fn.startBundle(startBundleContext);
fn.processElement(processContext, window); // write0
Expand Down Expand Up @@ -238,13 +239,22 @@ public void nonRetryableWriteResultStopsAttempts() throws Exception {
verifyNoMoreInteractions(callable);
}

@Test
public void testWithProjectId_thenWithDatabaseId() {
FirestoreV1.Write beamWrite =
new FirestoreV1.Write().withProjectId("my-project").withDatabaseId("(default)");

assertEquals("my-project", beamWrite.batchWrite().getProjectId());
assertEquals("(default)", beamWrite.batchWrite().getDatabaseId());
}

@Override
protected BatchWriteFnWithSummary getFn(
JodaClock clock,
FirestoreStatefulComponentFactory ff,
RpcQosOptions rpcQosOptions,
CounterFactory counterFactory,
DistributionFactory distributionFactory) {
return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory);
return new BatchWriteFnWithSummary(clock, ff, rpcQosOptions, counterFactory, null, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ abstract class BaseFirestoreIT {
.build();

protected static String project;
protected static String databaseId;

@Before
public void setup() {
project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject();
databaseId = "firestoredb";
}

private static Instant toWriteTime(WriteResult result) {
Expand Down Expand Up @@ -441,7 +443,14 @@ protected final void runWriteTest(
testPipeline
.apply(Create.of(Collections.singletonList(documentIds)))
.apply(createWrite)
.apply(FirestoreIO.v1().write().batchWrite().withRpcQosOptions(RPC_QOS_OPTIONS).build());
.apply(
FirestoreIO.v1()
.write()
.withProjectId(project)
.withDatabaseId(databaseId)
.batchWrite()
.withRpcQosOptions(RPC_QOS_OPTIONS)
.build());

testPipeline.run(TestPipeline.testingPipelineOptions());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ public void batchWrite_partialFailureOutputsToDeadLetterQueue()
.apply(
FirestoreIO.v1()
.write()
.withProjectId(project)
.withDatabaseId(databaseId)
.batchWrite()
.withDeadLetterQueue()
.withRpcQosOptions(options)
Expand Down
Loading