@@ -449,3 +449,172 @@ func TestCastVisit(t *testing.T) {
449449		})
450450	}
451451}
452+ 
453+ func TestSubqueryExpressionRoundtrip(t *testing.T) {
454+ 	const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
455+ 	// define extensions with no plan for now
456+ 	const planExt = `{
457+ 		"extensionUris": [
458+ 			{
459+ 				"extensionUriAnchor": 1,
460+ 				"uri": "` + substraitExtURI + `"
461+ 			}
462+ 		],
463+ 		"extensions": [],
464+ 		"relations": []
465+ 	}`
466+ 
467+ 	var planProto proto.Plan
468+ 	if err := protojson.Unmarshal([]byte(planExt), &planProto); err != nil {
469+ 		panic(err)
470+ 	}
471+ 
472+ 	// get the extension set and create registry with subquery handler
473+ 	extSet := ext.GetExtensionSet(&planProto)
474+ 	c := ext.GetDefaultCollectionWithNoError()
475+ 
476+ 	// Create extension registry with subquery handler properly
477+ 	baseReg := expr.NewExtensionRegistry(extSet, c)
478+ 	subqueryReg := &plan.ExpressionConverter{ExtensionRegistry: baseReg}
479+ 
480+ 	// Create a simple mock relation for subqueries - single column of int32
481+ 	mockSchema := types.NamedStruct{
482+ 		Names: []string{"col1"},
483+ 		Struct: types.StructType{
484+ 			Types: []types.Type{&types.Int32Type{}},
485+ 		},
486+ 	}
487+ 	mockRel := plan.NewBuilderDefault().NamedScan([]string{"test_table"}, mockSchema)
488+ 
489+ 	// Create base schema for needle expressions
490+ 	baseSchema := types.NewRecordTypeFromTypes([]types.Type{&types.Int32Type{}, &types.StringType{}})
491+ 
492+ 	tests := []struct {
493+ 		name    string
494+ 		subExpr expr.Expression
495+ 	}{
496+ 		{
497+ 			name:    "ScalarSubquery",
498+ 			subExpr: plan.NewScalarSubquery(mockRel),
499+ 		},
500+ 		{
501+ 			name: "InPredicateSubquery",
502+ 			subExpr: plan.NewInPredicateSubquery(
503+ 				[]expr.Expression{expr.NewPrimitiveLiteral(int32(42), false)},
504+ 				mockRel,
505+ 			),
506+ 		},
507+ 		{
508+ 			name: "InPredicateSubquery_MultipleNeedles",
509+ 			subExpr: func() expr.Expression {
510+ 				// Create a 2-column relation for multi-needle test
511+ 				twoColSchema := types.NamedStruct{
512+ 					Names: []string{"col1", "col2"},
513+ 					Struct: types.StructType{
514+ 						Types: []types.Type{&types.Int32Type{}, &types.StringType{}},
515+ 					},
516+ 				}
517+ 				twoColRel := plan.NewBuilderDefault().NamedScan([]string{"two_col_table"}, twoColSchema)
518+ 
519+ 				return plan.NewInPredicateSubquery(
520+ 					[]expr.Expression{
521+ 						expr.NewPrimitiveLiteral(int32(42), false),
522+ 						expr.NewPrimitiveLiteral("test", false),
523+ 					},
524+ 					twoColRel,
525+ 				)
526+ 			}(),
527+ 		},
528+ 		{
529+ 			name: "SetPredicateSubquery_EXISTS",
530+ 			subExpr: plan.NewSetPredicateSubquery(
531+ 				proto.Expression_Subquery_SetPredicate_PREDICATE_OP_EXISTS,
532+ 				mockRel,
533+ 			),
534+ 		},
535+ 		{
536+ 			name: "SetPredicateSubquery_UNIQUE",
537+ 			subExpr: plan.NewSetPredicateSubquery(
538+ 				proto.Expression_Subquery_SetPredicate_PREDICATE_OP_UNIQUE,
539+ 				mockRel,
540+ 			),
541+ 		},
542+ 		{
543+ 			name: "SetComparisonSubquery_ANY_EQ",
544+ 			subExpr: plan.NewSetComparisonSubquery(
545+ 				proto.Expression_Subquery_SetComparison_REDUCTION_OP_ANY,
546+ 				proto.Expression_Subquery_SetComparison_COMPARISON_OP_EQ,
547+ 				expr.NewPrimitiveLiteral(int32(42), false),
548+ 				mockRel,
549+ 			),
550+ 		},
551+ 		{
552+ 			name: "SetComparisonSubquery_ALL_GT",
553+ 			subExpr: plan.NewSetComparisonSubquery(
554+ 				proto.Expression_Subquery_SetComparison_REDUCTION_OP_ALL,
555+ 				proto.Expression_Subquery_SetComparison_COMPARISON_OP_GT,
556+ 				expr.NewPrimitiveLiteral(int32(100), false),
557+ 				mockRel,
558+ 			),
559+ 		},
560+ 		{
561+ 			name: "SetComparisonSubquery_ANY_NE",
562+ 			subExpr: plan.NewSetComparisonSubquery(
563+ 				proto.Expression_Subquery_SetComparison_REDUCTION_OP_ANY,
564+ 				proto.Expression_Subquery_SetComparison_COMPARISON_OP_NE,
565+ 				expr.NewPrimitiveLiteral(int32(0), false),
566+ 				mockRel,
567+ 			),
568+ 		},
569+ 		{
570+ 			name: "SetComparisonSubquery_ALL_LE",
571+ 			subExpr: plan.NewSetComparisonSubquery(
572+ 				proto.Expression_Subquery_SetComparison_REDUCTION_OP_ALL,
573+ 				proto.Expression_Subquery_SetComparison_COMPARISON_OP_LE,
574+ 				expr.NewPrimitiveLiteral(int32(50), false),
575+ 				mockRel,
576+ 			),
577+ 		},
578+ 	}
579+ 
580+ 	for _, tt := range tests {
581+ 		t.Run(tt.name, func(t *testing.T) {
582+ 			// Convert expression to protobuf
583+ 			protoExpr := tt.subExpr.ToProto()
584+ 			require.NotNil(t, protoExpr)
585+ 			require.NotNil(t, protoExpr.GetSubquery())
586+ 
587+ 			// Convert back from protobuf using ExprFromProto with subquery handler
588+ 			baseReg.SetSubqueryConverter(subqueryReg)
589+ 			fromProto, err := expr.ExprFromProto(protoExpr, baseSchema, baseReg)
590+ 			require.NoError(t, err)
591+ 			require.NotNil(t, fromProto)
592+ 
593+ 			// Verify that we got the right type of subquery back
594+ 			switch tt.subExpr.(type) {
595+ 			case *plan.ScalarSubquery:
596+ 				assert.IsType(t, &plan.ScalarSubquery{}, fromProto)
597+ 			case *plan.InPredicateSubquery:
598+ 				assert.IsType(t, &plan.InPredicateSubquery{}, fromProto)
599+ 			case *plan.SetPredicateSubquery:
600+ 				assert.IsType(t, &plan.SetPredicateSubquery{}, fromProto)
601+ 			case *plan.SetComparisonSubquery:
602+ 				assert.IsType(t, &plan.SetComparisonSubquery{}, fromProto)
603+ 			}
604+ 
605+ 			// Verify protobuf roundtrip
606+ 			roundtripProto := fromProto.ToProto()
607+ 			assert.True(t, pb.Equal(protoExpr, roundtripProto), "protobuf roundtrip failed")
608+ 
609+ 			// Verify basic properties
610+ 			assert.Equal(t, tt.subExpr.IsScalar(), fromProto.IsScalar())
611+ 			assert.True(t, tt.subExpr.GetType().Equals(fromProto.GetType()))
612+ 
613+ 			// Note: We don't test Equals() here because the current implementation
614+ 			// of isRelEqual() only does pointer equality, so relations created from
615+ 			// protobuf will never be equal to the original relations, even if they
616+ 			// have identical content. This is a known limitation noted in the TODO
617+ 			// comment in plan/subquery.go
618+ 		})
619+ 	}
620+ }
0 commit comments