diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java index 9deb33c6007..e831ed62369 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatement.java @@ -69,7 +69,8 @@ static BatchStatement newInstance(@NonNull BatchType batchType) { null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -100,7 +101,8 @@ static BatchStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -131,7 +133,8 @@ static BatchStatement newInstance( null, null, null, - Statement.NO_NOW_IN_SECONDS); + Statement.NO_NOW_IN_SECONDS, + null); } /** @@ -277,4 +280,13 @@ default int computeSizeInBytes(@NonNull DriverContext context) { return size; } + + /** + * Overrides LWT state to a specific value. If unset or set to {@code null} the {@link + * Statement#isLWT()} method will infer result from the statments in the batch. + * + * @param newIsLWT new Boolean to set + * @return new BatchStatement with updated isLWT field. + */ + BatchStatement setIsLWT(Boolean newIsLWT); } diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java index a8e2b8ab659..26e0aef8ca1 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/BatchStatementBuilder.java @@ -24,6 +24,7 @@ import edu.umd.cs.findbugs.annotations.NonNull; import edu.umd.cs.findbugs.annotations.Nullable; import java.util.Arrays; +import java.util.List; import net.jcip.annotations.NotThreadSafe; /** @@ -38,6 +39,7 @@ public class BatchStatementBuilder extends StatementBuilder> statementsBuilder; private int statementsCount; + @Nullable private Boolean isLWT = null; public BatchStatementBuilder(@NonNull BatchType batchType) { this.batchType = batchType; @@ -74,6 +76,19 @@ public BatchStatementBuilder setKeyspace(@NonNull String keyspaceName) { return setKeyspace(CqlIdentifier.fromCql(keyspaceName)); } + /** + * Forces driver to see this batch as LWT or non-LWT. Note that if never explicitly set or set to + * {@code null}, the resulting {@code DefaultBatchStatement} will decide its LWT state based on + * contained statements. + * + * @return this builder; never {@code null}. + */ + @NonNull + public BatchStatementBuilder setIsLWT(Boolean newIsLWT) { + this.isLWT = newIsLWT; + return this; + } + /** * Adds a new statement to the batch. * @@ -136,9 +151,10 @@ public BatchStatementBuilder clearStatements() { @Override @NonNull public BatchStatement build() { + List> statements = statementsBuilder.build(); return new DefaultBatchStatement( batchType, - statementsBuilder.build(), + statements, executionProfileName, executionProfile, keyspace, @@ -155,7 +171,8 @@ public BatchStatement build() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } public int getStatementsCount() { diff --git a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java index 77e707332a6..464a0a92a53 100644 --- a/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java +++ b/core/src/main/java/com/datastax/oss/driver/api/core/cql/Statement.java @@ -533,6 +533,9 @@ default SelfT setNowInSeconds(int nowInSeconds) { * Scylla, using too old Scylla version, future changes in driver allowing channels to be created * without sending OPTIONS request. * + *

Provided implementations of BatchStatements will be considered by driver as LWT if any of + * the statements in a batch is LWT on its own. + * *

More information about LWT: * * @see Docs about LWT diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java index 5b5d0a9d5bf..c8cb5b7a084 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatement.java @@ -30,6 +30,7 @@ import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.BatchableStatement; import com.datastax.oss.driver.api.core.cql.SimpleStatement; +import com.datastax.oss.driver.api.core.cql.Statement; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.token.Token; import com.datastax.oss.driver.shaded.guava.common.collect.ImmutableList; @@ -68,6 +69,7 @@ public class DefaultBatchStatement implements BatchStatement { private final Duration timeout; private final Node node; private final int nowInSeconds; + private final Boolean isLWT; public DefaultBatchStatement( BatchType batchType, @@ -88,7 +90,8 @@ public DefaultBatchStatement( ConsistencyLevel serialConsistencyLevel, Duration timeout, Node node, - int nowInSeconds) { + int nowInSeconds, + Boolean isLWT) { for (BatchableStatement statement : statements) { if (statement != null && (statement.getConsistencyLevel() != null @@ -120,6 +123,7 @@ public DefaultBatchStatement( this.timeout = timeout; this.node = node; this.nowInSeconds = nowInSeconds; + this.isLWT = isLWT; } @NonNull @@ -150,7 +154,8 @@ public BatchStatement setBatchType(@NonNull BatchType newBatchType) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -175,7 +180,8 @@ public BatchStatement setKeyspace(@Nullable CqlIdentifier newKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -204,7 +210,8 @@ public BatchStatement add(@NonNull BatchableStatement statement) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } } @@ -237,7 +244,8 @@ public BatchStatement addAll(@NonNull Iterable> serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } } @@ -268,7 +276,8 @@ public BatchStatement clear() { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -304,7 +313,8 @@ public BatchStatement setPagingState(ByteBuffer newPagingState) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -334,7 +344,8 @@ public BatchStatement setPageSize(int newPageSize) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -365,7 +376,8 @@ public BatchStatement setConsistencyLevel(@Nullable ConsistencyLevel newConsiste serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -397,7 +409,8 @@ public BatchStatement setSerialConsistencyLevel( newSerialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -427,7 +440,8 @@ public BatchStatement setExecutionProfileName(@Nullable String newConfigProfileN serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -457,7 +471,8 @@ public DefaultBatchStatement setExecutionProfile(@Nullable DriverExecutionProfil serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -522,7 +537,8 @@ public BatchStatement setRoutingKeyspace(CqlIdentifier newRoutingKeyspace) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -547,7 +563,8 @@ public BatchStatement setNode(@Nullable Node newNode) { serialConsistencyLevel, timeout, newNode, - nowInSeconds); + nowInSeconds, + isLWT); } @Nullable @@ -593,7 +610,8 @@ public BatchStatement setRoutingKey(ByteBuffer newRoutingKey) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -633,7 +651,8 @@ public BatchStatement setRoutingToken(Token newRoutingToken) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -664,7 +683,8 @@ public DefaultBatchStatement setCustomPayload(@NonNull Map n serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -700,7 +720,8 @@ public DefaultBatchStatement setIdempotent(Boolean newIdempotence) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -730,7 +751,8 @@ public BatchStatement setTracing(boolean newTracing) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -760,7 +782,8 @@ public BatchStatement setQueryTimestamp(long newTimestamp) { serialConsistencyLevel, timeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @NonNull @@ -785,7 +808,8 @@ public BatchStatement setTimeout(@Nullable Duration newTimeout) { serialConsistencyLevel, newTimeout, node, - nowInSeconds); + nowInSeconds, + isLWT); } @Override @@ -815,11 +839,39 @@ public BatchStatement setNowInSeconds(int newNowInSeconds) { serialConsistencyLevel, timeout, node, - newNowInSeconds); + newNowInSeconds, + isLWT); + } + + @NonNull + @Override + public BatchStatement setIsLWT(Boolean newIsLWT) { + return new DefaultBatchStatement( + batchType, + statements, + executionProfileName, + executionProfile, + keyspace, + routingKeyspace, + routingKey, + routingToken, + customPayload, + idempotent, + tracing, + timestamp, + pagingState, + pageSize, + consistencyLevel, + serialConsistencyLevel, + timeout, + node, + nowInSeconds, + newIsLWT); } @Override public boolean isLWT() { - return false; + if (isLWT != null) return isLWT; + return statements.stream().anyMatch(Statement::isLWT); } } diff --git a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java index 1150841f7a2..2377968b4fc 100644 --- a/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java +++ b/core/src/test/java/com/datastax/oss/driver/internal/core/cql/DefaultBatchStatementTest.java @@ -18,14 +18,18 @@ package com.datastax.oss.driver.internal.core.cql; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import ch.qos.logback.classic.Level; import ch.qos.logback.classic.spi.ILoggingEvent; import com.datastax.oss.driver.api.core.DefaultConsistencyLevel; +import com.datastax.oss.driver.api.core.cql.BatchStatement; import com.datastax.oss.driver.api.core.cql.BatchStatementBuilder; import com.datastax.oss.driver.api.core.cql.BatchType; +import com.datastax.oss.driver.api.core.cql.BoundStatement; import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.internal.core.util.LoggerTest; import org.junit.Test; @@ -97,4 +101,58 @@ public void should_not_issue_log_warn_if_statement_have_no_consistency_level_set verify(logger.appender, times(0)).doAppend(logger.loggingEventCaptor.capture()); } + + @Test + public void should_infer_lwt_status() { + // SELECT is not allowed in practice but is sufficient for unit testing + SimpleStatement simpleStatement = + SimpleStatement.builder("SELECT * FROM some_table WHERE a = ?").build(); + BoundStatement lwtBoundStatement = mock(DefaultBoundStatement.class); + when(lwtBoundStatement.isLWT()).thenReturn(true); + + // Without LWT statements added + BatchStatementBuilder batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); + batchStatementBuilder.addStatement(simpleStatement); + assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + + // Check if implicitly set to true after adding LWT bound statement + batchStatementBuilder.addStatement(lwtBoundStatement); + assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + + // Check if explicit set overrides implicit resolution + batchStatementBuilder.setIsLWT(false); + assertThat(batchStatementBuilder.build().isLWT()).isFalse(); + batchStatementBuilder = new BatchStatementBuilder(BatchType.UNLOGGED); + batchStatementBuilder.addStatement(simpleStatement); + batchStatementBuilder.setIsLWT(true); + assertThat(batchStatementBuilder.build().isLWT()).isTrue(); + + // Check if explicit set remains after clear + assertThat(batchStatementBuilder.build().clear().isLWT()).isTrue(); + + // Similar checks without using builder + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(lwtBoundStatement); + assertThat(batch.isLWT()).isTrue(); + batch = batch.setIsLWT(false); + assertThat(batch.isLWT()).isFalse(); + batch = batch.add(lwtBoundStatement); + assertThat(batch.isLWT()).isFalse(); + batch = batch.setIsLWT(true); + assertThat(batch.isLWT()).isTrue(); + batch = batch.clear(); + assertThat(batch.isLWT()).isTrue(); + batch = batch.setIsLWT(null); + assertThat(batch.isLWT()).isFalse(); + + assertThat(BatchStatement.newInstance(BatchType.UNLOGGED).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.LOGGED).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.COUNTER).isLWT()).isFalse(); + assertThat(BatchStatement.newInstance(BatchType.UNLOGGED, lwtBoundStatement).isLWT()).isTrue(); + assertThat(BatchStatement.newInstance(BatchType.LOGGED, lwtBoundStatement).isLWT()).isTrue(); + assertThat(BatchStatement.newInstance(BatchType.COUNTER, lwtBoundStatement).isLWT()).isTrue(); + } } diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java index 5351557bc46..a3c9d5fd3f7 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/loadbalancing/LWTLoadBalancingIT.java @@ -27,8 +27,11 @@ import com.datastax.oss.driver.api.core.CqlIdentifier; import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.ProtocolVersion; +import com.datastax.oss.driver.api.core.cql.BatchStatement; +import com.datastax.oss.driver.api.core.cql.BatchType; import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.cql.ResultSet; +import com.datastax.oss.driver.api.core.cql.SimpleStatement; import com.datastax.oss.driver.api.core.metadata.Node; import com.datastax.oss.driver.api.core.metadata.TokenMap; import com.datastax.oss.driver.api.core.type.codec.TypeCodecs; @@ -108,4 +111,51 @@ public void should_not_use_only_one_node_when_non_lwt() { // Because keyspace RF == 3 assertThat(coordinators.size()).isEqualTo(3); } + + @Test + public void should_use_only_one_node_when_lwt_batch_detected() { + assumeTrue(CcmBridge.SCYLLA_ENABLEMENT); // Functionality only available in Scylla + CqlSession session = SESSION_RULE.session(); + int pk = 1234; + ByteBuffer routingKey = TypeCodecs.INT.encodePrimitive(pk, ProtocolVersion.DEFAULT); + TokenMap tokenMap = SESSION_RULE.session().getMetadata().getTokenMap().get(); + Node owner = tokenMap.getReplicas(session.getKeyspace().get(), routingKey).iterator().next(); + PreparedStatement statement = + SESSION_RULE + .session() + .prepare("INSERT INTO foo (pk, ck, v) VALUES (?, ?, ?) IF NOT EXISTS"); + assertThat(statement.isLWT()).isTrue(); + + for (int i = 0; i < 30; i++) { + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + SimpleStatement simpleStatement = + SimpleStatement.newInstance( + String.format( + "INSERT INTO foo (pk, ck, v) VALUES (%s, %s, %s) IF NOT EXISTS", pk, i, 123)); + assertThat(simpleStatement.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + batch = batch.add(statement.bind(pk, i, 123)); + assertThat(batch.isLWT()).isTrue(); + ResultSet result = session.execute(batch); + assertThat(result.getExecutionInfo().getCoordinator()).isEqualTo(owner); + } + + // Check if multiple coordinators are used when forcibly set to non-LWT + Set coordinators = new HashSet<>(); + for (int i = 0; i < 30; i++) { + BatchStatement batch = BatchStatement.newInstance(BatchType.UNLOGGED); + SimpleStatement simpleStatement = + SimpleStatement.newInstance( + String.format( + "INSERT INTO foo (pk, ck, v) VALUES (%s, %s, %s) IF NOT EXISTS", pk, i, 123)); + assertThat(simpleStatement.isLWT()).isFalse(); + batch = batch.add(simpleStatement); + batch = batch.add(statement.bind(pk, i, 123)); + batch = batch.setIsLWT(false); + assertThat(batch.isLWT()).isFalse(); + ResultSet result = session.execute(batch); + coordinators.add(result.getExecutionInfo().getCoordinator()); + } + assertThat(coordinators.size()).isEqualTo(3); + } }