Skip to content

Commit f3c77bc

Browse files
author
tengu-alt
committed
scan to any were implemented for all simple types
1 parent 974fa12 commit f3c77bc

File tree

2 files changed

+247
-1
lines changed

2 files changed

+247
-1
lines changed

cassandra_test.go

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import (
4444
"time"
4545
"unicode"
4646

47-
inf "gopkg.in/inf.v0"
47+
"gopkg.in/inf.v0"
4848
)
4949

5050
func TestEmptyHosts(t *testing.T) {
@@ -3288,3 +3288,163 @@ func TestQuery_NamedValues(t *testing.T) {
32883288
t.Fatal(err)
32893289
}
32903290
}
3291+
3292+
func TestScanToAny(t *testing.T) {
3293+
session := createSession(t)
3294+
defer session.Close()
3295+
ctx := context.Background()
3296+
3297+
dataTypes := []struct {
3298+
tableName string
3299+
createQuery string
3300+
insertQuery string
3301+
expectedVal interface{}
3302+
}{
3303+
{
3304+
"scan_to_any_varchar",
3305+
"CREATE TABLE IF NOT EXISTS scan_to_any_varchar (id int PRIMARY KEY, val varchar)",
3306+
"INSERT INTO scan_to_any_varchar (id, val) VALUES (?, ?)",
3307+
"test",
3308+
},
3309+
{
3310+
"scan_to_any_bool",
3311+
"CREATE TABLE IF NOT EXISTS scan_to_any_bool (id int PRIMARY KEY, val boolean)",
3312+
"INSERT INTO scan_to_any_bool (id, val) VALUES (?, ?)",
3313+
true,
3314+
},
3315+
{
3316+
"scan_to_any_int",
3317+
"CREATE TABLE IF NOT EXISTS scan_to_any_int (id int PRIMARY KEY, val int)",
3318+
"INSERT INTO scan_to_any_int (id, val) VALUES (?, ?)",
3319+
42,
3320+
},
3321+
{
3322+
"scan_to_any_float",
3323+
"CREATE TABLE IF NOT EXISTS scan_to_any_float (id int PRIMARY KEY, val float)",
3324+
"INSERT INTO scan_to_any_float (id, val) VALUES (?, ?)",
3325+
float32(3.14),
3326+
},
3327+
{
3328+
"scan_to_any_double",
3329+
"CREATE TABLE IF NOT EXISTS scan_to_any_double (id int PRIMARY KEY, val double)",
3330+
"INSERT INTO scan_to_any_double (id, val) VALUES (?, ?)",
3331+
3.14159,
3332+
},
3333+
{
3334+
"scan_to_any_decimal",
3335+
"CREATE TABLE IF NOT EXISTS scan_to_any_decimal (id int PRIMARY KEY, val decimal)",
3336+
"INSERT INTO scan_to_any_decimal (id, val) VALUES (?, ?)",
3337+
inf.NewDec(12345, 2), // Example decimal value
3338+
},
3339+
{
3340+
"scan_to_any_time",
3341+
"CREATE TABLE IF NOT EXISTS scan_to_any_time (id int PRIMARY KEY, val time)",
3342+
"INSERT INTO scan_to_any_time (id, val) VALUES (?, ?)",
3343+
time.Duration(1000),
3344+
},
3345+
{
3346+
"scan_to_any_timestamp",
3347+
"CREATE TABLE IF NOT EXISTS scan_to_any_timestamp (id int PRIMARY KEY, val timestamp)",
3348+
"INSERT INTO scan_to_any_timestamp (id, val) VALUES (?, ?)",
3349+
time.Now().UTC().Truncate(time.Millisecond),
3350+
},
3351+
{
3352+
"scan_to_any_inet",
3353+
"CREATE TABLE IF NOT EXISTS scan_to_any_inet (id int PRIMARY KEY, val inet)",
3354+
"INSERT INTO scan_to_any_inet (id, val) VALUES (?, ?)",
3355+
net.ParseIP("192.168.0.1"),
3356+
},
3357+
{
3358+
"scan_to_any_uuid",
3359+
"CREATE TABLE IF NOT EXISTS scan_to_any_uuid (id int PRIMARY KEY, val uuid)",
3360+
"INSERT INTO scan_to_any_uuid (id, val) VALUES (?, ?)",
3361+
TimeUUID().String(),
3362+
},
3363+
{
3364+
"scan_to_any_date",
3365+
"CREATE TABLE IF NOT EXISTS scan_to_any_date (id int PRIMARY KEY, val date)",
3366+
"INSERT INTO scan_to_any_date (id, val) VALUES (?, ?)",
3367+
time.Now().UTC().Truncate(time.Hour * 24),
3368+
},
3369+
{
3370+
"scan_to_any_duration",
3371+
"CREATE TABLE IF NOT EXISTS scan_to_any_duration (id int PRIMARY KEY, val duration)",
3372+
"INSERT INTO scan_to_any_duration (id, val) VALUES (?, ?)",
3373+
Duration{0, 0, 123},
3374+
},
3375+
}
3376+
3377+
for _, dt := range dataTypes {
3378+
t.Run(fmt.Sprintf("Test_%s", dt.tableName), func(t *testing.T) {
3379+
if err := session.Query(dt.createQuery).WithContext(ctx).Exec(); err != nil {
3380+
t.Fatal(err)
3381+
}
3382+
3383+
if err := session.Query(dt.insertQuery, 1, dt.expectedVal).WithContext(ctx).Exec(); err != nil {
3384+
t.Fatal(err)
3385+
}
3386+
3387+
var out interface{}
3388+
if err := session.Query(fmt.Sprintf("SELECT val FROM %s WHERE id = 1", dt.tableName)).WithContext(ctx).Scan(&out); err != nil {
3389+
t.Fatal(err)
3390+
}
3391+
3392+
if err := session.Query(fmt.Sprintf("DROP TABLE %s", dt.tableName)).WithContext(ctx).Exec(); err != nil {
3393+
t.Fatal(err)
3394+
}
3395+
3396+
switch dt.tableName {
3397+
case "scan_to_any_decimal":
3398+
result, ok := out.(inf.Dec)
3399+
if !ok {
3400+
t.Fatal("expected inf.Dec, got", out)
3401+
}
3402+
expected := inf.NewDec(12345, 2)
3403+
3404+
if result.Cmp(expected) != 0 {
3405+
t.Fatalf("expected %v, got %v", expected, out)
3406+
}
3407+
case "scan_to_any_inet":
3408+
result, ok := out.(net.IP)
3409+
if !ok {
3410+
t.Fatal("expected net.IP, got", out)
3411+
}
3412+
expected, ok := dt.expectedVal.(net.IP)
3413+
if !ok {
3414+
t.Fatal("expected net.IP, got", dt.expectedVal)
3415+
}
3416+
if result.String() != expected.String() {
3417+
t.Fatalf("expected %v, got %v", expected, out)
3418+
}
3419+
case "scan_to_any_date":
3420+
result, ok := out.(time.Time)
3421+
if !ok {
3422+
t.Fatal("expected time.Time, got", out)
3423+
}
3424+
expected, ok := dt.expectedVal.(time.Time)
3425+
if !ok {
3426+
t.Fatal("expected time.Time, got", dt.expectedVal)
3427+
}
3428+
if result.String() != expected.String() {
3429+
t.Fatalf("expected %v, got %v", expected, out)
3430+
}
3431+
case "scan_to_any_duration":
3432+
result, ok := out.(Duration)
3433+
if !ok {
3434+
t.Fatal("expected time.Duration, got", out)
3435+
}
3436+
expected, ok := dt.expectedVal.(Duration)
3437+
if !ok {
3438+
t.Fatal("expected time.Duration, got", dt.expectedVal)
3439+
}
3440+
if result != expected {
3441+
t.Fatalf("expected %v, got %v", expected, out)
3442+
}
3443+
default:
3444+
if out != dt.expectedVal {
3445+
t.Fatalf("expected %v, got %v", dt.expectedVal, out)
3446+
}
3447+
}
3448+
})
3449+
}
3450+
}

