Skip to content

Commit e76de4f

Browse files
committed
Change Batch API to be consistent with Query()
Exec() method for batch was added & Query() method was refactored. Batch for now behaves the same way as query. patch by Oleksandr Luzhniy; reviewed by João Reis, Danylo Savchenko, Bohdan Siryk for CASSGO-7
1 parent 109a892 commit e76de4f

File tree

9 files changed

+63
-34
lines changed

9 files changed

+63
-34
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414

1515
- Remove global NewBatch function (CASSGO-15)
1616

17+
- Change Batch API to be consistent with Query() (CASSGO-7)
18+
1719
### Fixed
1820

1921
- Retry policy now takes into account query idempotency (CASSGO-27)

batch_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) {
4747
t.Fatal(err)
4848
}
4949

50-
b := session.NewBatch(LoggedBatch)
51-
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
52-
if err := session.ExecuteBatch(b); err == nil {
50+
b := session.Batch(LoggedBatch)
51+
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
52+
if err := b.Exec(); err == nil {
5353
t.Fatal("expected to get error for invalid query in batch")
5454
}
5555
}
@@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) {
6868

6969
micros := time.Now().UnixNano()/1e3 - 1000
7070

71-
b := session.NewBatch(LoggedBatch)
71+
b := session.Batch(LoggedBatch)
7272
b.WithTimestamp(micros)
73-
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
74-
if err := session.ExecuteBatch(b); err != nil {
73+
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
74+
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")
75+
76+
if err := b.Exec(); err != nil {
7577
t.Fatal(err)
7678
}
7779

7880
var storedTs int64
79-
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
81+
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
8082
t.Fatal(err)
8183
}
8284

cassandra_test.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import (
4545
"time"
4646
"unicode"
4747

48-
inf "gopkg.in/inf.v0"
48+
"gopkg.in/inf.v0"
4949
)
5050

5151
func TestEmptyHosts(t *testing.T) {
@@ -454,15 +454,15 @@ func TestCAS(t *testing.T) {
454454
t.Fatal("truncate:", err)
455455
}
456456

457-
successBatch := session.NewBatch(LoggedBatch)
457+
successBatch := session.Batch(LoggedBatch)
458458
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
459459
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
460460
t.Fatal("insert:", err)
461461
} else if !applied {
462462
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
463463
}
464464

465-
successBatch = session.NewBatch(LoggedBatch)
465+
successBatch = session.Batch(LoggedBatch)
466466
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
467467
casMap := make(map[string]interface{})
468468
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
@@ -471,22 +471,22 @@ func TestCAS(t *testing.T) {
471471
t.Fatal("insert should have been applied")
472472
}
473473

474-
failBatch := session.NewBatch(LoggedBatch)
474+
failBatch := session.Batch(LoggedBatch)
475475
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
476476
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
477477
t.Fatal("insert:", err)
478478
} else if applied {
479479
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
480480
}
481481

482-
insertBatch := session.NewBatch(LoggedBatch)
482+
insertBatch := session.Batch(LoggedBatch)
483483
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
484484
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
485485
if err := session.ExecuteBatch(insertBatch); err != nil {
486486
t.Fatal("insert:", err)
487487
}
488488

489-
failBatch = session.NewBatch(LoggedBatch)
489+
failBatch = session.Batch(LoggedBatch)
490490
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
491491
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
492492
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
@@ -611,7 +611,7 @@ func TestBatch(t *testing.T) {
611611
t.Fatal("create table:", err)
612612
}
613613

614-
batch := session.NewBatch(LoggedBatch)
614+
batch := session.Batch(LoggedBatch)
615615
for i := 0; i < 100; i++ {
616616
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
617617
}
@@ -643,9 +643,9 @@ func TestUnpreparedBatch(t *testing.T) {
643643

644644
var batch *Batch
645645
if session.cfg.ProtoVersion == 2 {
646-
batch = session.NewBatch(CounterBatch)
646+
batch = session.Batch(CounterBatch)
647647
} else {
648-
batch = session.NewBatch(UnloggedBatch)
648+
batch = session.Batch(UnloggedBatch)
649649
}
650650

651651
for i := 0; i < 100; i++ {
@@ -684,7 +684,7 @@ func TestBatchLimit(t *testing.T) {
684684
t.Fatal("create table:", err)
685685
}
686686

687-
batch := session.NewBatch(LoggedBatch)
687+
batch := session.Batch(LoggedBatch)
688688
for i := 0; i < 65537; i++ {
689689
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
690690
}
@@ -738,7 +738,7 @@ func TestTooManyQueryArgs(t *testing.T) {
738738
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
739739
}
740740

741-
batch := session.NewBatch(UnloggedBatch)
741+
batch := session.Batch(UnloggedBatch)
742742
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
743743
err = session.ExecuteBatch(batch)
744744

@@ -770,7 +770,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
770770
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
771771
}
772772

773-
batch := session.NewBatch(UnloggedBatch)
773+
batch := session.Batch(UnloggedBatch)
774774
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
775775
err = session.ExecuteBatch(batch)
776776

@@ -1392,7 +1392,7 @@ func TestBatchQueryInfo(t *testing.T) {
13921392
return values, nil
13931393
}
13941394

1395-
batch := session.NewBatch(LoggedBatch)
1395+
batch := session.Batch(LoggedBatch)
13961396
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)
13971397

13981398
if err := session.ExecuteBatch(batch); err != nil {
@@ -1520,7 +1520,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
15201520
}
15211521

15221522
stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
1523-
batch := session.NewBatch(UnloggedBatch)
1523+
batch := session.Batch(UnloggedBatch)
15241524
batch.Query(stmt, "bar")
15251525
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
15261526
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
@@ -1904,7 +1904,7 @@ func TestBatchStats(t *testing.T) {
19041904
t.Fatalf("failed to create table with error '%v'", err)
19051905
}
19061906

