Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,21 +664,21 @@ func TestCAS(t *testing.T) {
}

failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
if _, _, err := session.ExecuteBatchCAS(failBatch, new(bool)); err == nil {
t.Fatal("update should have errored")
}
// make sure MapScanCAS does not panic when MapScan fails
casMap = make(map[string]interface{})
casMap["last_modified"] = false
if _, err := session.Query(`UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`,
if _, err := session.Query(`UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?`,
modified).MapScanCAS(casMap); err == nil {
t.Fatal("update should hvae errored", err)
}

// make sure MapExecuteBatchCAS does not panic when MapScan fails
failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
failBatch.Query("UPDATE cas_table SET last_modified = TOTIMESTAMP(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified = ?", modified)
casMap = make(map[string]interface{})
casMap["last_modified"] = false
if _, _, err := session.MapExecuteBatchCAS(failBatch, casMap); err == nil {
Expand Down
17 changes: 17 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"io/ioutil"
"net"
"runtime"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -913,6 +914,22 @@ func (f *framer) readTypeInfo() TypeInfo {
collection.Elem = f.readTypeInfo()

return collection
case TypeCustom:
if strings.HasPrefix(simple.custom, "org.apache.cassandra.db.marshal.VectorType") {
spec := strings.TrimPrefix(simple.custom, "org.apache.cassandra.db.marshal.VectorType")
spec = spec[1 : len(spec)-1] // remove parenthesis
idx := strings.LastIndex(spec, ",")
typeStr := spec[:idx]
dimStr := spec[idx+1:]
subType := getCassandraLongType(strings.TrimSpace(typeStr), f.proto, nopLogger{})
dim, _ := strconv.Atoi(strings.TrimSpace(dimStr))
vector := VectorType{
NativeType: simple,
SubType: subType,
Dimensions: dim,
}
return vector
}
}

return simple
Expand Down
212 changes: 189 additions & 23 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
package gocql

import (
"encoding/hex"
"fmt"
"math/big"
"reflect"
"strconv"
"strings"
"time"

Expand All @@ -39,6 +41,46 @@ type RowData struct {
Values []interface{}
}

// asVectorType attempts to convert a NativeType(custom) which represents a VectorType
// into a concrete VectorType. It also works recursively (nested vectors).
func asVectorType(t TypeInfo) (VectorType, bool) {
if v, ok := t.(VectorType); ok {
return v, true
}
n, ok := t.(NativeType)
if !ok || n.Type() != TypeCustom {
return VectorType{}, false
}
const prefix = "org.apache.cassandra.db.marshal.VectorType"
if !strings.HasPrefix(n.Custom(), prefix+"(") {
return VectorType{}, false
}

spec := strings.TrimPrefix(n.Custom(), prefix)
spec = strings.Trim(spec, "()")
// split last comma -> subtype spec , dimensions
idx := strings.LastIndex(spec, ",")
if idx <= 0 {
return VectorType{}, false
}
subStr := strings.TrimSpace(spec[:idx])
dimStr := strings.TrimSpace(spec[idx+1:])
dim, err := strconv.Atoi(dimStr)
if err != nil {
return VectorType{}, false
}
subType := getCassandraLongType(subStr, n.Version(), nopLogger{})
// recurse if subtype itself is still a custom vector
if innerVec, ok := asVectorType(subType); ok {
subType = innerVec
}
return VectorType{
NativeType: NewCustomType(n.Version(), TypeCustom, prefix),
SubType: subType,
Dimensions: dim,
}, true
}

func goType(t TypeInfo) (reflect.Type, error) {
switch t.Type() {
case TypeVarchar, TypeAscii, TypeInet, TypeText:
Expand Down Expand Up @@ -95,6 +137,20 @@ func goType(t TypeInfo) (reflect.Type, error) {
return reflect.TypeOf(*new(time.Time)), nil
case TypeDuration:
return reflect.TypeOf(*new(Duration)), nil
case TypeCustom:
// Handle VectorType encoded as custom
if vec, ok := asVectorType(t); ok {
innerPtr, err := vec.SubType.NewWithError()
if err != nil {
return nil, err
}
elemType := reflect.TypeOf(innerPtr)
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
return reflect.SliceOf(elemType), nil
}
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
default:
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
}
Expand Down Expand Up @@ -161,59 +217,165 @@ func getCassandraBaseType(name string) Type {
}
}

func getCassandraType(name string, logger StdLogger) TypeInfo {
// TODO: Cover with unit tests.
// Parses long Java-style type definition to internal data structures.
func getCassandraLongType(name string, protoVer byte, logger StdLogger) TypeInfo {
if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.SetType") {
return CollectionType{
NativeType: NewNativeType(protoVer, TypeSet),
Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, "org.apache.cassandra.db.marshal.SetType", '('), protoVer, logger),
}
} else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.ListType") {
return CollectionType{
NativeType: NewNativeType(protoVer, TypeList),
Elem: getCassandraLongType(unwrapCompositeTypeDefinition(name, "org.apache.cassandra.db.marshal.ListType", '('), protoVer, logger),
}
} else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.MapType") {
names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.MapType")
if len(names) != 2 {
logger.Printf("gocql: error parsing map type, it has %d subelements, expecting 2\n", len(names))
return NewNativeType(protoVer, TypeCustom)
}
return CollectionType{
NativeType: NewNativeType(protoVer, TypeMap),
Key: getCassandraLongType(names[0], protoVer, logger),
Elem: getCassandraLongType(names[1], protoVer, logger),
}
} else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.TupleType") {
names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.TupleType")
types := make([]TypeInfo, len(names))

for i, name := range names {
types[i] = getCassandraLongType(name, protoVer, logger)
}

return TupleTypeInfo{
NativeType: NewNativeType(protoVer, TypeTuple),
Elems: types,
}
} else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.UserType") {
names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.UserType")
fields := make([]UDTField, len(names)-2)

