Skip to content

Commit 3c63df7

Browse files
committed
[misc] batch correction
1 parent 3b25c89 commit 3c63df7

File tree

7 files changed

+174
-100
lines changed

7 files changed

+174
-100
lines changed

src/main/java/org/mariadb/r2dbc/MariadbClientParameterizedQueryStatement.java

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ public MariadbClientParameterizedQueryStatement bind(
8080
@SuppressWarnings({"rawtypes", "unchecked"})
8181
@Override
8282
public MariadbClientParameterizedQueryStatement bind(int index, @Nullable Object value) {
83+
if (value == null) return bindNull(index, null);
8384
if (index >= prepareResult.getParamCount() || index < 0) {
8485
throw new IndexOutOfBoundsException(
8586
String.format(
8687
"index must be in 0-%d range but value is %d",
8788
prepareResult.getParamCount() - 1, index));
8889
}
89-
if (value == null) return bindNull(index, null);
9090

9191
for (Codec<?> codec : Codecs.LIST) {
9292
if (codec.canEncode(value)) {
@@ -142,28 +142,36 @@ public Flux<org.mariadb.r2dbc.api.MariadbResult> execute() {
142142
return execute(this.sql, this.prepareResult, parameters, this.generatedColumns);
143143
} else {
144144
add();
145-
Flux<Flux<ServerMessage>> fluxMsg =
146-
Flux.create(
147-
sink -> {
148-
for (Parameter<?>[] parameters : this.batchingParameters) {
149-
Flux<ServerMessage> in =
150-
this.client.sendCommand(
151-
new QueryWithParametersPacket(
152-
prepareResult,
153-
parameters,
154-
generatedColumns != null
155-
&& client.getVersion().isMariaDBServer()
156-
&& client.getVersion().versionGreaterOrEqual(10, 5, 1)
157-
? generatedColumns
158-
: null));
159-
sink.next(in);
160-
in.subscribe();
161-
}
162-
sink.complete();
163-
});
145+
146+
Flux<ServerMessage> fluxMsg =
147+
this.client.sendCommand(
148+
new QueryWithParametersPacket(
149+
prepareResult,
150+
this.batchingParameters.get(0),
151+
generatedColumns != null
152+
&& client.getVersion().isMariaDBServer()
153+
&& client.getVersion().versionGreaterOrEqual(10, 5, 1)
154+
? generatedColumns
155+
: null));
156+
int index = 1;
157+
while (index < this.batchingParameters.size()) {
158+
fluxMsg =
159+
fluxMsg.concatWith(
160+
this.client.sendCommand(
161+
new QueryWithParametersPacket(
162+
prepareResult,
163+
this.batchingParameters.get(index++),
164+
generatedColumns != null
165+
&& client.getVersion().isMariaDBServer()
166+
&& client.getVersion().versionGreaterOrEqual(10, 5, 1)
167+
? generatedColumns
168+
: null)));
169+
}
170+
171+
this.batchingParameters.clear();
172+
this.parameters = new Parameter<?>[prepareResult.getParamCount()];
164173

165174
return fluxMsg
166-
.flatMap(Flux::from)
167175
.windowUntil(it -> it.resultSetEnd())
168176
.map(
169177
dataRow ->

src/main/java/org/mariadb/r2dbc/MariadbServerParameterizedQueryStatement.java

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.Arrays;
21+
import java.util.HashMap;
2122
import java.util.List;
23+
import java.util.Map;
2224
import org.mariadb.r2dbc.api.MariadbStatement;
2325
import org.mariadb.r2dbc.client.Client;
2426
import org.mariadb.r2dbc.client.DecoderState;
@@ -40,8 +42,8 @@ final class MariadbServerParameterizedQueryStatement implements MariadbStatement
4042
private final Client client;
4143
private final String sql;
4244
private final MariadbConnectionConfiguration configuration;
43-
private List<Parameter<?>> parameters;
44-
private List<List<Parameter<?>>> batchingParameters;
45+
private Map<Integer, Parameter<?>> parameters;
46+
private List<Map<Integer, Parameter<?>>> batchingParameters;
4547
private String[] generatedColumns;
4648
private ServerPrepareResult prepareResult;
4749

@@ -50,7 +52,7 @@ final class MariadbServerParameterizedQueryStatement implements MariadbStatement
5052
this.client = client;
5153
this.configuration = configuration;
5254
this.sql = Assert.requireNonNull(sql, "sql must not be null");
53-
this.parameters = new ArrayList<>();
55+
this.parameters = new HashMap<>();
5456
this.prepareResult = client.getPrepareCache().get(sql);
5557
}
5658

@@ -72,7 +74,7 @@ public MariadbServerParameterizedQueryStatement add() {
7274
}
7375
if (batchingParameters == null) batchingParameters = new ArrayList<>();
7476
batchingParameters.add(parameters);
75-
parameters = new ArrayList<>(prepareResult.getNumParams());
77+
parameters = new HashMap<>();
7678
return this;
7779
}
7880

@@ -96,7 +98,8 @@ public MariadbServerParameterizedQueryStatement bind(int index, @Nullable Object
9698

9799
for (Codec<?> codec : Codecs.LIST) {
98100
if (codec.canEncode(value)) {
99-
parameters.add(index, (Parameter<?>) new Parameter(codec, value));
101+
102+
parameters.put(index, (Parameter<?>) new Parameter(codec, value));
100103
return this;
101104
}
102105
}
@@ -121,7 +124,7 @@ public MariadbServerParameterizedQueryStatement bindNull(int index, @Nullable Cl
121124
prepareResult.getNumParams() - 1, index));
122125
}
123126

124-
parameters.add(index, Parameter.NULL_PARAMETER);
127+
parameters.put(index, Parameter.NULL_PARAMETER);
125128
return this;
126129
}
127130

@@ -157,41 +160,39 @@ public Flux<org.mariadb.r2dbc.api.MariadbResult> execute() {
157160
sendPrepare().block();
158161
}
159162
}
163+
Flux<ServerMessage> fluxMsg =
164+
this.client.sendCommand(
165+
new ExecutePacket(prepareResult.getStatementId(), this.batchingParameters.get(0)));
166+
int index = 1;
167+
while (index < this.batchingParameters.size()) {
168+
fluxMsg =
169+
fluxMsg.concatWith(
170+
this.client.sendCommand(
171+
new ExecutePacket(
172+
prepareResult.getStatementId(), this.batchingParameters.get(index++))));
173+
}
174+
fluxMsg =
175+
fluxMsg.concatWith(
176+
Flux.create(
177+
sink -> {
178+
prepareResult.decrementUse(client);
179+
sink.complete();
180+
}));
160181

161-
Flux<Flux<ServerMessage>> fluxMsg =
162-
Flux.create(
163-
sink -> {
164-
for (List<Parameter<?>> parameters : this.batchingParameters) {
165-
Flux<ServerMessage> in =
166-
this.client.sendCommand(
167-
new ExecutePacket(
168-
prepareResult != null ? prepareResult.getStatementId() : -1,
169-
parameters));
170-
sink.next(in);
171-
in.subscribe();
172-
}
173-
sink.complete();
174-
});
175-
176-
Flux<org.mariadb.r2dbc.api.MariadbResult> f =
177-
fluxMsg
178-
.flatMap(Flux::from)
179-
.windowUntil(it -> it.resultSetEnd())
180-
.map(
181-
dataRow ->
182-
new MariadbResult(
183-
true,
184-
dataRow,
185-
ExceptionFactory.INSTANCE,
186-
null,
187-
client.getVersion().isMariaDBServer()
188-
&& client.getVersion().versionGreaterOrEqual(10, 5, 1)));
189-
return f.concatWith(
190-
Flux.create(
191-
sink -> {
192-
prepareResult.decrementUse(client);
193-
sink.complete();
194-
}));
182+
this.batchingParameters.clear();
183+
this.parameters = new HashMap<>();
184+
185+
return fluxMsg
186+
.windowUntil(it -> it.resultSetEnd())
187+
.map(
188+
dataRow ->
189+
new MariadbResult(
190+
false,
191+
dataRow,
192+
ExceptionFactory.INSTANCE,
193+
null,
194+
client.getVersion().isMariaDBServer()
195+
&& client.getVersion().versionGreaterOrEqual(10, 5, 1)));
195196
}
196197
}
197198

