Skip to content

Commit 97336ec

Browse files
committed
CASSGO-42: don't panic if no applied column is returned
I also added checks against MapScan failing which could also trigger this bug. Several integration tests were added to validate various edge cases. Patch by James Hartig for CASSGO-42
1 parent 37030fb commit 97336ec

File tree

4 files changed

+86
-12
lines changed

4 files changed

+86
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
- Retry policy now takes into account query idempotency (CASSGO-27)
2424
- Don't return error to caller with RetryType Ignore (CASSGO-28)
25+
- Don't panic in MapExecuteBatchCAS if no `[applied]` column is returned (CASSGO-42)
2526

2627
## [1.7.0] - 2024-09-23
2728

cassandra_test.go

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"context"
3333
"errors"
3434
"fmt"
35-
"github.com/stretchr/testify/require"
3635
"io"
3736
"math"
3837
"math/big"
@@ -45,6 +44,8 @@ import (
4544
"time"
4645
"unicode"
4746

47+
"github.com/stretchr/testify/require"
48+
4849
"gopkg.in/inf.v0"
4950
)
5051

@@ -476,7 +477,7 @@ func TestCAS(t *testing.T) {
476477
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
477478
t.Fatal("insert:", err)
478479
} else if applied {
479-
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
480+
t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
480481
}
481482

482483
insertBatch := session.Batch(LoggedBatch)
@@ -492,7 +493,7 @@ func TestCAS(t *testing.T) {
492493
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
493494
t.Fatal("insert:", err)
494495
} else if applied {
495-
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
496+
t.Fatalf("insert should have not been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
496497
} else {
497498
if scan := iter.Scan(&applied, &titleCAS, &revidCAS, &modifiedCAS); scan && applied {
498499
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
@@ -503,6 +504,55 @@ func TestCAS(t *testing.T) {
503504
t.Fatal("scan:", err)
504505
}
505506
}
507+
508+
casMap = make(map[string]interface{})
509+
if applied, err := session.Query(`SELECT revid FROM cas_table WHERE title = ?`,
510+
title+"_foo").MapScanCAS(casMap); err != nil {
511+
t.Fatal("select:", err)
512+
} else if applied {
513+
t.Fatal("select shouldn't have returned applied")
514+
}
515+
516+
if _, err := session.Query(`SELECT revid FROM cas_table WHERE title = ?`,
517+
title+"_foo").ScanCAS(&revidCAS); err == nil {
518+
t.Fatal("select: should have returned an error")
519+
}
520+
521+
notCASBatch := session.Batch(LoggedBatch)
522+
notCASBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?)", title+"_baz", revid, modified)
523+
casMap = make(map[string]interface{})
524+
if _, _, err := session.MapExecuteBatchCAS(notCASBatch, casMap); err != ErrNotFound {
525+
t.Fatal("insert should have returned not found:", err)
526+
}
527+
528+
notCASBatch = session.Batch(LoggedBatch)
529+
notCASBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?)", title+"_baz", revid, modified)
530+
casMap = make(map[string]interface{})
531+
if _, _, err := session.ExecuteBatchCAS(notCASBatch, &revidCAS); err != ErrNotFound {
532+
t.Fatal("insert should have returned not found:", err)
533+
}
534+
535+
failBatch = session.Batch(LoggedBatch)
536+
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
537+
if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == nil {
538+
t.Fatal("update should have errored")
539+
}
540+
// make sure MapScanCAS does not panic when MapScan fails
541+
casMap = make(map[string]interface{})
542+
casMap["last_modified"] = false
543+
if _, err := session.Query(`UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`,
544+
modified).MapScanCAS(casMap); err == nil {
545+
t.Fatal("update should hvae errored", err)
546+
}
547+
548+
// make sure MapExecuteBatchCAS does not panic when MapScan fails
549+
failBatch = session.Batch(LoggedBatch)
550+
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
551+
casMap = make(map[string]interface{})
552+
casMap["last_modified"] = false
553+
if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == nil {
554+
t.Fatal("update should have errored")
555+
}
506556
}
507557

