diff --git a/lib/column/bigint.go b/lib/column/bigint.go index 6c9e4c26d9..be45826a2b 100644 --- a/lib/column/bigint.go +++ b/lib/column/bigint.go @@ -68,17 +68,23 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) { case []big.Int: nulls = make([]uint8, len(v)) for i := range v { - col.append(&v[i]) + if err := col.append(&v[i]); err != nil { + return nil, err + } } case []*big.Int: nulls = make([]uint8, len(v)) for i := range v { switch { case v[i] != nil: - col.append(v[i]) + if err := col.append(v[i]); err != nil { + return nil, err + } default: nulls[i] = 1 - col.append(big.NewInt(0)) + if err := col.append(big.NewInt(0)); err != nil { + return nil, err + } } } default: @@ -106,16 +112,16 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) { func (col *BigInt) AppendRow(v any) error { switch v := v.(type) { case big.Int: - col.append(&v) + return col.append(&v) case *big.Int: switch { case v != nil: - col.append(v) + return col.append(v) default: - col.append(big.NewInt(0)) + return col.append(big.NewInt(0)) } case nil: - col.append(big.NewInt(0)) + return col.append(big.NewInt(0)) default: if valuer, ok := v.(driver.Valuer); ok { val, err := valuer.Value() @@ -135,7 +141,6 @@ func (col *BigInt) AppendRow(v any) error { From: fmt.Sprintf("%T", v), } } - return nil } func (col *BigInt) Decode(reader *proto.Reader, rows int) error { @@ -177,9 +182,16 @@ func (col *BigInt) row(i int) *big.Int { return big.NewInt(0) } -func (col *BigInt) append(v *big.Int) { +func (col *BigInt) append(v *big.Int) error { dest := make([]byte, col.size) - bigIntToRaw(dest, new(big.Int).Set(v)) + if err := bigIntToRaw(dest, v, col.signed); err != nil { + return &ColumnConverterError{ + Op: "Append", + To: string(col.chType), + From: "big.Int", + Hint: err.Error(), + } + } switch v := col.col.(type) { case *proto.ColInt128: v.Append(proto.Int128{ @@ -214,17 +226,39 @@ func (col *BigInt) append(v *big.Int) { }, }) } + return nil } -func bigIntToRaw(dest []byte, v *big.Int) { +func bigIntToRaw(dest []byte, v *big.Int, signed bool) error { + bits := len(dest) * 8 + if signed { + if v.Sign() >= 0 { + if v.BitLen() > bits-1 { + return fmt.Errorf("value overflows %d-byte signed buffer", len(dest)) + } + } else { + if new(big.Int).Not(v).BitLen() > bits-1 { + return fmt.Errorf("value overflows %d-byte signed buffer", len(dest)) + } + } + } else { + if v.Sign() < 0 { + return fmt.Errorf("negative value not allowed for unsigned type") + } + if v.BitLen() > bits { + return fmt.Errorf("value overflows %d-byte unsigned buffer", len(dest)) + } + } + var sign int if v.Sign() < 0 { - v.Not(v).FillBytes(dest) + new(big.Int).Not(v).FillBytes(dest) sign = -1 } else { v.FillBytes(dest) } endianSwap(dest, sign < 0) + return nil } func rawToBigInt(v []byte, signed bool) *big.Int { diff --git a/lib/column/decimal.go b/lib/column/decimal.go index 66d103d873..aab7e5e010 100644 --- a/lib/column/decimal.go +++ b/lib/column/decimal.go @@ -6,7 +6,7 @@ import ( "encoding/binary" "errors" "fmt" - "math/big" + "math" "reflect" "strconv" "strings" @@ -139,18 +139,24 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { case []decimal.Decimal: nulls = make([]uint8, len(v)) for i := range v { - col.append(&v[i]) + if err := col.append(&v[i]); err != nil { + return nil, err + } } case []*decimal.Decimal: nulls = make([]uint8, len(v)) for i := range v { switch { case v[i] != nil: - col.append(v[i]) + if err := col.append(v[i]); err != nil { + return nil, err + } default: nulls[i] = 1 value := decimal.New(0, 0) - col.append(&value) + if err := col.append(&value); err != nil { + return nil, err + } } } case []string: @@ -160,7 +166,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if err != nil { return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", v[i], err) } - col.append(&d) + if err := col.append(&d); err != nil { + return nil, err + } } case []*string: nulls = make([]uint8, len(v)) @@ -168,8 +176,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if v[i] == nil { nulls[i] = 1 value := decimal.New(0, 0) - col.append(&value) - + if err := col.append(&value); err != nil { + return nil, err + } continue } @@ -177,7 +186,9 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) { if err != nil { return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", *v[i], err) } - col.append(&d) + if err := col.append(&d); err != nil { + return nil, err + } } default: if valuer, ok := v.(driver.Valuer); ok { @@ -244,34 +255,41 @@ func (col *Decimal) AppendRow(v any) error { From: fmt.Sprintf("%T", v), } } - col.append(&value) - return nil + return col.append(&value) } -func (col *Decimal) append(v *decimal.Decimal) { +func (col *Decimal) append(v *decimal.Decimal) error { switch vCol := col.col.(type) { case *proto.ColDecimal32: - var part uint32 - part = uint32(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart()) - vCol.Append(proto.Decimal32(part)) + scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)) + bi := scaled.BigInt() + if !bi.IsInt64() || bi.Int64() > math.MaxInt32 || bi.Int64() < math.MinInt32 { + return fmt.Errorf("value %s overflows decimal32 range", v.String()) + } + vCol.Append(proto.Decimal32(uint32(bi.Int64()))) case *proto.ColDecimal64: - var part uint64 - part = uint64(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart()) - vCol.Append(proto.Decimal64(part)) + scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)) + bi := scaled.BigInt() + if !bi.IsInt64() { + return fmt.Errorf("value %s overflows decimal64 range", v.String()) + } + vCol.Append(proto.Decimal64(uint64(bi.Int64()))) case *proto.ColDecimal128: - var bi *big.Int - bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() + bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() dest := make([]byte, 16) - bigIntToRaw(dest, bi) + if err := bigIntToRaw(dest, bi, true); err != nil { + return fmt.Errorf("value %s overflows decimal128 range", v.String()) + } vCol.Append(proto.Decimal128{ Low: binary.LittleEndian.Uint64(dest[0 : 64/8]), High: binary.LittleEndian.Uint64(dest[64/8 : 128/8]), }) case *proto.ColDecimal256: - var bi *big.Int - bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() + bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt() dest := make([]byte, 32) - bigIntToRaw(dest, bi) + if err := bigIntToRaw(dest, bi, true); err != nil { + return fmt.Errorf("value %s overflows decimal256 range", v.String()) + } vCol.Append(proto.Decimal256{ Low: proto.UInt128{ Low: binary.LittleEndian.Uint64(dest[0 : 64/8]), @@ -283,6 +301,7 @@ func (col *Decimal) append(v *decimal.Decimal) { }, }) } + return nil } func (col *Decimal) Decode(reader *proto.Reader, rows int) error { diff --git a/lib/column/decimal_overflow_test.go b/lib/column/decimal_overflow_test.go new file mode 100644 index 0000000000..da4d333980 --- /dev/null +++ b/lib/column/decimal_overflow_test.go @@ -0,0 +1,92 @@ +package column + +import ( + "math/big" + "testing" + + "github.com/ClickHouse/ch-go/proto" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDecimal32OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(9, 2)") + require.NoError(t, err) + + // max int32 is 2147483647; scaled by 10^2 the max representable value is ~21474836.47 + overflow, err := decimal.NewFromString("21474836.48") + require.NoError(t, err) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal64OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(18, 2)") + require.NoError(t, err) + + // max int64 is 9223372036854775807; scaled by 10^2 the max representable value is ~92233720368547758.07 + overflow, err := decimal.NewFromString("92233720368547758.08") + require.NoError(t, err) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal128OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(38, 0)") + require.NoError(t, err) + + // 2^127 exceeds Decimal128 signed range + big2_127 := new(big.Int).Lsh(big.NewInt(1), 127) + overflow := decimal.NewFromBigInt(big2_127, 0) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestDecimal256OverflowReturnsError(t *testing.T) { + col := &Decimal{} + _, err := col.parse("Decimal(76, 0)") + require.NoError(t, err) + + // 2^255 exceeds Decimal256 signed range + big2_255 := new(big.Int).Lsh(big.NewInt(1), 255) + overflow := decimal.NewFromBigInt(big2_255, 0) + err = col.AppendRow(overflow) + assert.ErrorContains(t, err, "overflow") +} + +func TestBigIntOverflowReturnsError(t *testing.T) { + // Int128: signed 128-bit, max positive is 2^127-1 + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + big2_127 := new(big.Int).Lsh(big.NewInt(1), 127) + err := col128.AppendRow(*big2_127) + assert.ErrorContains(t, err, "overflow") +} + +func TestBigIntValidValuesNoError(t *testing.T) { + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + // 2^127 - 1 is the max valid Int128 value + maxInt128 := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 127), big.NewInt(1)) + err := col128.AppendRow(*maxInt128) + assert.NoError(t, err) + + // min valid Int128 value is -2^127 + minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127)) + err = col128.AppendRow(*minInt128) + assert.NoError(t, err) +} + +func TestBigIntNegativeOverflowReturnsError(t *testing.T) { + col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}} + + // -2^127 - 1 is below the minimum Int128 value (-2^127) + minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127)) + overflow := new(big.Int).Sub(minInt128, big.NewInt(1)) + err := col128.AppendRow(*overflow) + assert.ErrorContains(t, err, "overflow") +} diff --git a/tests/issues/issue_1849_test.go b/tests/issues/issue_1849_test.go new file mode 100644 index 0000000000..3a1ad427aa --- /dev/null +++ b/tests/issues/issue_1849_test.go @@ -0,0 +1,134 @@ +package issues + +import ( + "context" + "database/sql" + "fmt" + "math/big" + "strconv" + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ClickHouse/clickhouse-go/v2" + clickhouse_tests "github.com/ClickHouse/clickhouse-go/v2/tests" + clickhouse_std_tests "github.com/ClickHouse/clickhouse-go/v2/tests/std" +) + +// overflow1849Cases captures the four overflow scenarios that previously +// either panicked (BigInt) or silently truncated (Decimal). For each case the +// driver must now return an error containing "overflow". +func overflow1849Cases(t *testing.T) []struct { + name string + d128 any + i128 any +} { + t.Helper() + big2_127 := new(big.Int).Lsh(big.NewInt(1), 127) + min128 := new(big.Int).Neg(big2_127) + belowMin128 := new(big.Int).Sub(min128, big.NewInt(1)) + + return []struct { + name string + d128 any + i128 any + }{ + {"Decimal128PositiveOverflow", decimal.NewFromBigInt(big2_127, 0), big.NewInt(0)}, + {"Decimal128NegativeOverflow", decimal.NewFromBigInt(belowMin128, 0), big.NewInt(0)}, + {"Int128PositiveOverflow", decimal.NewFromBigInt(big.NewInt(0), 0), *big2_127}, + {"Int128NegativeOverflow", decimal.NewFromBigInt(big.NewInt(0), 0), *belowMin128}, + } +} + +// TestIssue1849 verifies that appending overflow values to Decimal and +// BigInt columns returns an error instead of panicking. Covered surfaces: +// - native driver.Conn over TCP +// - native driver.Conn over HTTP +// - database/sql over TCP +// - database/sql over HTTP +// +// Regression test for https://github.com/ClickHouse/clickhouse-go/issues/1849. +func TestIssue1849(t *testing.T) { + const ddl = `CREATE TABLE test_issue_1849 ( + d128 Decimal(38, 0), + i128 Int128 + ) Engine MergeTree() ORDER BY tuple()` + + t.Run("Native", func(t *testing.T) { + ctx := context.Background() + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + conn, err := clickhouse_tests.GetConnection(testSet, t, protocol, nil, nil, nil) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849")) + require.NoError(t, conn.Exec(ctx, ddl)) + t.Cleanup(func() { _ = conn.Exec(ctx, "DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range overflow1849Cases(t) { + t.Run(tc.name, func(t *testing.T) { + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_issue_1849") + require.NoError(t, err) + t.Cleanup(func() { _ = batch.Abort() }) + + err = batch.Append(tc.d128, tc.i128) + assert.ErrorContains(t, err, "overflow") + }) + } + }) + } + }) + + t.Run("Std", func(t *testing.T) { + useSSL, err := strconv.ParseBool(clickhouse_tests.GetEnv("CLICKHOUSE_USE_SSL", "false")) + require.NoError(t, err) + + for _, protocol := range []clickhouse.Protocol{clickhouse.Native, clickhouse.HTTP} { + t.Run(protocol.String(), func(t *testing.T) { + db, err := clickhouse_std_tests.GetDSNConnection(testSet, protocol, useSSL, nil) + require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + + _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") + _, err = db.Exec(ddl) + require.NoError(t, err) + t.Cleanup(func() { _, _ = db.Exec("DROP TABLE IF EXISTS test_issue_1849") }) + + for _, tc := range overflow1849Cases(t) { + t.Run(tc.name, func(t *testing.T) { + appendErr := stdInsertOverflow(db, tc.d128, tc.i128) + assert.ErrorContains(t, appendErr, "overflow") + }) + } + }) + } + }) +} + +// stdInsertOverflow runs a single-row INSERT through the database/sql +// surface and returns the first error encountered. The overflow check +// fires inside the column converter, so the error normally surfaces from +// ExecContext. We still drive the full Begin → Prepare → Exec → Commit +// flow so any caller-visible regression is caught. +func stdInsertOverflow(db *sql.DB, d128, i128 any) error { + ctx := context.Background() + scope, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin: %w", err) + } + defer func() { _ = scope.Rollback() }() + + stmt, err := scope.PrepareContext(ctx, "INSERT INTO test_issue_1849") + if err != nil { + return fmt.Errorf("prepare: %w", err) + } + defer stmt.Close() + + if _, err := stmt.ExecContext(ctx, d128, i128); err != nil { + return err + } + return scope.Commit() +}