Skip to content

Commit 6224a0b

Browse files
committed
Support of keyspace field for BATCH message
1 parent 974fa12 commit 6224a0b

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

cassandra_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3288,3 +3288,46 @@ func TestQuery_NamedValues(t *testing.T) {
32883288
t.Fatal(err)
32893289
}
32903290
}
3291+
3292+
func TestBatchKeyspaceField(t *testing.T) {
3293+
session := createSession(t)
3294+
defer session.Close()
3295+
3296+
if session.cfg.ProtoVersion < protoVersion5 {
3297+
t.Skip("keyspace for BATCH message is not supported in protocol < 5")
3298+
}
3299+
3300+
err := createTable(session, "CREATE TABLE batch_keyspace(id int, value text, PRIMARY KEY (id))")
3301+
if err != nil {
3302+
t.Fatal(err)
3303+
}
3304+
3305+
ids := []int{1, 2}
3306+
texts := []string{"val1", "val2"}
3307+
3308+
b := session.NewBatch(LoggedBatch)
3309+
b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[0], texts[0])
3310+
b.Query("INSERT INTO batch_keyspace(id, value) VALUES (?, ?)", ids[1], texts[1])
3311+
err = session.ExecuteBatch(b)
3312+
if err != nil {
3313+
t.Fatal(err)
3314+
}
3315+
3316+
var (
3317+
id int
3318+
text string
3319+
)
3320+
3321+
iter := session.Query("SELECT * FROM batch_keyspace").Iter()
3322+
defer iter.Close()
3323+
3324+
for i := 0; iter.Scan(&id, &text); i++ {
3325+
if id != ids[i] {
3326+
t.Fatalf("expected id %v, got %v", ids[i], id)
3327+
}
3328+
3329+
if text != texts[i] {
3330+
t.Fatalf("expected text %v, got %v", texts[i], text)
3331+
}
3332+
}
3333+
}

conn.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,10 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter {
15541554
customPayload: batch.CustomPayload,
15551555
}
15561556

1557+
if c.version > protoVersion4 {
1558+
req.keyspace = c.currentKeyspace
1559+
}
1560+
15571561
stmts := make(map[string]string, len(batch.Entries))
15581562

15591563
for i := 0; i < n; i++ {

frame.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,9 @@ type writeBatchFrame struct {
16591659

16601660
//v4+
16611661
customPayload map[string][]byte
1662+
1663+
//v5+
1664+
keyspace string
16621665
}
16631666

16641667
func (w *writeBatchFrame) buildFrame(framer *framer, streamID int) error {
@@ -1718,6 +1721,13 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload
17181721
flags |= flagDefaultTimestamp
17191722
}
17201723

1724+
if w.keyspace != "" {
1725+
if f.proto < protoVersion5 {
1726+
panic(fmt.Errorf("the keyspace can only be set with protocol 5 or higher"))
1727+
}
1728+
flags |= flagWithKeyspace
1729+
}
1730+
17211731
if f.proto > protoVersion4 {
17221732
f.writeUint(uint32(flags))
17231733
} else {
@@ -1737,6 +1747,10 @@ func (f *framer) writeBatchFrame(streamID int, w *writeBatchFrame, customPayload
17371747
}
17381748
f.writeLong(ts)
17391749
}
1750+
1751+
if w.keyspace != "" {
1752+
f.writeString(w.keyspace)
1753+
}
17401754
}
17411755

17421756
return f.finish()

0 commit comments

Comments
 (0)