1907-
b := session.NewBatch(LoggedBatch)
1907+
b := session.Batch(LoggedBatch)
19081908
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
19091909
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)
19101910

@@ -1947,7 +1947,7 @@ func TestBatchObserve(t *testing.T) {
19471947

19481948
var observedBatch *observation
19491949

1950-
batch := session.NewBatch(LoggedBatch)
1950+
batch := session.Batch(LoggedBatch)
19511951
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
19521952
if observedBatch != nil {
19531953
t.Fatal("batch observe called more than once")
@@ -3286,7 +3286,7 @@ func TestUnsetColBatch(t *testing.T) {
32863286
t.Fatalf("failed to create table with error '%v'", err)
32873287
}
32883288

3289-
b := session.NewBatch(LoggedBatch)
3289+
b := session.Batch(LoggedBatch)
32903290
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
32913291
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
32923292
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)

doc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@
310310
// # Batches
311311
//
312312
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
313-
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
313+
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
314314
// Then execute the batch with Session.ExecuteBatch.
315315
//
316316
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have

example_batch_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929
"fmt"
3030
"log"
3131

32-
gocql "github.com/gocql/gocql"
32+
"github.com/gocql/gocql"
3333
)
3434

3535
// Example_batch demonstrates how to execute a batch of statements.
@@ -49,7 +49,7 @@ func Example_batch() {
4949

5050
ctx := context.Background()
5151

52-
b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
52+
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
5353
b.Entries = append(b.Entries, gocql.BatchEntry{
5454
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
5555
Args: []interface{}{1, 2, "1.2"},
@@ -60,11 +60,19 @@ func Example_batch() {
6060
Args: []interface{}{1, 3, "1.3"},
6161
Idempotent: true,
6262
})
63+
6364
err = session.ExecuteBatch(b)
6465
if err != nil {
6566
log.Fatal(err)
6667
}
6768

69+
err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
70+
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
71+
Exec()
72+
if err != nil {
73+
log.Fatal(err)
74+
}
75+
6876
scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
6977
for scanner.Next() {
7078
var pk, ck int32
@@ -77,4 +85,6 @@ func Example_batch() {
7785
}
7886
// 1 2 1.2
7987
// 1 3 1.3
88+
// 1 4 1.4
89+
// 1 5 1.5
8090
}

example_lwt_batch_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929
"fmt"
3030
"log"
3131

32-
gocql "github.com/gocql/gocql"
32+
"github.com/gocql/gocql"
3333
)
3434

3535
// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
@@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() {
6262
}
6363

6464
executeBatch := func(ck2Version int) {
65-
b := session.NewBatch(gocql.LoggedBatch)
65+
b := session.Batch(gocql.LoggedBatch)
6666
b.Entries = append(b.Entries, gocql.BatchEntry{
6767
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
6868
Args: []interface{}{"b", "pk1", "ck1", 1},

integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) {
218218
iter.Close()
219219

220220
// Batch Message
221-
b := session.NewBatch(LoggedBatch)
221+
b := session.Batch(LoggedBatch)
222222
b.CustomPayload = customPayload
223223
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
224224
if err := session.ExecuteBatch(b); err != nil {

session.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
731731
return conn.executeBatch(ctx, b)
732732
}
733733

734+
// Exec executes a batch operation and returns nil if successful
735+
// otherwise an error is returned describing the failure.
736+
func (b *Batch) Exec() error {
737+
iter := b.session.executeBatch(b)
738+
return iter.Close()
739+
}
740+
734741
func (s *Session) executeBatch(batch *Batch) *Iter {
735742
// fail fast
736743
if s.Closed() {
@@ -1748,7 +1755,14 @@ type Batch struct {
17481755
}
17491756

17501757
// NewBatch creates a new batch operation using defaults defined in the cluster
1758+
//
1759+
// Deprecated: use session.Batch instead
17511760
func (s *Session) NewBatch(typ BatchType) *Batch {
1761+
return s.Batch(typ)
1762+
}
1763+
1764+
// Batch creates a new batch operation using defaults defined in the cluster
1765+
func (s *Session) Batch(typ BatchType) *Batch {
17521766
s.mu.RLock()
17531767
batch := &Batch{
17541768
Type: typ,
@@ -1848,8 +1862,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch
18481862
}
18491863

18501864
// Query adds the query to the batch operation
1851-
func (b *Batch) Query(stmt string, args ...interface{}) {
1865+
func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
18521866
b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
1867+
return b
18531868
}
18541869

18551870
// Bind adds the query to the batch operation and correlates it with a binding callback

session_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) {
9696
t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err)
9797
}
9898

99-
testBatch := s.NewBatch(LoggedBatch)
99+
testBatch := s.Batch(LoggedBatch)
100100
testBatch.Query("test")
101101
err := s.ExecuteBatch(testBatch)
102102

@@ -219,15 +219,15 @@ func TestBatchBasicAPI(t *testing.T) {
219219
s.pool = cfg.PoolConfig.buildPool(s)
220220

221221
// Test UnloggedBatch
222-
b := s.NewBatch(UnloggedBatch)
222+
b := s.Batch(UnloggedBatch)
223223
if b.Type != UnloggedBatch {
224224
t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type)
225225
} else if b.rt != cfg.RetryPolicy {
226226
t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt)
227227
}
228228

229229
// Test LoggedBatch
230-
b = s.NewBatch(LoggedBatch)
230+
b = s.Batch(LoggedBatch)
231231
if b.Type != LoggedBatch {
232232
t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type)
233233
}

0 commit comments

Comments
 (0)