@@ -216,7 +217,7 @@ public MariadbServerParameterizedQueryStatement returnGeneratedValues(String...
216217
}
217218

218219
private Flux<org.mariadb.r2dbc.api.MariadbResult> execute(
219-
String sql, List<Parameter<?>> parameters, String[] generatedColumns) {
220+
String sql, Map<Integer, Parameter<?>> parameters, String[] generatedColumns) {
220221
ExceptionFactory factory = ExceptionFactory.withSql(sql);
221222

222223
if (prepareResult == null && client.getPrepareCache() != null) {
@@ -270,7 +271,7 @@ private Flux<org.mariadb.r2dbc.api.MariadbResult> execute(
270271
}
271272

272273
private Flux<org.mariadb.r2dbc.api.MariadbResult> sendPrepareAndExecute(
273-
ExceptionFactory factory, List<Parameter<?>> parameters, String[] generatedColumns) {
274+
ExceptionFactory factory, Map<Integer, Parameter<?>> parameters, String[] generatedColumns) {
274275
return this.client
275276
.sendCommand(new PreparePacket(sql), new ExecutePacket(-1, parameters))
276277
.windowUntil(it -> it.resultSetEnd())
@@ -311,7 +312,7 @@ private Mono<ServerPrepareResult> sendPrepare() {
311312
};
312313

313314
private Flux<org.mariadb.r2dbc.api.MariadbResult> sendExecuteCmd(
314-
ExceptionFactory factory, List<Parameter<?>> parameters, String[] generatedColumns) {
315+
ExceptionFactory factory, Map<Integer, Parameter<?>> parameters, String[] generatedColumns) {
315316
return this.client
316317
.sendCommand(
317318
new ExecutePacket(

src/main/java/org/mariadb/r2dbc/message/client/ExecutePacket.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818

1919
import io.netty.buffer.ByteBuf;
2020
import io.netty.buffer.ByteBufAllocator;
21-
import java.util.List;
21+
import java.util.Map;
2222
import org.mariadb.r2dbc.client.ConnectionContext;
2323
import org.mariadb.r2dbc.codec.Parameter;
2424
import org.mariadb.r2dbc.message.server.Sequencer;
2525

2626
public final class ExecutePacket implements ClientMessage {
27-
private final List<Parameter<?>> parameters;
27+
private final Map<Integer, Parameter<?>> parameters;
2828
private final int statementId;
2929
private final Sequencer sequencer = new Sequencer((byte) 0xff);
3030

31-
public ExecutePacket(int statementId, List<Parameter<?>> parameters) {
31+
public ExecutePacket(int statementId, Map<Integer, Parameter<?>> parameters) {
3232
this.parameters = parameters;
3333
this.statementId = statementId;
3434
}

src/test/java/org/mariadb/r2dbc/integration/StatementBatchingTest.java

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,39 @@
1818

1919
import org.junit.jupiter.api.Test;
2020
import org.mariadb.r2dbc.BaseTest;
21+
import org.mariadb.r2dbc.api.MariadbConnection;
2122
import reactor.test.StepVerifier;
2223

2324
public class StatementBatchingTest extends BaseTest {
2425

2526
@Test
2627
void batchStatement() {
27-
sharedConn
28+
batchStatement(sharedConn);
29+
}
30+
31+
@Test
32+
void batchStatementPrepare() {
33+
batchStatement(sharedConnPrepare);
34+
}
35+
36+
void batchStatement(MariadbConnection connection) {
37+
connection
2838
.createStatement(
2939
"CREATE TEMPORARY TABLE batchStatement (id int not null primary key auto_increment, test varchar(10))")
3040
.execute()
3141
.blockLast();
3242

33-
sharedConn
43+
connection
3444
.createStatement("INSERT INTO batchStatement values (?, ?)")
3545
.bind(0, 1)
3646
.bind(1, "test")
3747
.add()
38-
.bind(0, 2)
3948
.bind(1, "test2")
49+
.bind(0, 2)
4050
.execute()
41-
.subscribe();
51+
.blockLast();
4252

43-
sharedConn
53+
connection
4454
.createStatement("SELECT * FROM batchStatement")
4555
.execute()
4656
.flatMap(r -> r.map((row, metadata) -> row.get(0, String.class) + row.get(1, String.class)))
@@ -51,16 +61,25 @@ void batchStatement() {
5161

5262
@Test
5363
void batchStatementResultSet() {
54-
sharedConn
64+
batchStatementResultSet(sharedConn);
65+
}
66+
67+
@Test
68+
void batchStatementResultSetPrepare() {
69+
batchStatementResultSet(sharedConnPrepare);
70+
}
71+
72+
void batchStatementResultSet(MariadbConnection connection) {
73+
connection
5574
.createStatement(
5675
"CREATE TEMPORARY TABLE batchStatementResultSet (id int not null primary key auto_increment, test varchar(10))")
5776
.execute()
5877
.blockLast();
59-
sharedConn
78+
connection
6079
.createStatement("INSERT INTO batchStatementResultSet values (1, 'test1'), (2, 'test2')")
6180
.execute()
6281
.blockLast();
63-
sharedConn
82+
connection
6483
.createStatement("SELECT test FROM batchStatementResultSet WHERE id = ?")
6584
.bind(0, 1)
6685
.add()

0 commit comments

Comments
 (0)