Skip to content

Commit 7337d08

Browse files
author
tengu-alt
committed
Exec() method for batch was added & Query() method was refactored
1 parent 974fa12 commit 7337d08

File tree

8 files changed

+78
-34
lines changed

8 files changed

+78
-34
lines changed

batch_test.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) {
4747
t.Fatal(err)
4848
}
4949

50-
b := session.NewBatch(LoggedBatch)
51-
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
52-
if err := session.ExecuteBatch(b); err == nil {
50+
b := session.Batch(LoggedBatch)
51+
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
52+
if err := b.Exec(); err == nil {
5353
t.Fatal("expected to get error for invalid query in batch")
5454
}
5555
}
@@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) {
6868

6969
micros := time.Now().UnixNano()/1e3 - 1000
7070

71-
b := session.NewBatch(LoggedBatch)
71+
b := session.Batch(LoggedBatch)
7272
b.WithTimestamp(micros)
73-
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
74-
if err := session.ExecuteBatch(b); err != nil {
73+
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
74+
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")
75+
76+
if err := b.Exec(); err != nil {
7577
t.Fatal(err)
7678
}
7779

7880
var storedTs int64
79-
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
81+
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
8082
t.Fatal(err)
8183
}
8284

cassandra_test.go

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import (
4444
"time"
4545
"unicode"
4646

47-
inf "gopkg.in/inf.v0"
47+
"gopkg.in/inf.v0"
4848
)
4949

5050
func TestEmptyHosts(t *testing.T) {
@@ -453,15 +453,15 @@ func TestCAS(t *testing.T) {
453453
t.Fatal("truncate:", err)
454454
}
455455

456-
successBatch := session.NewBatch(LoggedBatch)
456+
successBatch := session.Batch(LoggedBatch)
457457
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
458458
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
459459
t.Fatal("insert:", err)
460460
} else if !applied {
461461
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
462462
}
463463

464-
successBatch = session.NewBatch(LoggedBatch)
464+
successBatch = session.Batch(LoggedBatch)
465465
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
466466
casMap := make(map[string]interface{})
467467
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
@@ -470,22 +470,22 @@ func TestCAS(t *testing.T) {
470470
t.Fatal("insert should have been applied")
471471
}
472472

473-
failBatch := session.NewBatch(LoggedBatch)
473+
failBatch := session.Batch(LoggedBatch)
474474
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
475475
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
476476
t.Fatal("insert:", err)
477477
} else if applied {
478478
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
479479
}
480480

481-
insertBatch := session.NewBatch(LoggedBatch)
481+
insertBatch := session.Batch(LoggedBatch)
482482
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
483483
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
484484
if err := session.ExecuteBatch(insertBatch); err != nil {
485485
t.Fatal("insert:", err)
486486
}
487487

488-
failBatch = session.NewBatch(LoggedBatch)
488+
failBatch = session.Batch(LoggedBatch)
489489
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
490490
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
491491
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
@@ -610,7 +610,7 @@ func TestBatch(t *testing.T) {
610610
t.Fatal("create table:", err)
611611
}
612612

613-
batch := session.NewBatch(LoggedBatch)
613+
batch := session.Batch(LoggedBatch)
614614
for i := 0; i < 100; i++ {
615615
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
616616
}
@@ -642,9 +642,9 @@ func TestUnpreparedBatch(t *testing.T) {
642642

643643
var batch *Batch
644644
if session.cfg.ProtoVersion == 2 {
645-
batch = session.NewBatch(CounterBatch)
645+
batch = session.Batch(CounterBatch)
646646
} else {
647-
batch = session.NewBatch(UnloggedBatch)
647+
batch = session.Batch(UnloggedBatch)
648648
}
649649

650650
for i := 0; i < 100; i++ {
@@ -683,7 +683,7 @@ func TestBatchLimit(t *testing.T) {
683683
t.Fatal("create table:", err)
684684
}
685685

686-
batch := session.NewBatch(LoggedBatch)
686+
batch := session.Batch(LoggedBatch)
687687
for i := 0; i < 65537; i++ {
688688
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
689689
}
@@ -737,7 +737,7 @@ func TestTooManyQueryArgs(t *testing.T) {
737737
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
738738
}
739739

740-
batch := session.NewBatch(UnloggedBatch)
740+
batch := session.Batch(UnloggedBatch)
741741
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
742742
err = session.ExecuteBatch(batch)
743743

@@ -769,7 +769,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
769769
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
770770
}
771771

772-
batch := session.NewBatch(UnloggedBatch)
772+
batch := session.Batch(UnloggedBatch)
773773
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
774774
err = session.ExecuteBatch(batch)
775775

@@ -1342,7 +1342,7 @@ func TestBatchQueryInfo(t *testing.T) {
13421342
return values, nil
13431343
}
13441344

1345-
batch := session.NewBatch(LoggedBatch)
1345+
batch := session.Batch(LoggedBatch)
13461346
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)
13471347