marshal.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error {
350350
*v = nil
351351
}
352352
return nil
353+
case *interface{}:
354+
*v = string(data)
355+
return nil
353356
}
354357

355358
rv := reflect.ValueOf(value)
@@ -743,6 +746,9 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error {
743746
*v = bytesToUint64(data[1:])
744747
return nil
745748
}
749+
case *interface{}:
750+
return unmarshalBigInt(info, data, value)
751+
746752
}
747753

748754
if len(data) > 8 {
@@ -904,6 +910,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac
904910
case *string:
905911
*v = strconv.FormatInt(int64Val, 10)
906912
return nil
913+
case *interface{}:
914+
if ^uint(0) == math.MaxUint32 && (int64Val < math.MinInt32 || int64Val > math.MaxInt32) {
915+
return unmarshalErrorf("unmarshal int: value %d out of range for %s", int64Val, info.Type())
916+
}
917+
*v = int(int64Val)
918+
return nil
907919
}
908920

909921
rv := reflect.ValueOf(value)
@@ -1055,6 +1067,9 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error {
10551067
case *bool:
10561068
*v = decBool(data)
10571069
return nil
1070+
case *interface{}:
1071+
*v = decBool(data)
1072+
return nil
10581073
}
10591074
rv := reflect.ValueOf(value)
10601075
if rv.Kind() != reflect.Ptr {
@@ -1105,6 +1120,9 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error {
11051120
case *float32:
11061121
*v = math.Float32frombits(uint32(decInt(data)))
11071122
return nil
1123+
case *interface{}:
1124+
*v = math.Float32frombits(uint32(decInt(data)))
1125+
return nil
11081126
}
11091127
rv := reflect.ValueOf(value)
11101128
if rv.Kind() != reflect.Ptr {
@@ -1146,6 +1164,9 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error {
11461164
case *float64:
11471165
*v = math.Float64frombits(uint64(decBigInt(data)))
11481166
return nil
1167+
case *interface{}:
1168+
*v = math.Float64frombits(uint64(decBigInt(data)))
1169+
return nil
11491170
}
11501171
rv := reflect.ValueOf(value)
11511172
if rv.Kind() != reflect.Ptr {
@@ -1196,6 +1217,14 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error {
11961217
unscaled := decBigInt2C(data[4:], nil)
11971218
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
11981219
return nil
1220+
case *interface{}:
1221+
if len(data) < 4 {
1222+
return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data))
1223+
}
1224+
scale := decInt(data[0:4])
1225+
unscaled := decBigInt2C(data[4:], nil)
1226+
*v = *inf.NewDecBig(unscaled, inf.Scale(scale))
1227+
return nil
11991228
}
12001229
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
12011230
}
@@ -1302,6 +1331,9 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error {
13021331
case *time.Duration:
13031332
*v = time.Duration(decBigInt(data))
13041333
return nil
1334+
case *interface{}:
1335+
*v = time.Duration(decBigInt(data))
1336+
return nil
13051337
}
13061338

