Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,50 @@ func TestSkipMetadata(t *testing.T) {
}
}

func TestPrepareBatchMetadataMultipleKeyspaceTables(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

srv := NewTestServer(t, protoVersion4, ctx)
defer srv.Stop()

cfg := testCluster(protoVersion4, srv.Address)
db, err := cfg.CreateSession()
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
defer db.Close()

conn := db.getConn()
if conn == nil {
t.Fatal("expected connection, got nil")
}

stmt := "BEGIN BATCH INSERT INTO ks1.tbl1 (col1) VALUES (?) INSERT INTO ks2.tbl2 (col2) VALUES (?) APPLY BATCH"
info, err := conn.prepareStatement(ctx, stmt, nil, time.Second)
if err != nil {
t.Fatalf("prepareStatement failed: %v", err)
}

if got := len(info.request.columns); got != 2 {
t.Fatalf("expected 2 request columns, got %d", got)
}

col0 := info.request.columns[0]
if col0.Keyspace != "ks1" || col0.Table != "tbl1" || col0.Name != "col1" {
t.Fatalf("unexpected column 0: %+v", col0)
}

col1 := info.request.columns[1]
if col1.Keyspace != "ks2" || col1.Table != "tbl2" || col1.Name != "col2" {
t.Fatalf("unexpected column 1: %+v", col1)
}

if info.request.keyspace != "" || info.request.table != "" {
t.Fatalf("expected empty prepared keyspace/table for mixed batch, got %q/%q", info.request.keyspace, info.request.table)
}
}

type recordingFrameHeaderObserver struct {
t *testing.T
mu sync.Mutex
Expand Down Expand Up @@ -1425,12 +1469,20 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string]
respFrame.writeHeader(0, frm.OpError, head.Stream)
respFrame.buf = append(respFrame.buf, reqFrame.buf...)
case frm.OpPrepare:
query := reqFrame.readLongString()
name := strings.TrimPrefix(query, "select ")
if n := strings.Index(name, " "); n > 0 {
name = name[:n]
query := strings.TrimSpace(reqFrame.readLongString())
lower := strings.ToLower(query)
name := ""
if strings.HasPrefix(lower, "select ") {
name = strings.TrimPrefix(lower, "select ")
if n := strings.Index(name, " "); n > 0 {
name = name[:n]
}
} else if strings.HasPrefix(lower, "begin batch") {
name = "batchmetadata"
} else {
name = lower
}
switch strings.ToLower(name) {
switch name {
case "nometadata":
respFrame.writeHeader(0, frm.OpResult, head.Stream)
respFrame.writeInt(frm.ResultKindPrepared)
Expand Down Expand Up @@ -1465,6 +1517,30 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string]
// <col_spec_0>
respFrame.writeString("col0") // <name>
respFrame.writeShort(uint16(TypeBoolean)) // <type>
case "batchmetadata":
respFrame.writeHeader(0, frm.OpResult, head.Stream)
respFrame.writeInt(frm.ResultKindPrepared)
// <id>
respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 3))
// <metadata>
respFrame.writeInt(0) // <flags>
respFrame.writeInt(2) // <columns_count>
if srv.protocol >= protoVersion4 {
respFrame.writeInt(0) // <pk_count>
}
// <col_spec_0>
respFrame.writeString("ks1")
respFrame.writeString("tbl1")
respFrame.writeString("col1")
respFrame.writeShort(uint16(TypeInt))
// <col_spec_1>
respFrame.writeString("ks2")
respFrame.writeString("tbl2")
respFrame.writeString("col2")
respFrame.writeShort(uint16(TypeInt))
// <result_metadata>
respFrame.writeInt(int32(frm.FlagNoMetaData))
respFrame.writeInt(0)
default:
respFrame.writeHeader(0, frm.OpError, head.Stream)
respFrame.writeInt(0)
Expand Down
81 changes: 69 additions & 12 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,11 @@ func (f *framer) readTypeInfo() TypeInfo {
typ: Type(id),
}

