From 2f81f2412ccec7fc525d23b065b7c5491d8ac3a3 Mon Sep 17 00:00:00 2001 From: Bouncheck <36934780+Bouncheck@users.noreply.github.com> Date: Thu, 3 Apr 2025 17:08:29 +0200 Subject: [PATCH] Allow BatchStatements to be LWT Previously all DefaultBatchStatements would always return `false` when `isLWT()` was called. This would cause the driver to route the batch based on the first non null routing information found in a batch but using regular rules rather than rules for LWT queries, even if a LWT query was inside the batch. Now LWT routing will be used for DefaultBatchStatements if somewhere along the way an LWT query was added to the batch. This can also be controlled explicitly regardless of batch contents with `BatchStatementBuilder#setIsLWT(boolean)`. --- .../driver/api/core/cql/BatchStatement.java | 18 +++- .../api/core/cql/BatchStatementBuilder.java | 21 +++- .../oss/driver/api/core/cql/Statement.java | 3 + .../core/cql/DefaultBatchStatement.java | 98 ++++++++++++++----- .../core/cql/DefaultBatchStatementTest.java | 58 +++++++++++ .../loadbalancing/LWTLoadBalancingIT.java | 50 ++++++++++ 6 files changed, 220 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java index 9deb33c6007..e831ed62369 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java @@ -69,7 +69,8 @@ static BatchStatement newInstance(@NonNull BatchType batchType) { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -100,7 +101,8 @@ static BatchStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -131,7 +133,8 @@ static BatchStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -277,4 +280,13 @@ default int computeSizeInBytes(@NonNull DriverContext context) { return size; } + + /** + * Overrides LWT state to a specific value. If unset or set to {@code null} the {@link + * Statement#isLWT()} method will infer result from the statments in the batch. + * + * @param newIsLWT new Boolean to set + * @return new BatchStatement with updated isLWT field. + */ + BatchStatement setIsLWT(Boolean newIsLWT); } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index a8e2b8ab659..26e0aef8ca1 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -24,6 +24,7 @@ import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; import java.util.Arrays; +import java.util.List; import net.jcip.annotations.NotThreadSafe; /** @@ -38,6 +39,7 @@ public class BatchStatementBuilder extends StatementBuilder> statementsBuilder; private int statementsCount; + @Nullable private Boolean isLWT = null; public BatchStatementBuilder(@NonNull BatchType batchType) { this.batchType = batchType; @@ -74,6 +76,19 @@ public BatchStatementBuilder setKeyspace(@NonNull String keyspaceName) { return setKeyspace(CqlIdentifier.fromCql(keyspaceName)); } + /** + * Forces driver to see this batch as LWT or non-LWT. Note that if never explicitly set or set to + * {@code null}, the resulting {@code DefaultBatchStatement} will decide its LWT state based on + * contained statements. + * + * @return this builder; never {@code null}. + */ + @NonNull + public BatchStatementBuilder setIsLWT(Boolean newIsLWT) { + this.isLWT = newIsLWT; + return this; + } + /** * Adds a new statement to the batch. * @@ -136,9 +151,10 @@ public BatchStatementBuilder clearStatements() { @Override @NonNull public BatchStatement build() { + List> statements = statementsBuilder.build(); return new DefaultBatchStatement( batchType, - statementsBuilder.build(), + statements, executionProfileName, executionProfile, keyspace, @@ -155,7 +171,8 @@ public BatchStatement build() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index 77e707332a6..464a0a92a53 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -533,6 +533,9 @@ default SelfT setNowInSeconds(int nowInSeconds) { * Scylla, using too old Scylla version, future changes in driver allowing channels to be created * without sending OPTIONS request. * + *

Provided implementations of BatchStatements will be considered by driver as LWT if any of + * the statements in a batch is LWT on its own. + * *

More information about LWT: * * @see Docs about LWT diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index 5b5d0a9d5bf..c8cb5b7a084 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.BatchableStatement; import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -68,6 +69,7 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; + private final Boolean isLWT; public DefaultBatchStatement( BatchType batchType, @@ -88,7 +90,8 @@ public DefaultBatchStatement( ConsistencyLevel serialConsistencyLevel, Duration timeout, Node node, - int nowInSeconds) { + int nowInSeconds, + Boolean isLWT) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -120,6 +123,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; + this.isLWT = isLWT; } @NonNull @@ -150,7 +154,8 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -175,7 +180,8 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -204,7 +210,8 @@ public BatchStatement add(@NonNull BatchableStatement statement) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } } @@ -237,7 +244,8 @@ public BatchStatement addAll(@NonNull Iterable> serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } } @@ -268,7 +276,8 @@ public BatchStatement clear() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -304,7 +313,8 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -334,7 +344,8 @@ public BatchStatement setPageSize(int newPageSize) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -365,7 +376,8 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -397,7 +409,8 @@ public BatchStatement setSerialConsistencyLevel( newSerialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -427,7 +440,8 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -457,7 +471,8 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -522,7 +537,8 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -547,7 +563,8 @@ public BatchStatement setNode(@Nullable Node newNode) { serialConsistencyLevel, timeout, newNode, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -593,7 +610,8 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -633,7 +651,8 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -664,7 +683,8 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -700,7 +720,8 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -730,7 +751,8 @@ public BatchStatement setTracing(boolean newTracing) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -760,7 +782,8 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -785,7 +808,8 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { serialConsistencyLevel, newTimeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -815,11 +839,39 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { serialConsistencyLevel, timeout, node, - newNowInSeconds); + newNowInSeconds, + isLWT); + } + + @NonNull + @Override + public BatchStatement setIsLWT(Boolean newIsLWT) { + return new DefaultBatchStatement( + batchType, + statements, + executionProfileName, + executionProfile, + keyspace, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + node, + nowInSeconds, + newIsLWT); } @Override public boolean isLWT() { - return false; + if (isLWT != null) return isLWT; + return statements.stream().anyMatch(Statement::isLWT); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java index 1150841f7a2..2377968b4fc 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java @@ -18,14 +18,18 @@ package com.datastax.oss.driver.internal.core.cql; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import ch.qos.logback.classic.Level; import ch.qos.logback.classic.spi.ILoggingEvent; import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; import com.datastax.oss.driver.api.core.cql.BatchType; +import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.internal.core.util.LoggerTest; import org.junit.Test; @@ -97,4 +101,58 @@ public void should_not_issue_log_warn_if_statement_have_no_consistency_level_set verify(logger.appender, times(0)).doAppend(logger.loggingEventCaptor.capture()); } + + @Test + public void should_infer_lwt_status() { + // SELECT is not allowed in practice but is sufficient for unit testing + SimpleStatement simpleStatement = + SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); + BoundStatement lwtBoundStatement = mock(DefaultBoundStatement.class); + when(lwtBoundStatement.isLWT()).thenReturn(true); + + // Without LWT statements added + BatchStatementBuilder batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); + batchStatementBuilder.addStatement(simpleStatement); + assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + + // Check if implicitly set to true after adding LWT bound statement + batchStatementBuilder.addStatement(lwtBoundStatement); + assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + + // Check if explicit set overrides implicit resolution + batchStatementBuilder.setIsLWT(false); + assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); + batchStatementBuilder.addStatement(simpleStatement); + batchStatementBuilder.setIsLWT(true); + assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + + // Check if explicit set remains after clear + assertThat(batchStatementBuilder.build().clear().isLWT()).isTrue(); + + // Similar checks without using builder + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(lwtBoundStatement); + assertThat(batch.isLWT()).isTrue(); + batch = batch.setIsLWT(false); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(lwtBoundStatement); + assertThat(batch.isLWT()).isFalse(); + batch = batch.setIsLWT(true); + assertThat(batch.isLWT()).isTrue(); + batch = batch.clear(); + assertThat(batch.isLWT()).isTrue(); + batch = batch.setIsLWT(null); + assertThat(batch.isLWT()).isFalse(); + + assertThat(BatchStatement.newInstance(BatchType.UNLOGGED).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.LOGGED).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.COUNTER).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.UNLOGGED, lwtBoundStatement).isLWT()).isTrue(); + assertThat(BatchStatement.newInstance(BatchType.LOGGED, lwtBoundStatement).isLWT()).isTrue(); + assertThat(BatchStatement.newInstance(BatchType.COUNTER, lwtBoundStatement).isLWT()).isTrue(); + } } diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java index 5351557bc46..a3c9d5fd3f7 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java @@ -27,8 +27,11 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.cql.BatchStatement; +import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.TokenMap; import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; @@ -108,4 +111,51 @@ public void should_not_use_only_one_node_when_non_lwt() { // Because keyspace RF == 3 assertThat(coordinators.size()).isEqualTo(3); } + + @Test + public void should_use_only_one_node_when_lwt_batch_detected() { + assumeTrue(CcmBridge.SCYLLA_ENABLEMENT); // Functionality only available in Scylla + CqlSession session = SESSION_RULE.session(); + int pk = 1234; + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Node owner = tokenMap.getReplicas(session.getKeyspace().get(), routingKey).iterator().next(); + PreparedStatement statement = + SESSION_RULE + .session() + .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + assertThat(statement.isLWT()).isTrue(); + + for (int i = 0; i < 30; i++) { + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + SimpleStatement simpleStatement = + SimpleStatement.newInstance( + String.format( + "INSERT INTO foo (pk, ck, v) VALUES (%s, %s, %s) IF NOT EXISTS", pk, i, 123)); + assertThat(simpleStatement.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + batch = batch.add(statement.bind(pk, i, 123)); + assertThat(batch.isLWT()).isTrue(); + ResultSet result = session.execute(batch); + assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + } + + // Check if multiple coordinators are used when forcibly set to non-LWT + Set coordinators = new HashSet<>(); + for (int i = 0; i < 30; i++) { + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + SimpleStatement simpleStatement = + SimpleStatement.newInstance( + String.format( + "INSERT INTO foo (pk, ck, v) VALUES (%s, %s, %s) IF NOT EXISTS", pk, i, 123)); + assertThat(simpleStatement.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + batch = batch.add(statement.bind(pk, i, 123)); + batch = batch.setIsLWT(false); + assertThat(batch.isLWT()).isFalse(); + ResultSet result = session.execute(batch); + coordinators.add(result.getExecutionInfo().getCoordinator()); + } + assertThat(coordinators.size()).isEqualTo(3); + } }