13481348
if err := session.ExecuteBatch(batch); err != nil {
@@ -1470,7 +1470,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
14701470
}
14711471

14721472
stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
1473-
batch := session.NewBatch(UnloggedBatch)
1473+
batch := session.Batch(UnloggedBatch)
14741474
batch.Query(stmt, "bar")
14751475
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
14761476
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
@@ -1854,7 +1854,7 @@ func TestBatchStats(t *testing.T) {
18541854
t.Fatalf("failed to create table with error '%v'", err)
18551855
}
18561856

1857-
b := session.NewBatch(LoggedBatch)
1857+
b := session.Batch(LoggedBatch)
18581858
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
18591859
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)
18601860

@@ -1897,7 +1897,7 @@ func TestBatchObserve(t *testing.T) {
18971897

18981898
var observedBatch *observation
18991899

1900-
batch := session.NewBatch(LoggedBatch)
1900+
batch := session.Batch(LoggedBatch)
19011901
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
19021902
if observedBatch != nil {
19031903
t.Fatal("batch observe called more than once")
@@ -3236,7 +3236,7 @@ func TestUnsetColBatch(t *testing.T) {
32363236
t.Fatalf("failed to create table with error '%v'", err)
32373237
}
32383238

3239-
b := session.NewBatch(LoggedBatch)
3239+
b := session.Batch(LoggedBatch)
32403240
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
32413241
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
32423242
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)

doc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@
300300
// # Batches
301301
//
302302
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
303-
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
303+
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
304304
// Then execute the batch with Session.ExecuteBatch.
305305
//
306306
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have

example_batch_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929
"fmt"
3030
"log"
3131

32-
gocql "github.com/gocql/gocql"
32+
"github.com/gocql/gocql"
3333
)
3434

