diff --git a/conn_test.go b/conn_test.go index 20d7e924e..171a11dbb 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1082,6 +1082,50 @@ func TestSkipMetadata(t *testing.T) { } } +func TestPrepareBatchMetadataMultipleKeyspaceTables(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + srv := NewTestServer(t, protoVersion4, ctx) + defer srv.Stop() + + cfg := testCluster(protoVersion4, srv.Address) + db, err := cfg.CreateSession() + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + defer db.Close() + + conn := db.getConn() + if conn == nil { + t.Fatal("expected connection, got nil") + } + + stmt := "BEGIN BATCH INSERT INTO ks1.tbl1 (col1) VALUES (?) INSERT INTO ks2.tbl2 (col2) VALUES (?) APPLY BATCH" + info, err := conn.prepareStatement(ctx, stmt, nil, time.Second) + if err != nil { + t.Fatalf("prepareStatement failed: %v", err) + } + + if got := len(info.request.columns); got != 2 { + t.Fatalf("expected 2 request columns, got %d", got) + } + + col0 := info.request.columns[0] + if col0.Keyspace != "ks1" || col0.Table != "tbl1" || col0.Name != "col1" { + t.Fatalf("unexpected column 0: %+v", col0) + } + + col1 := info.request.columns[1] + if col1.Keyspace != "ks2" || col1.Table != "tbl2" || col1.Name != "col2" { + t.Fatalf("unexpected column 1: %+v", col1) + } + + if info.request.keyspace != "" || info.request.table != "" { + t.Fatalf("expected empty prepared keyspace/table for mixed batch, got %q/%q", info.request.keyspace, info.request.table) + } +} + type recordingFrameHeaderObserver struct { t *testing.T mu sync.Mutex @@ -1425,12 +1469,20 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string] respFrame.writeHeader(0, frm.OpError, head.Stream) respFrame.buf = append(respFrame.buf, reqFrame.buf...) case frm.OpPrepare: - query := reqFrame.readLongString() - name := strings.TrimPrefix(query, "select ") - if n := strings.Index(name, " "); n > 0 { - name = name[:n] + query := strings.TrimSpace(reqFrame.readLongString()) + lower := strings.ToLower(query) + name := "" + if strings.HasPrefix(lower, "select ") { + name = strings.TrimPrefix(lower, "select ") + if n := strings.Index(name, " "); n > 0 { + name = name[:n] + } + } else if strings.HasPrefix(lower, "begin batch") { + name = "batchmetadata" + } else { + name = lower } - switch strings.ToLower(name) { + switch name { case "nometadata": respFrame.writeHeader(0, frm.OpResult, head.Stream) respFrame.writeInt(frm.ResultKindPrepared) @@ -1465,6 +1517,30 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string] // respFrame.writeString("col0") // respFrame.writeShort(uint16(TypeBoolean)) // + case "batchmetadata": + respFrame.writeHeader(0, frm.OpResult, head.Stream) + respFrame.writeInt(frm.ResultKindPrepared) + // + respFrame.writeShortBytes(binary.BigEndian.AppendUint64(nil, 3)) + // + respFrame.writeInt(0) // + respFrame.writeInt(2) // + if srv.protocol >= protoVersion4 { + respFrame.writeInt(0) // + } + // + respFrame.writeString("ks1") + respFrame.writeString("tbl1") + respFrame.writeString("col1") + respFrame.writeShort(uint16(TypeInt)) + // + respFrame.writeString("ks2") + respFrame.writeString("tbl2") + respFrame.writeString("col2") + respFrame.writeShort(uint16(TypeInt)) + // + respFrame.writeInt(int32(frm.FlagNoMetaData)) + respFrame.writeInt(0) default: respFrame.writeHeader(0, frm.OpError, head.Stream) respFrame.writeInt(0) diff --git a/frame.go b/frame.go index 4303c0c89..fac334235 100644 --- a/frame.go +++ b/frame.go @@ -712,6 +712,11 @@ func (f *framer) readTypeInfo() TypeInfo { typ: Type(id), } + // Fast path for simple native types (through TypeDuration). + if id > 0 && id <= uint16(TypeDuration) { + return simple + } + if simple.typ == TypeCustom { simple.custom = f.readString() if cassType := getApacheCassandraType(simple.custom); cassType != TypeCustom { @@ -835,22 +840,36 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { } var cols []ColumnInfo + readPerColumnSpec := !globalSpec + var tracker keyspaceTableTracker if meta.colCount < 1000 { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + col := &cols[i] + keyspace, table := f.readColWithSpec(col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table, i, readPerColumnSpec) + if readPerColumnSpec { + tracker.track(i, keyspace, table) + } } } else { // use append, huge number of columns usually indicates a corrupt frame or // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo - f.readCol(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table) + keyspace, table := f.readColWithSpec(&col, &meta.resultMetadata, globalSpec, meta.keyspace, meta.table, i, readPerColumnSpec) + if readPerColumnSpec { + tracker.track(i, keyspace, table) + } cols = append(cols, col) } } + if !globalSpec && meta.colCount > 0 && tracker.allSame { + meta.keyspace = tracker.keyspace + meta.table = tracker.table + } + meta.columns = cols return meta @@ -875,23 +894,46 @@ func (r resultMetadata) String() string { return fmt.Sprintf("[metadata flags=0x%x paging_state=% X columns=%v]", r.flags, r.pagingState, r.columns) } -func (f *framer) readCol(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string) { - if !globalSpec { +// keyspaceTableTracker tracks whether all columns share the same keyspace/table. +type keyspaceTableTracker struct { + keyspace string + table string + allSame bool +} + +func (t *keyspaceTableTracker) track(colIndex int, keyspace, table string) { + if colIndex == 0 { + t.keyspace = keyspace + t.table = table + t.allSame = true + } else if t.allSame && (keyspace != t.keyspace || table != t.table) { + t.allSame = false + } +} + +func (f *framer) readColWithSpec(col *ColumnInfo, meta *resultMetadata, globalSpec bool, keyspace, table string, colIndex int, readPerColumnSpec bool) (string, string) { + if readPerColumnSpec { + // Per-column table spec encoding: read keyspace/table for this column. col.Keyspace = f.readString() col.Table = f.readString() } else { + if !globalSpec && colIndex != 0 { + // Skip per-column keyspace/table already read from column 0. + f.skipString() + f.skipString() + } col.Keyspace = keyspace col.Table = table } col.Name = f.readString() col.TypeInfo = f.readTypeInfo() - switch v := col.TypeInfo.(type) { - // maybe also UDT - case TupleTypeInfo: + if tuple, ok := col.TypeInfo.(TupleTypeInfo); ok { // -1 because we already included the tuple column - meta.actualColCount += len(v.Elems) - 1 + meta.actualColCount += len(tuple.Elems) - 1 } + + return col.Keyspace, col.Table } func (f *framer) parseResultMetadata() resultMetadata { @@ -912,9 +954,13 @@ func (f *framer) parseResultMetadata() resultMetadata { return meta } - var keyspace, table string globalSpec := meta.flags&frm.FlagGlobalTableSpec == frm.FlagGlobalTableSpec - if globalSpec { + + // Read keyspace/table once and reuse for all columns. ROWS results are + // always single-table; when !globalSpec this consumes column 0's wire + // values and readColWithSpec skips the rest via skipString(). + var keyspace, table string + if globalSpec || meta.colCount > 0 { keyspace = f.readString() table = f.readString() } @@ -924,7 +970,7 @@ func (f *framer) parseResultMetadata() resultMetadata { // preallocate columninfo to avoid excess copying cols = make([]ColumnInfo, meta.colCount) for i := 0; i < meta.colCount; i++ { - f.readCol(&cols[i], &meta, globalSpec, keyspace, table) + f.readColWithSpec(&cols[i], &meta, globalSpec, keyspace, table, i, false) } } else { @@ -932,7 +978,7 @@ func (f *framer) parseResultMetadata() resultMetadata { // just a huge row. for i := 0; i < meta.colCount; i++ { var col ColumnInfo - f.readCol(&col, &meta, globalSpec, keyspace, table) + f.readColWithSpec(&col, &meta, globalSpec, keyspace, table, i, false) cols = append(cols, col) } } @@ -1484,6 +1530,17 @@ func (f *framer) readString() (s string) { return } +// skipString advances past a string without allocating. +func (f *framer) skipString() { + size := f.readShort() + + if len(f.buf) < int(size) { + panic(fmt.Errorf("not enough bytes in buffer to skip string, requires %d got %d", size, len(f.buf))) + } + + f.buf = f.buf[size:] +} + func (f *framer) readLongString() (s string) { size := f.readInt() diff --git a/frame_test.go b/frame_test.go index d4c933840..7dc4e6861 100644 --- a/frame_test.go +++ b/frame_test.go @@ -137,6 +137,84 @@ func TestFrameReadTooLong(t *testing.T) { } } +func TestParseResultMetadata_PerColumnSpec(t *testing.T) { + t.Parallel() + + // Build a synthetic ROWS result metadata frame with FlagGlobalTableSpec unset + // (per-column keyspace/table encoding). This tests the !globalSpec optimization + // in parseResultMetadata() which reads keyspace/table from the first column + // position and reuses them for all columns via skipString(). + fr := newFramer(nil, protoVersion4) + fr.header = &frm.FrameHeader{Version: protoVersion4} + + // flags: no FlagGlobalTableSpec — per-column keyspace/table + fr.writeInt(0) + // colCount + fr.writeInt(3) + + // Column 0: keyspace/table + name + type + fr.writeString("test_ks") + fr.writeString("test_tbl") + fr.writeString("col_a") + fr.writeShort(uint16(TypeInt)) + + // Column 1: same keyspace/table (will be skipped by optimization) + fr.writeString("test_ks") + fr.writeString("test_tbl") + fr.writeString("col_b") + fr.writeShort(uint16(TypeVarchar)) + + // Column 2: same keyspace/table + fr.writeString("test_ks") + fr.writeString("test_tbl") + fr.writeString("col_c") + fr.writeShort(uint16(TypeBoolean)) + + meta := fr.parseResultMetadata() + + if meta.colCount != 3 { + t.Fatalf("colCount = %d, want 3", meta.colCount) + } + if len(meta.columns) != 3 { + t.Fatalf("len(columns) = %d, want 3", len(meta.columns)) + } + + // Verify all columns got the correct keyspace/table from the optimization + for i, col := range meta.columns { + if col.Keyspace != "test_ks" { + t.Errorf("columns[%d].Keyspace = %q, want %q", i, col.Keyspace, "test_ks") + } + if col.Table != "test_tbl" { + t.Errorf("columns[%d].Table = %q, want %q", i, col.Table, "test_tbl") + } + } + + // Verify column names + expectedNames := []string{"col_a", "col_b", "col_c"} + for i, col := range meta.columns { + if col.Name != expectedNames[i] { + t.Errorf("columns[%d].Name = %q, want %q", i, col.Name, expectedNames[i]) + } + } + + // Verify column types + expectedTypes := []Type{TypeInt, TypeVarchar, TypeBoolean} + for i, col := range meta.columns { + nt, ok := col.TypeInfo.(NativeType) + if !ok { + t.Fatalf("columns[%d].TypeInfo is %T, want NativeType", i, col.TypeInfo) + } + if nt.typ != expectedTypes[i] { + t.Errorf("columns[%d].Type = %v, want %v", i, nt.typ, expectedTypes[i]) + } + } + + // Verify the entire buffer was consumed (no misalignment from skipString) + if len(fr.buf) != 0 { + t.Errorf("buffer has %d unconsumed bytes, want 0 (possible skipString misalignment)", len(fr.buf)) + } +} + func TestParseEventFrame_ClientRoutesChanged(t *testing.T) { t.Parallel()