Skip to content

Commit 0985fcf

Browse files
committed
Add supports for slice of struct ptr
1 parent ada350f commit 0985fcf

File tree

2 files changed

+109
-10
lines changed

2 files changed

+109
-10
lines changed

result_set.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,16 @@ func (res ResultSet) Scan(v interface{}) error {
338338
func (res ResultSet) scanRow(row *nebula.Row, colNames []string, rowType reflect.Type) (reflect.Value, error) {
339339
rowVals := row.GetValues()
340340

341-
val := reflect.New(rowType).Elem()
341+
var result reflect.Value
342+
if rowType.Kind() == reflect.Ptr {
343+
result = reflect.New(rowType.Elem())
344+
} else {
345+
result = reflect.New(rowType).Elem()
346+
}
347+
structVal := reflect.Indirect(result)
342348

343-
for fIdx := 0; fIdx < rowType.NumField(); fIdx++ {
344-
f := rowType.Field(fIdx)
349+
for fIdx := 0; fIdx < structVal.Type().NumField(); fIdx++ {
350+
f := structVal.Type().Field(fIdx)
345351
tag := f.Tag.Get("nebula")
346352

347353
if tag == "" {
@@ -358,19 +364,19 @@ func (res ResultSet) scanRow(row *nebula.Row, colNames []string, rowType reflect
358364

359365
if f.Type.Kind() == reflect.Slice {
360366
list := rowVal.GetLVal()
361-
err := scanListCol(list.Values, val.Field(fIdx), f.Type)
367+
err := scanListCol(list.Values, structVal.Field(fIdx), f.Type)
362368
if err != nil {
363-
return val, err
369+
return result, err
364370
}
365371
} else {
366-
err := scanPrimitiveCol(rowVal, val.Field(fIdx), f.Type.Kind())
372+
err := scanPrimitiveCol(rowVal, structVal.Field(fIdx), f.Type.Kind())
367373
if err != nil {
368-
return val, err
374+
return result, err
369375
}
370376
}
371377
}
372378

373-
return val, nil
379+
return result, nil
374380
}
375381

376382
func scanListCol(vals []*nebula.Value, listVal reflect.Value, sliceType reflect.Type) error {

result_set_test.go

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,54 @@ func TestScan(t *testing.T) {
835835
assert.Equal(t, true, testStructList[1].Col3)
836836
}
837837

838+
func TestScanPtr(t *testing.T) {
839+
resp := &graph.ExecutionResponse{
840+
ErrorCode: nebula.ErrorCode_SUCCEEDED,
841+
LatencyInUs: 1000,
842+
Data: getDateset2(),
843+
SpaceName: []byte("test_space"),
844+
ErrorMsg: []byte("test"),
845+
PlanDesc: graph.NewPlanDescription(),
846+
Comment: []byte("test_comment")}
847+
resultSet, err := genResultSet(resp, testTimezone)
848+
if err != nil {
849+
t.Error(err)
850+
}
851+
852+
type testStruct struct {
853+
Col0 int64 `nebula:"col0_int64"`
854+
Col1 float64 `nebula:"col1_float64"`
855+
Col2 string `nebula:"col2_string"`
856+
Col3 bool `nebula:"col3_bool"`
857+
}
858+
859+
var testStructList []*testStruct
860+
err = resultSet.Scan(&testStructList)
861+
if err != nil {
862+
t.Error(err)
863+
}
864+
assert.Equal(t, 1, len(testStructList))
865+
assert.Equal(t, int64(1), testStructList[0].Col0)
866+
assert.Equal(t, float64(2.0), testStructList[0].Col1)
867+
assert.Equal(t, "string", testStructList[0].Col2)
868+
assert.Equal(t, true, testStructList[0].Col3)
869+
870+
// Scan again should work
871+
err = resultSet.Scan(&testStructList)
872+
if err != nil {
873+
t.Error(err)
874+
}
875+
assert.Equal(t, 2, len(testStructList))
876+
assert.Equal(t, int64(1), testStructList[0].Col0)
877+
assert.Equal(t, float64(2.0), testStructList[0].Col1)
878+
assert.Equal(t, "string", testStructList[0].Col2)
879+
assert.Equal(t, true, testStructList[0].Col3)
880+
assert.Equal(t, int64(1), testStructList[1].Col0)
881+
assert.Equal(t, float64(2.0), testStructList[1].Col1)
882+
assert.Equal(t, "string", testStructList[1].Col2)
883+
assert.Equal(t, true, testStructList[1].Col3)
884+
}
885+
838886
func TestScanWithNestStruct(t *testing.T) {
839887
resp := &graph.ExecutionResponse{
840888
ErrorCode: nebula.ErrorCode_SUCCEEDED,
@@ -916,8 +964,6 @@ func TestScanWithNestStructPtr(t *testing.T) {
916964
Edges []*Friend `nebula:"relationships"`
917965
}
918966

919-
// TODO: actually, the `results` should be []*Result,
920-
// we still need to support this case
921967
var results []Result
922968
err = resultSet.Scan(&results)
923969
if err != nil {
@@ -939,6 +985,53 @@ func TestScanWithNestStructPtr(t *testing.T) {
939985
assert.Equal(t, 2, len(results))
940986
}
941987

988+
func TestScanWithStructPtr(t *testing.T) {
989+
resp := &graph.ExecutionResponse{
990+
ErrorCode: nebula.ErrorCode_SUCCEEDED,
991+
LatencyInUs: 1000,
992+
Data: getNestDateset(),
993+
SpaceName: []byte("test_space"),
994+
ErrorMsg: []byte("test"),
995+
PlanDesc: graph.NewPlanDescription(),
996+
Comment: []byte("test_comment")}
997+
resultSet, err := genResultSet(resp, testTimezone)
998+
if err != nil {
999+
t.Error(err)
1000+
}
1001+
1002+
type Person struct {
1003+
Name string `nebula:"name"`
1004+
City string `nebula:"city"`
1005+
}
1006+
type Friend struct {
1007+
CreatedAt string `nebula:"created_at"`
1008+
}
1009+
type Result struct {
1010+
Nodes []*Person `nebula:"nodes"`
1011+
Edges []*Friend `nebula:"relationships"`
1012+
}
1013+
1014+
var results []*Result
1015+
err = resultSet.Scan(&results)
1016+
if err != nil {
1017+
t.Error(err)
1018+
}
1019+
assert.Equal(t, 1, len(results))
1020+
assert.Equal(t, "Tom", results[0].Nodes[0].Name)
1021+
assert.Equal(t, "Shanghai", results[0].Nodes[0].City)
1022+
assert.Equal(t, "Bob", results[0].Nodes[1].Name)
1023+
assert.Equal(t, "Hangzhou", results[0].Nodes[1].City)
1024+
assert.Equal(t, "2024-07-07", results[0].Edges[0].CreatedAt)
1025+
assert.Equal(t, "2024-07-07", results[0].Edges[1].CreatedAt)
1026+
1027+
// Scan again should work
1028+
err = resultSet.Scan(&results)
1029+
if err != nil {
1030+
t.Error(err)
1031+
}
1032+
assert.Equal(t, 2, len(results))
1033+
}
1034+
9421035
func TestIntVid(t *testing.T) {
9431036
vertex := getVertexInt(101, 3, 5)
9441037
node, err := genNode(vertex, testTimezone)

0 commit comments

Comments
 (0)