// Fast path for simple native types (through TypeDuration).
if id > 0 && id <= uint16(TypeDuration) {
return simple
}

if simple.typ == TypeCustom {
simple.custom = f.readString()
if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom {
Expand Down Expand Up @@ -835,22 +840,36 @@ func (f *framer) parsePreparedMetadata() preparedMetadata {
}

var cols []ColumnInfo
readPerColumnSpec := !globalSpec
var tracker keyspaceTableTracker
if meta.colCount < 1000 {
// preallocate columninfo to avoid excess copying
cols = make([]ColumnInfo, meta.colCount)
for i := 0; i < meta.colCount; i++ {
f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
col := &cols[i]
keyspace, table := f.readColWithSpec(col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table, i, readPerColumnSpec)
if readPerColumnSpec {
tracker.track(i, keyspace, table)
}
}
} else {
// use append, huge number of columns usually indicates a corrupt frame or
// just a huge row.
for i := 0; i < meta.colCount; i++ {
var col ColumnInfo
f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table)
keyspace, table := f.readColWithSpec(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table, i, readPerColumnSpec)
if readPerColumnSpec {
tracker.track(i, keyspace, table)
}
cols = append(cols, col)
}
}

if !globalSpec && meta.colCount > 0 && tracker.allSame {
meta.keyspace = tracker.keyspace
meta.table = tracker.table
}

meta.columns = cols

return meta
Expand All @@ -875,23 +894,46 @@ func (r resultMetadata) String() string {
return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns)
}