3535
// Example_batch demonstrates how to execute a batch of statements.
@@ -49,7 +49,7 @@ func Example_batch() {
4949

5050
ctx := context.Background()
5151

52-
b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
52+
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
5353
b.Entries = append(b.Entries, gocql.BatchEntry{
5454
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
5555
Args: []interface{}{1, 2, "1.2"},
@@ -60,11 +60,19 @@ func Example_batch() {
6060
Args: []interface{}{1, 3, "1.3"},
6161
Idempotent: true,
6262
})
63+
6364
err = session.ExecuteBatch(b)
6465
if err != nil {
6566
log.Fatal(err)
6667
}
6768

69+
err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
70+
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
71+
Exec()
72+
if err != nil {
73+
log.Fatal(err)
74+
}
75+
6876
scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
6977
for scanner.Next() {
7078
var pk, ck int32
@@ -77,4 +85,6 @@ func Example_batch() {
7785
}
7886
// 1 2 1.2
7987
// 1 3 1.3
88+
// 1 4 1.4
89+
// 1 5 1.5
8090
}

example_lwt_batch_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import (
2929
"fmt"
3030
"log"
3131

32-
gocql "github.com/gocql/gocql"
32+
"github.com/gocql/gocql"
3333
)
3434

3535
// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
@@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() {
6262
}
6363

6464
executeBatch := func(ck2Version int) {
65-
b := session.NewBatch(gocql.LoggedBatch)
65+
b := session.Batch(gocql.LoggedBatch)
6666
b.Entries = append(b.Entries, gocql.BatchEntry{
6767
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
6868
Args: []interface{}{"b", "pk1", "ck1", 1},

integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) {
218218
iter.Close()
219219

220220
// Batch Message
221-
b := session.NewBatch(LoggedBatch)
221+
b := session.Batch(LoggedBatch)
222222
b.CustomPayload = customPayload
223223
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
224224
if err := session.ExecuteBatch(b); err != nil {

session.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
731731
return conn.executeBatch(ctx, b)
732732
}
733733

734+
// Exec executes a batch operation and returns nil if successful
735+
// otherwise an error is returned describing the failure.
736+
func (b *Batch) Exec() error {
737+
iter := b.session.executeBatch(b)
738+
return iter.Close()
739+
}
740+
734741
func (s *Session) executeBatch(batch *Batch) *Iter {
735742
// fail fast
736743
if s.Closed() {
@@ -1760,6 +1767,8 @@ func NewBatch(typ BatchType) *Batch {
17601767
}
17611768

17621769
// NewBatch creates a new batch operation using defaults defined in the cluster
1770+
//
1771+
// Deprecated: use session.Batch instead
17631772
func (s *Session) NewBatch(typ BatchType) *Batch {
17641773
s.mu.RLock()
17651774
batch := &Batch{
@@ -1781,6 +1790,28 @@ func (s *Session) NewBatch(typ BatchType) *Batch {
17811790
return batch
17821791
}
17831792

1793+
// Batch creates a new batch operation using defaults defined in the cluster
1794+
func (s *Session) Batch(typ BatchType) *Batch {
1795+
s.mu.RLock()
1796+
batch := &Batch{
1797+
Type: typ,
1798+
rt: s.cfg.RetryPolicy,
1799+
serialCons: s.cfg.SerialConsistency,
1800+
trace: s.trace,
1801+
observer: s.batchObserver,
1802+
session: s,
1803+
Cons: s.cons,
1804+
defaultTimestamp: s.cfg.DefaultTimestamp,
1805+
keyspace: s.cfg.Keyspace,
1806+
metrics: &queryMetrics{m: make(map[string]*hostMetrics)},
1807+
spec: &NonSpeculativeExecution{},
1808+
routingInfo: &queryRoutingInfo{},
1809+
}
1810+
1811+
s.mu.RUnlock()
1812+
return batch
1813+
}
1814+
17841815
// Trace enables tracing of this batch. Look at the documentation of the
17851816
// Tracer interface to learn more about tracing.
17861817
func (b *Batch) Trace(trace Tracer) *Batch {
@@ -1860,8 +1891,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch
18601891
}
18611892

18621893
// Query adds the query to the batch operation
1863-
func (b *Batch) Query(stmt string, args ...interface{}) {
1894+
func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
18641895
b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
1896+
return b
18651897
}
18661898

18671899
// Bind adds the query to the batch operation and correlates it with a binding callback

session_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) {
9696
t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err)
9797
}
9898

99-
testBatch := s.NewBatch(LoggedBatch)
99+
testBatch := s.Batch(LoggedBatch)
100100
testBatch.Query("test")
101101
err := s.ExecuteBatch(testBatch)
102102

@@ -219,15 +219,15 @@ func TestBatchBasicAPI(t *testing.T) {
219219
s.pool = cfg.PoolConfig.buildPool(s)
220220

221221
// Test UnloggedBatch
222-
b := s.NewBatch(UnloggedBatch)
222+
b := s.Batch(UnloggedBatch)
223223
if b.Type != UnloggedBatch {
224224
t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type)
225225
} else if b.rt != cfg.RetryPolicy {
226226
t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt)
227227
}
228228

229229
// Test LoggedBatch
230-
b = s.NewBatch(LoggedBatch)
230+
b = s.Batch(LoggedBatch)
231231
if b.Type != LoggedBatch {
232232
t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type)
233233
}

0 commit comments

Comments
 (0)