508558
func TestDurationType(t *testing.T) {

helpers.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ func TupleColumnName(c string, n int) string {
322322
return fmt.Sprintf("%s[%d]", c, n)
323323
}
324324

325+
// RowData returns the RowData for the iterator.
325326
func (iter *Iter) RowData() (RowData, error) {
326327
if iter.err != nil {
327328
return RowData{}, iter.err
@@ -334,6 +335,7 @@ func (iter *Iter) RowData() (RowData, error) {
334335
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
335336
val, err := column.TypeInfo.NewWithError()
336337
if err != nil {
338+
iter.err = err
337339
return RowData{}, err
338340
}
339341
columns = append(columns, column.Name)
@@ -343,6 +345,7 @@ func (iter *Iter) RowData() (RowData, error) {
343345
columns = append(columns, TupleColumnName(column.Name, i))
344346
val, err := elem.NewWithError()
345347
if err != nil {
348+
iter.err = err
346349
return RowData{}, err
347350
}
348351
values = append(values, val)
@@ -364,7 +367,10 @@ func (iter *Iter) rowMap() (map[string]interface{}, error) {
364367
return nil, iter.err
365368
}
366369

367-
rowData, _ := iter.RowData()
370+
rowData, err := iter.RowData()
371+
if err != nil {
372+
return nil, err
373+
}
368374
iter.Scan(rowData.Values...)
369375
m := make(map[string]interface{}, len(rowData.Columns))
370376
rowData.rowMap(m)
@@ -379,7 +385,10 @@ func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
379385
}
380386

381387
// Not checking for the error because we just did
382-
rowData, _ := iter.RowData()
388+
rowData, err := iter.RowData()
389+
if err != nil {
390+
return nil, err
391+
}
383392
dataToReturn := make([]map[string]interface{}, 0)
384393
for iter.Scan(rowData.Values...) {
385394
m := make(map[string]interface{}, len(rowData.Columns))
@@ -435,8 +444,10 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool {
435444
return false
436445
}
437446

438-
// Not checking for the error because we just did
439-
rowData, _ := iter.RowData()
447+
rowData, err := iter.RowData()
448+
if err != nil {
449+
return false
450+
}
440451

441452
for i, col := range rowData.Columns {
442453
if dest, ok := m[col]; ok {

session.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bo
785785
iter.Scan(&applied)
786786
}
787787

788-
return applied, iter, nil
788+
return applied, iter, iter.err
789789
}
790790

791791
// MapExecuteBatchCAS executes a batch operation much like ExecuteBatchCAS,
@@ -798,8 +798,14 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
798798
return false, nil, err
799799
}
800800
iter.MapScan(dest)
801-
applied = dest["[applied]"].(bool)
802-
delete(dest, "[applied]")
801+
if iter.err != nil {
802+
return false, iter, iter.err
803+
}
804+
// check if [applied] was returned, otherwise it might not be CAS
805+
if _, ok := dest["[applied]"]; ok {
806+
applied = dest["[applied]"].(bool)
807+
delete(dest, "[applied]")
808+
}
803809

804810
// we usually close here, but instead of closing, just returin an error
805811
// if MapScan failed. Although Close just returns err, using Close
@@ -1387,8 +1393,14 @@ func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error
13871393
return false, err
13881394
}
13891395
iter.MapScan(dest)
1390-
applied = dest["[applied]"].(bool)
1391-
delete(dest, "[applied]")
1396+
if iter.err != nil {
1397+
return false, iter.err
1398+
}
1399+
// check if [applied] was returned, otherwise it might not be CAS
1400+
if _, ok := dest["[applied]"]; ok {
1401+
applied = dest["[applied]"].(bool)
1402+
delete(dest, "[applied]")
1403+
}
13921404

13931405
return applied, iter.Close()
13941406
}

0 commit comments

Comments
 (0)