func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) {
if !globalSpec {
// keyspaceTableTracker tracks whether all columns share the same keyspace/table.
type keyspaceTableTracker struct {
keyspace string
table string
allSame bool
}

func (t *keyspaceTableTracker) track(colIndex int, keyspace, table string) {
if colIndex == 0 {
t.keyspace = keyspace
t.table = table
t.allSame = true
Comment thread
mykaul marked this conversation as resolved.
} else if t.allSame && (keyspace != t.keyspace || table != t.table) {
t.allSame = false
}
}

func (f *framer) readColWithSpec(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string, colIndex int, readPerColumnSpec bool) (string, string) {
if readPerColumnSpec {
// Per-column table spec encoding: read keyspace/table for this column.
col.Keyspace = f.readString()
col.Table = f.readString()
} else {
if !globalSpec && colIndex != 0 {
// Skip per-column keyspace/table already read from column 0.
f.skipString()
f.skipString()
}
col.Keyspace = keyspace
col.Table = table
Comment thread
mykaul marked this conversation as resolved.
}

col.Name = f.readString()
col.TypeInfo = f.readTypeInfo()
switch v := col.TypeInfo.(type) {
// maybe also UDT
case TupleTypeInfo:
if tuple, ok := col.TypeInfo.(TupleTypeInfo); ok {
// -1 because we already included the tuple column
meta.actualColCount += len(v.Elems) - 1
meta.actualColCount += len(tuple.Elems) - 1
}

return col.Keyspace, col.Table
}

func (f *framer) parseResultMetadata() resultMetadata {
Expand All @@ -912,9 +954,13 @@ func (f *framer) parseResultMetadata() resultMetadata {
return meta
}

var keyspace, table string
globalSpec := meta.flags&frm.FlagGlobalTableSpec == frm.FlagGlobalTableSpec
if globalSpec {

// Read keyspace/table once and reuse for all columns. ROWS results are
// always single-table; when !globalSpec this consumes column 0's wire
// values and readColWithSpec skips the rest via skipString().
var keyspace, table string
if globalSpec || meta.colCount > 0 {
keyspace = f.readString()
table = f.readString()
}
Comment thread
mykaul marked this conversation as resolved.
Comment thread
mykaul marked this conversation as resolved.
Expand All @@ -924,15 +970,15 @@ func (f *framer) parseResultMetadata() resultMetadata {
// preallocate columninfo to avoid excess copying
cols = make([]ColumnInfo, meta.colCount)
for i := 0; i < meta.colCount; i++ {
f.readCol(&cols[i], &meta, globalSpec, keyspace, table)
f.readColWithSpec(&cols[i], &meta, globalSpec, keyspace, table, i, false)
}

} else {
// use append, huge number of columns usually indicates a corrupt frame or
// just a huge row.
for i := 0; i < meta.colCount; i++ {
var col ColumnInfo
f.readCol(&col, &meta, globalSpec, keyspace, table)
f.readColWithSpec(&col, &meta, globalSpec, keyspace, table, i, false)
cols = append(cols, col)
}
}
Expand Down Expand Up @@ -1484,6 +1530,17 @@ func (f *framer) readString() (s string) {
return
}

// skipString advances past a string without allocating.
func (f *framer) skipString() {
size := f.readShort()

if len(f.buf) < int(size) {
panic(fmt.Errorf("not enough bytes in buffer to skip string, requires %d got %d", size, len(f.buf)))
}

f.buf = f.buf[size:]
}

func (f *framer) readLongString() (s string) {
size := f.readInt()

Expand Down
78 changes: 78 additions & 0 deletions frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,84 @@ func TestFrameReadTooLong(t *testing.T) {
}
}

func TestParseResultMetadata_PerColumnSpec(t *testing.T) {
t.Parallel()

// Build a synthetic ROWS result metadata frame with FlagGlobalTableSpec unset
// (per-column keyspace/table encoding). This tests the !globalSpec optimization
// in parseResultMetadata() which reads keyspace/table from the first column
// position and reuses them for all columns via skipString().
fr := newFramer(nil, protoVersion4)
fr.header = &frm.FrameHeader{Version: protoVersion4}

// flags: no FlagGlobalTableSpec — per-column keyspace/table
fr.writeInt(0)
// colCount
fr.writeInt(3)

// Column 0: keyspace/table + name + type
fr.writeString("test_ks")
fr.writeString("test_tbl")
fr.writeString("col_a")
fr.writeShort(uint16(TypeInt))

// Column 1: same keyspace/table (will be skipped by optimization)
fr.writeString("test_ks")
fr.writeString("test_tbl")
fr.writeString("col_b")
fr.writeShort(uint16(TypeVarchar))

// Column 2: same keyspace/table
fr.writeString("test_ks")
fr.writeString("test_tbl")
fr.writeString("col_c")
fr.writeShort(uint16(TypeBoolean))

meta := fr.parseResultMetadata()

if meta.colCount != 3 {
t.Fatalf("colCount = %d, want 3", meta.colCount)
}
if len(meta.columns) != 3 {
t.Fatalf("len(columns) = %d, want 3", len(meta.columns))
}

// Verify all columns got the correct keyspace/table from the optimization
for i, col := range meta.columns {
if col.Keyspace != "test_ks" {
t.Errorf("columns[%d].Keyspace = %q, want %q", i, col.Keyspace, "test_ks")
}
if col.Table != "test_tbl" {
t.Errorf("columns[%d].Table = %q, want %q", i, col.Table, "test_tbl")
}
}

// Verify column names
expectedNames := []string{"col_a", "col_b", "col_c"}
for i, col := range meta.columns {
if col.Name != expectedNames[i] {
t.Errorf("columns[%d].Name = %q, want %q", i, col.Name, expectedNames[i])
}
}

// Verify column types
expectedTypes := []Type{TypeInt, TypeVarchar, TypeBoolean}
for i, col := range meta.columns {
nt, ok := col.TypeInfo.(NativeType)
if !ok {
t.Fatalf("columns[%d].TypeInfo is %T, want NativeType", i, col.TypeInfo)
}
if nt.typ != expectedTypes[i] {
t.Errorf("columns[%d].Type = %v, want %v", i, nt.typ, expectedTypes[i])
}
}

// Verify the entire buffer was consumed (no misalignment from skipString)
if len(fr.buf) != 0 {
t.Errorf("buffer has %d unconsumed bytes, want 0 (possible skipString misalignment)", len(fr.buf))
}
}

func TestParseEventFrame_ClientRoutesChanged(t *testing.T) {
t.Parallel()

Expand Down
Loading