Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions lib/column/bigint.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,23 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) {
case []big.Int:
nulls = make([]uint8, len(v))
for i := range v {
col.append(&v[i])
if err := col.append(&v[i]); err != nil {
return nil, err
}
}
case []*big.Int:
nulls = make([]uint8, len(v))
for i := range v {
switch {
case v[i] != nil:
col.append(v[i])
if err := col.append(v[i]); err != nil {
return nil, err
}
default:
nulls[i] = 1
col.append(big.NewInt(0))
if err := col.append(big.NewInt(0)); err != nil {
return nil, err
}
}
}
default:
Expand Down Expand Up @@ -106,16 +112,16 @@ func (col *BigInt) Append(v any) (nulls []uint8, err error) {
func (col *BigInt) AppendRow(v any) error {
switch v := v.(type) {
case big.Int:
col.append(&v)
return col.append(&v)
case *big.Int:
switch {
case v != nil:
col.append(v)
return col.append(v)
default:
col.append(big.NewInt(0))
return col.append(big.NewInt(0))
}
case nil:
col.append(big.NewInt(0))
return col.append(big.NewInt(0))
default:
if valuer, ok := v.(driver.Valuer); ok {
val, err := valuer.Value()
Expand All @@ -135,7 +141,6 @@ func (col *BigInt) AppendRow(v any) error {
From: fmt.Sprintf("%T", v),
}
}
return nil
}

func (col *BigInt) Decode(reader *proto.Reader, rows int) error {
Expand Down Expand Up @@ -177,9 +182,16 @@ func (col *BigInt) row(i int) *big.Int {
return big.NewInt(0)
}

func (col *BigInt) append(v *big.Int) {
func (col *BigInt) append(v *big.Int) error {
dest := make([]byte, col.size)
bigIntToRaw(dest, new(big.Int).Set(v))
if err := bigIntToRaw(dest, v, col.signed); err != nil {
return &ColumnConverterError{
Op: "Append",
To: string(col.chType),
From: "big.Int",
Hint: err.Error(),
}
}
switch v := col.col.(type) {
case *proto.ColInt128:
v.Append(proto.Int128{
Expand Down Expand Up @@ -214,17 +226,39 @@ func (col *BigInt) append(v *big.Int) {
},
})
}
return nil
}

func bigIntToRaw(dest []byte, v *big.Int) {
func bigIntToRaw(dest []byte, v *big.Int, signed bool) error {
bits := len(dest) * 8
if signed {
if v.Sign() >= 0 {
if v.BitLen() > bits-1 {
return fmt.Errorf("value overflows %d-byte signed buffer", len(dest))
}
} else {
if new(big.Int).Not(v).BitLen() > bits-1 {
return fmt.Errorf("value overflows %d-byte signed buffer", len(dest))
}
}
} else {
if v.Sign() < 0 {
return fmt.Errorf("negative value not allowed for unsigned type")
}
if v.BitLen() > bits {
return fmt.Errorf("value overflows %d-byte unsigned buffer", len(dest))
}
}

var sign int
if v.Sign() < 0 {
v.Not(v).FillBytes(dest)
new(big.Int).Not(v).FillBytes(dest)
sign = -1
} else {
v.FillBytes(dest)
}
endianSwap(dest, sign < 0)
return nil
}

func rawToBigInt(v []byte, signed bool) *big.Int {
Expand Down
65 changes: 42 additions & 23 deletions lib/column/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"encoding/binary"
"errors"
"fmt"
"math/big"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -139,18 +139,24 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) {
case []decimal.Decimal:
nulls = make([]uint8, len(v))
for i := range v {
col.append(&v[i])
if err := col.append(&v[i]); err != nil {
return nil, err
}
}
case []*decimal.Decimal:
nulls = make([]uint8, len(v))
for i := range v {
switch {
case v[i] != nil:
col.append(v[i])
if err := col.append(v[i]); err != nil {
return nil, err
}
default:
nulls[i] = 1
value := decimal.New(0, 0)
col.append(&value)
if err := col.append(&value); err != nil {
return nil, err
}
}
}
case []string:
Expand All @@ -160,24 +166,29 @@ func (col *Decimal) Append(v any) (nulls []uint8, err error) {
if err != nil {
return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", v[i], err)
}
col.append(&d)
if err := col.append(&d); err != nil {
return nil, err
}
}
case []*string:
nulls = make([]uint8, len(v))
for i := range v {
if v[i] == nil {
nulls[i] = 1
value := decimal.New(0, 0)
col.append(&value)

if err := col.append(&value); err != nil {
return nil, err
}
continue
}

d, err := decimal.NewFromString(*v[i])
if err != nil {
return nil, fmt.Errorf("could not convert \"%v\" to decimal: %w", *v[i], err)
}
col.append(&d)
if err := col.append(&d); err != nil {
return nil, err
}
}
default:
if valuer, ok := v.(driver.Valuer); ok {
Expand Down Expand Up @@ -244,34 +255,41 @@ func (col *Decimal) AppendRow(v any) error {
From: fmt.Sprintf("%T", v),
}
}
col.append(&value)
return nil
return col.append(&value)
}

func (col *Decimal) append(v *decimal.Decimal) {
func (col *Decimal) append(v *decimal.Decimal) error {
switch vCol := col.col.(type) {
case *proto.ColDecimal32:
var part uint32
part = uint32(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart())
vCol.Append(proto.Decimal32(part))
scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale))
bi := scaled.BigInt()
if !bi.IsInt64() || bi.Int64() > math.MaxInt32 || bi.Int64() < math.MinInt32 {
return fmt.Errorf("value %s overflows decimal32 range", v.String())
}
vCol.Append(proto.Decimal32(uint32(bi.Int64())))
case *proto.ColDecimal64:
var part uint64
part = uint64(decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).IntPart())
vCol.Append(proto.Decimal64(part))
scaled := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale))
bi := scaled.BigInt()
if !bi.IsInt64() {
return fmt.Errorf("value %s overflows decimal64 range", v.String())
}
vCol.Append(proto.Decimal64(uint64(bi.Int64())))
case *proto.ColDecimal128:
var bi *big.Int
bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt()
bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt()
dest := make([]byte, 16)
bigIntToRaw(dest, bi)
if err := bigIntToRaw(dest, bi, true); err != nil {
return fmt.Errorf("value %s overflows decimal128 range", v.String())
}
vCol.Append(proto.Decimal128{
Low: binary.LittleEndian.Uint64(dest[0 : 64/8]),
High: binary.LittleEndian.Uint64(dest[64/8 : 128/8]),
})
case *proto.ColDecimal256:
var bi *big.Int
bi = decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt()
bi := decimal.NewFromBigInt(v.Coefficient(), v.Exponent()+int32(col.scale)).BigInt()
dest := make([]byte, 32)
bigIntToRaw(dest, bi)
if err := bigIntToRaw(dest, bi, true); err != nil {
return fmt.Errorf("value %s overflows decimal256 range", v.String())
}
vCol.Append(proto.Decimal256{
Low: proto.UInt128{
Low: binary.LittleEndian.Uint64(dest[0 : 64/8]),
Expand All @@ -283,6 +301,7 @@ func (col *Decimal) append(v *decimal.Decimal) {
},
})
}
return nil
}

func (col *Decimal) Decode(reader *proto.Reader, rows int) error {
Expand Down
92 changes: 92 additions & 0 deletions lib/column/decimal_overflow_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package column

import (
"math/big"
"testing"

"github.com/ClickHouse/ch-go/proto"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDecimal32OverflowReturnsError(t *testing.T) {
col := &Decimal{}
_, err := col.parse("Decimal(9, 2)")
require.NoError(t, err)

// max int32 is 2147483647; scaled by 10^2 the max representable value is ~21474836.47
overflow, err := decimal.NewFromString("21474836.48")
require.NoError(t, err)
err = col.AppendRow(overflow)
assert.ErrorContains(t, err, "overflow")
}

func TestDecimal64OverflowReturnsError(t *testing.T) {
col := &Decimal{}
_, err := col.parse("Decimal(18, 2)")
require.NoError(t, err)

// max int64 is 9223372036854775807; scaled by 10^2 the max representable value is ~92233720368547758.07
overflow, err := decimal.NewFromString("92233720368547758.08")
require.NoError(t, err)
err = col.AppendRow(overflow)
assert.ErrorContains(t, err, "overflow")
}

func TestDecimal128OverflowReturnsError(t *testing.T) {
col := &Decimal{}
_, err := col.parse("Decimal(38, 0)")
require.NoError(t, err)

// 2^127 exceeds Decimal128 signed range
big2_127 := new(big.Int).Lsh(big.NewInt(1), 127)
overflow := decimal.NewFromBigInt(big2_127, 0)
err = col.AppendRow(overflow)
assert.ErrorContains(t, err, "overflow")
}

func TestDecimal256OverflowReturnsError(t *testing.T) {
col := &Decimal{}
_, err := col.parse("Decimal(76, 0)")
require.NoError(t, err)

// 2^255 exceeds Decimal256 signed range
big2_255 := new(big.Int).Lsh(big.NewInt(1), 255)
overflow := decimal.NewFromBigInt(big2_255, 0)
err = col.AppendRow(overflow)
assert.ErrorContains(t, err, "overflow")
}

func TestBigIntOverflowReturnsError(t *testing.T) {
// Int128: signed 128-bit, max positive is 2^127-1
col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}}

big2_127 := new(big.Int).Lsh(big.NewInt(1), 127)
err := col128.AppendRow(*big2_127)
assert.ErrorContains(t, err, "overflow")
}

func TestBigIntValidValuesNoError(t *testing.T) {
col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}}

// 2^127 - 1 is the max valid Int128 value
maxInt128 := new(big.Int).Sub(new(big.Int).Lsh(big.NewInt(1), 127), big.NewInt(1))
err := col128.AppendRow(*maxInt128)
assert.NoError(t, err)

// min valid Int128 value is -2^127
minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127))
err = col128.AppendRow(*minInt128)
assert.NoError(t, err)
}

func TestBigIntNegativeOverflowReturnsError(t *testing.T) {
col128 := &BigInt{size: 16, chType: "Int128", signed: true, col: &proto.ColInt128{}}

// -2^127 - 1 is below the minimum Int128 value (-2^127)
minInt128 := new(big.Int).Neg(new(big.Int).Lsh(big.NewInt(1), 127))
overflow := new(big.Int).Sub(minInt128, big.NewInt(1))
err := col128.AppendRow(*overflow)
assert.ErrorContains(t, err, "overflow")
}
Loading