for i := 2; i < len(names); i++ {
spec := strings.Split(names[i], ":")
fieldName, _ := hex.DecodeString(spec[0])
fields[i-2] = UDTField{
Name: string(fieldName),
Type: getCassandraLongType(spec[1], protoVer, logger),
}
}

udtName, _ := hex.DecodeString(names[1])
return UDTTypeInfo{
NativeType: NewNativeType(protoVer, TypeUDT),
KeySpace: names[0],
Name: string(udtName),
Elements: fields,
}
} else if strings.HasPrefix(name, "org.apache.cassandra.db.marshal.VectorType") {
names := splitJavaCompositeTypes(name, "org.apache.cassandra.db.marshal.VectorType")
subType := getCassandraLongType(strings.TrimSpace(names[0]), protoVer, logger)
dim, err := strconv.Atoi(strings.TrimSpace(names[1]))
if err != nil {
logger.Printf("gocql: error parsing vector dimensions: %v\n", err)
return NewNativeType(protoVer, TypeCustom)
}

return VectorType{
NativeType: NewCustomType(protoVer, TypeCustom, "org.apache.cassandra.db.marshal.VectorType"),
SubType: subType,
Dimensions: dim,
}
} else {
// basic type
return NativeType{
proto: protoVer,
typ: getApacheCassandraType(name),
}
}
}

// Parses short CQL type representation (e.g. map<text, text>) to internal data structures.
func getCassandraType(name string, protoVer byte, logger StdLogger) TypeInfo {
if strings.HasPrefix(name, "frozen<") {
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger)
return getCassandraType(unwrapCompositeTypeDefinition(name, "frozen", '<'), protoVer, logger)
} else if strings.HasPrefix(name, "set<") {
return CollectionType{
NativeType: NativeType{typ: TypeSet},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger),
NativeType: NewNativeType(protoVer, TypeSet),
Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "set", '<'), protoVer, logger),
}
} else if strings.HasPrefix(name, "list<") {
return CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger),
NativeType: NewNativeType(protoVer, TypeList),
Elem: getCassandraType(unwrapCompositeTypeDefinition(name, "list", '<'), protoVer, logger),
}
} else if strings.HasPrefix(name, "map<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
names := splitCQLCompositeTypes(name, "map")
if len(names) != 2 {
logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
return NativeType{
typ: TypeCustom,
}
return NewNativeType(protoVer, TypeCustom)
}
return CollectionType{
NativeType: NativeType{typ: TypeMap},
Key: getCassandraType(names[0], logger),
Elem: getCassandraType(names[1], logger),
NativeType: NewNativeType(protoVer, TypeMap),
Key: getCassandraType(names[0], protoVer, logger),
Elem: getCassandraType(names[1], protoVer, logger),
}
} else if strings.HasPrefix(name, "tuple<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
names := splitCQLCompositeTypes(name, "tuple")
types := make([]TypeInfo, len(names))

for i, name := range names {
types[i] = getCassandraType(name, logger)
types[i] = getCassandraType(name, protoVer, logger)
}

return TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},
NativeType: NewNativeType(protoVer, TypeTuple),
Elems: types,
}
} else if strings.HasPrefix(name, "vector<") {
names := splitCQLCompositeTypes(name, "vector")
subType := getCassandraType(strings.TrimSpace(names[0]), protoVer, logger)
dim, _ := strconv.Atoi(strings.TrimSpace(names[1]))

return VectorType{
NativeType: NewCustomType(protoVer, TypeCustom, "org.apache.cassandra.db.marshal.VectorType"),
SubType: subType,
Dimensions: dim,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
proto: protoVer,
typ: getCassandraBaseType(name),
}
}
}

func splitCompositeTypes(name string) []string {
if !strings.Contains(name, "<") {
return strings.Split(name, ", ")
func splitCQLCompositeTypes(name string, typeName string) []string {
return splitCompositeTypes(name, typeName, '<', '>')
}

func splitJavaCompositeTypes(name string, typeName string) []string {
return splitCompositeTypes(name, typeName, '(', ')')
}

func unwrapCompositeTypeDefinition(name string, typeName string, typeOpen int32) string {
return strings.TrimPrefix(name[:len(name)-1], typeName+string(typeOpen))
}

func splitCompositeTypes(name string, typeName string, typeOpen int32, typeClose int32) []string {
def := unwrapCompositeTypeDefinition(name, typeName, typeOpen)
if !strings.Contains(def, string(typeOpen)) {
parts := strings.Split(def, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
}
return parts
}
var parts []string
lessCount := 0
segment := ""
for _, char := range name {
for _, char := range def {
if char == ',' && lessCount == 0 {
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
Expand All @@ -222,9 +384,9 @@ func splitCompositeTypes(name string) []string {
continue
}
segment += string(char)
if char == '<' {
if char == typeOpen {
lessCount++
} else if char == '>' {
} else if char == typeClose {
lessCount--
}
}
Expand Down Expand Up @@ -282,6 +444,10 @@ func getApacheCassandraType(class string) Type {
return TypeTuple
case "DurationType":
return TypeDuration
case "SimpleDateType":
return TypeDate
case "UserType":
return TypeUDT
default:
return TypeCustom
}
Expand Down
Loading
Loading