13071339
rv := reflect.ValueOf(value)
@@ -1334,6 +1366,16 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error {
13341366
nsec := (x - sec*1000) * 1000000
13351367
*v = time.Unix(sec, nsec).In(time.UTC)
13361368
return nil
1369+
case *interface{}:
1370+
if len(data) == 0 {
1371+
*v = time.Time{}
1372+
return nil
1373+
}
1374+
x := decBigInt(data)
1375+
sec := x / 1000
1376+
nsec := (x - sec*1000) * 1000000
1377+
*v = time.Unix(sec, nsec).In(time.UTC)
1378+
return nil
13371379
}
13381380

13391381
rv := reflect.ValueOf(value)
@@ -1419,6 +1461,16 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error {
14191461
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
14201462
*v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02")
14211463
return nil
1464+
case *interface{}:
1465+
if len(data) == 0 {
1466+
*v = time.Time{}
1467+
return nil
1468+
}
1469+
var origin uint32 = 1 << 31
1470+
var current uint32 = binary.BigEndian.Uint32(data)
1471+
timestamp := (int64(current) - int64(origin)) * millisecondsInADay
1472+
*v = time.UnixMilli(timestamp).In(time.UTC)
1473+
return nil
14221474
}
14231475
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
14241476
}
@@ -1478,6 +1530,25 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error {
14781530
Nanoseconds: nanos,
14791531
}
14801532
return nil
1533+
case *interface{}:
1534+
if len(data) == 0 {
1535+
*v = Duration{
1536+
Months: 0,
1537+
Days: 0,
1538+
Nanoseconds: 0,
1539+
}
1540+
return nil
1541+
}
1542+
months, days, nanos, err := decVints(data)
1543+
if err != nil {
1544+
return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error())
1545+
}
1546+
*v = Duration{
1547+
Months: months,
1548+
Days: days,
1549+
Nanoseconds: nanos,
1550+
}
1551+
return nil
14811552
}
14821553
return unmarshalErrorf("can not unmarshal %s into %T", info, value)
14831554
}
@@ -1914,6 +1985,9 @@ func unmarshalUUID(info TypeInfo, data []byte, value interface{}) error {
19141985
case *[]byte:
19151986
*v = u[:]
19161987
return nil
1988+
case *interface{}:
1989+
*v = u.String()
1990+
return nil
19171991
}
19181992
return unmarshalErrorf("can not unmarshal X %s into %T", info, value)
19191993
}
@@ -1996,6 +2070,18 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error {
19962070
}
19972071
*v = ip.String()
19982072
return nil
2073+
case *interface{}:
2074+
if x := len(data); !(x == 4 || x == 16) {
2075+
return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x)
2076+
}
2077+
buf := copyBytes(data)
2078+
ip := net.IP(buf)
2079+
if v4 := ip.To4(); v4 != nil {
2080+
*v = v4
2081+
return nil
2082+
}
2083+
*v = ip
2084+
return nil
19992085
}
20002086
return unmarshalErrorf("cannot unmarshal %s into %T", info, value)
20012087
}

0 commit comments

Comments
 (0)