Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Externally-defined type registration (CASSGO-43)
- Add Query and Batch to ObservedQuery and ObservedBatch (CASSGO-73)
- Add way to create HostInfo objects for testing purposes (CASSGO-71)
- Add missing Context methods on Query and Batch (CASSGO-81)

### Changed

Expand Down
78 changes: 68 additions & 10 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string, keyspace stri
// Exec executes a batch operation and returns nil if successful
// otherwise an error is returned describing the failure.
func (b *Batch) Exec() error {
iter := b.session.executeBatch(b, nil)
iter := b.session.executeBatch(b, b.context)
return iter.Close()
}

Expand All @@ -732,7 +732,7 @@ func (b *Batch) ExecContext(ctx context.Context) error {

// Iter executes a batch operation and returns an Iter object
// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency
func (b *Batch) Iter() *Iter { return b.IterContext(nil) }
func (b *Batch) Iter() *Iter { return b.IterContext(b.context) }

// IterContext executes a batch operation with the provided context and returns an Iter object
// that can be used to access properties related to the execution like Iter.Attempts and Iter.Latency
Expand Down Expand Up @@ -766,7 +766,7 @@ func (s *Session) executeBatch(batch *Batch, ctx context.Context) *Iter {
// ExecuteBatch executes a batch operation and returns nil if successful
// otherwise an error is returned describing the failure.
func (s *Session) ExecuteBatch(batch *Batch) error {
iter := s.executeBatch(batch, nil)
iter := s.executeBatch(batch, batch.context)
return iter.Close()
}

Expand All @@ -786,7 +786,16 @@ func (s *Session) ExecuteBatchCAS(batch *Batch, dest ...interface{}) (applied bo
// Further scans on the interator must also remember to include
// the applied boolean as the first argument to *Iter.Scan
func (b *Batch) ExecCAS(dest ...interface{}) (applied bool, iter *Iter, err error) {
iter = b.session.executeBatch(b, nil)
return b.ExecCASContext(b.context, dest...)
}

// ExecCASContext executes a batch operation with the provided context and returns true if successful and
// an iterator (to scan additional rows if more than one conditional statement)
// was sent.
// Further scans on the interator must also remember to include
// the applied boolean as the first argument to *Iter.Scan
func (b *Batch) ExecCASContext(ctx context.Context, dest ...interface{}) (applied bool, iter *Iter, err error) {
iter = b.session.executeBatch(b, ctx)
if err := iter.checkErrAndNotFound(); err != nil {
iter.Close()
return false, nil, err
Expand Down Expand Up @@ -814,7 +823,14 @@ func (s *Session) MapExecuteBatchCAS(batch *Batch, dest map[string]interface{})
// however it accepts a map rather than a list of arguments for the initial
// scan.
func (b *Batch) MapExecCAS(dest map[string]interface{}) (applied bool, iter *Iter, err error) {
iter = b.session.executeBatch(b, nil)
return b.MapExecCASContext(b.context, dest)
}

// MapExecCASContext executes a batch operation with the provided context much like ExecuteBatchCAS,
// however it accepts a map rather than a list of arguments for the initial
// scan.
func (b *Batch) MapExecCASContext(ctx context.Context, dest map[string]interface{}) (applied bool, iter *Iter, err error) {
iter = b.session.executeBatch(b, ctx)
if err := iter.checkErrAndNotFound(); err != nil {
iter.Close()
return false, nil, err
Expand Down Expand Up @@ -1057,6 +1073,8 @@ func (q *Query) CustomPayload(customPayload map[string][]byte) *Query {
return q
}

// Deprecated: Context retrieval is deprecated. Pass context directly to execution methods
// like ExecContext or IterContext instead.
func (q *Query) Context() context.Context {
if q.context == nil {
return context.Background()
Expand Down Expand Up @@ -1272,7 +1290,7 @@ func isUseStatement(stmt string) bool {
// Iter executes the query and returns an iterator capable of iterating
// over all results.
func (q *Query) Iter() *Iter {
return q.IterContext(nil)
return q.IterContext(q.context)
}

// IterContext executes the query with the provided context and returns an iterator capable of iterating
Expand All @@ -1297,7 +1315,14 @@ func (q *Query) iterInternal(c *Conn, ctx context.Context) *Iter {
// row into the map pointed at by m and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Query) MapScan(m map[string]interface{}) error {
iter := q.Iter()
return q.MapScanContext(q.context, m)
}

// MapScanContext executes the query with the provided context, copies the columns of the first selected
// row into the map pointed at by m and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Query) MapScanContext(ctx context.Context, m map[string]interface{}) error {
iter := q.IterContext(ctx)
if err := iter.checkErrAndNotFound(); err != nil {
return err
}
Expand All @@ -1309,7 +1334,14 @@ func (q *Query) MapScan(m map[string]interface{}) error {
// row into the values pointed at by dest and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Query) Scan(dest ...interface{}) error {
iter := q.Iter()
return q.ScanContext(q.context, dest...)
}

// ScanContext executes the query with the provided context, copies the columns of the first selected
// row into the values pointed at by dest and discards the rest. If no rows
// were selected, ErrNotFound is returned.
func (q *Query) ScanContext(ctx context.Context, dest ...interface{}) error {
iter := q.IterContext(ctx)
if err := iter.checkErrAndNotFound(); err != nil {
return err
}
Expand All @@ -1326,8 +1358,20 @@ func (q *Query) Scan(dest ...interface{}) error {
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. Use MapScanCAS to capture them safely.
func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
return q.ScanCASContext(q.context, dest...)
}

// ScanCASContext executes a lightweight transaction (i.e. an UPDATE or INSERT
// statement containing an IF clause) with the provided context. If the transaction fails because
// the existing values did not match, the previous values will be stored
// in dest.
//
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. Use MapScanCAS to capture them safely.
func (q *Query) ScanCASContext(ctx context.Context, dest ...interface{}) (applied bool, err error) {
q.disableSkipMetadata = true
iter := q.Iter()
iter := q.IterContext(ctx)
if err := iter.checkErrAndNotFound(); err != nil {
return false, err
}
Expand All @@ -1349,8 +1393,20 @@ func (q *Query) ScanCAS(dest ...interface{}) (applied bool, err error) {
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. MapScanCAS is added to capture them safely.
func (q *Query) MapScanCAS(dest map[string]interface{}) (applied bool, err error) {
return q.MapScanCASContext(q.context, dest)
}

// MapScanCASContext executes a lightweight transaction (i.e. an UPDATE or INSERT
// statement containing an IF clause) with the provided context. If the transaction fails because
// the existing values did not match, the previous values will be stored
// in dest map.
//
// As for INSERT .. IF NOT EXISTS, previous values will be returned as if
// SELECT * FROM. So using ScanCAS with INSERT is inherently prone to
// column mismatching. MapScanCAS is added to capture them safely.
func (q *Query) MapScanCASContext(ctx context.Context, dest map[string]interface{}) (applied bool, err error) {
q.disableSkipMetadata = true
iter := q.Iter()
iter := q.IterContext(ctx)
if err := iter.checkErrAndNotFound(); err != nil {
return false, err
}
Expand Down Expand Up @@ -1834,6 +1890,8 @@ func (b *Batch) SetConsistency(c Consistency) {
b.Cons = c
}

// Deprecated: Context retrieval is deprecated. Pass context directly to execution methods
// like ExecContext or IterContext instead.
func (b *Batch) Context() context.Context {
if b.context == nil {
return context.Background()
Expand Down