diff --git a/internal/ast/compiler/default_as_typed.go b/internal/ast/compiler/default_as_typed.go new file mode 100644 index 000000000..cb28d3e31 --- /dev/null +++ b/internal/ast/compiler/default_as_typed.go @@ -0,0 +1,57 @@ +package compiler + +import ( + "github.com/grafana/cog/internal/ast" +) + +var _ Pass = (*DefaultAsTyped)(nil) + +// DefaultAsTyped converts the raw `Default any` field on every ast.Type into a +// typed *ast.TypeDefault, populating TypedDefault alongside the existing Default. +// Jennies can then access TypedDefault instead of performing unsafe type assertions +// on Default. +type DefaultAsTyped struct{} + +func (pass *DefaultAsTyped) Process(schemas []*ast.Schema) ([]*ast.Schema, error) { + visitor := &Visitor{ + OnObject: pass.processObject, + } + + return visitor.VisitSchemas(schemas) +} + +func (pass *DefaultAsTyped) processObject(_ *Visitor, _ *ast.Schema, object ast.Object) (ast.Object, error) { + object.Type = processTypeDefaults(object.Type) + return object, nil +} + +// processTypeDefaults recursively sets TypedDefault on any Type that has +// Default != nil, and recurses into child types. +func processTypeDefaults(t ast.Type) ast.Type { + if t.Default != nil { + t.TypedDefault = ast.AnyToTypedDefault(t.Default) + } + + switch { + case t.IsStruct(): + for i, field := range t.Struct.Fields { + field.Type = processTypeDefaults(field.Type) + t.Struct.Fields[i] = field + } + case t.IsIntersection(): + for i, branch := range t.Intersection.Branches { + t.Intersection.Branches[i] = processTypeDefaults(branch) + } + case t.IsDisjunction(): + for i, branch := range t.Disjunction.Branches { + t.Disjunction.Branches[i] = processTypeDefaults(branch) + } + case t.IsArray(): + t.Array.ValueType = processTypeDefaults(t.Array.ValueType) + case t.IsMap(): + t.Map.ValueType = processTypeDefaults(t.Map.ValueType) + } + + return t +} + diff --git a/internal/ast/compiler/default_as_typed_test.go b/internal/ast/compiler/default_as_typed_test.go new file mode 100644 index 000000000..cbce14e41 --- /dev/null +++ b/internal/ast/compiler/default_as_typed_test.go @@ -0,0 +1,140 @@ +package compiler + +import ( + "testing" + + "github.com/grafana/cog/internal/ast" + "github.com/grafana/cog/internal/testutils" +) + +func TestDefaultAsTyped_ScalarStringDefault(t *testing.T) { + schema := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("AString", ast.String(ast.Default("hello"))), + )), + ), + } + + expected := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("AString", ast.String( + ast.Default("hello"), + ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindString, Value: "hello"}}), + )), + )), + ), + } + + pass := &DefaultAsTyped{} + runPassOnSchema(t, pass, schema, expected) +} + +func TestDefaultAsTyped_ScalarBoolDefault(t *testing.T) { + schema := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("ABool", ast.Bool(ast.Default(true))), + )), + ), + } + + expected := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("ABool", ast.Bool( + ast.Default(true), + ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindBool, Value: true}}), + )), + )), + ), + } + + pass := &DefaultAsTyped{} + runPassOnSchema(t, pass, schema, expected) +} + +func TestDefaultAsTyped_Float64Default(t *testing.T) { + schema := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("AFloat", ast.NewScalar(ast.KindFloat64, ast.Default(float64(3.14)))), + )), + ), + } + + expected := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("AFloat", ast.NewScalar(ast.KindFloat64, + ast.Default(3.14), + ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindFloat64, Value: float64(3.14)}}), + )), + )), + ), + } + + pass := &DefaultAsTyped{} + runPassOnSchema(t, pass, schema, expected) +} + +func TestDefaultAsTyped_NoDefault(t *testing.T) { + schema := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("AString", ast.String()), + )), + ), + } + + pass := &DefaultAsTyped{} + runPassOnSchema(t, pass, schema, schema) +} + +func TestDefaultAsTyped_StructDefault(t *testing.T) { + structDefault := map[string]any{ + "name": "alice", + } + + // Build schema with a field whose type has a struct default set directly + innerType := ast.NewStruct( + ast.NewStructField("name", ast.String()), + ) + innerType.Default = structDefault + + schema := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("inner", innerType), + )), + ), + } + + expectedInnerType := innerType + expectedInnerType.TypedDefault = &ast.TypeDefault{ + Struct: map[string]*ast.TypeDefault{ + "name": {Scalar: &ast.ScalarType{ScalarKind: ast.KindString, Value: "alice"}}, + }, + } + + expected := &ast.Schema{ + Package: "test", + Objects: testutils.ObjectsMap( + ast.NewObject("test", "SomeObject", ast.NewStruct( + ast.NewStructField("inner", expectedInnerType), + )), + ), + } + + pass := &DefaultAsTyped{} + runPassOnSchema(t, pass, schema, expected) +} diff --git a/internal/ast/types.go b/internal/ast/types.go index 41cf915fa..9c7070f37 100644 --- a/internal/ast/types.go +++ b/internal/ast/types.go @@ -84,6 +84,66 @@ func (constraint TypeConstraint) DeepCopy() TypeConstraint { // JenniesHints meant to be used by jennies, to gain a finer control on the codegen from schemas type JenniesHints map[string]any +// TypeDefault represents a typed default value for a field or type. +// Exactly one of Scalar, Array, or Struct should be non-nil. +type TypeDefault struct { + // Scalar holds the default for scalar and enum types. + // ScalarKind identifies the scalar type; Value holds the actual Go value. + Scalar *ScalarType + + // Array holds typed defaults for each element of an array default. + Array []*TypeDefault + + // Struct maps field names to typed defaults, used for struct/ref types + // whose default is an object literal. + Struct map[string]*TypeDefault +} + +// AnyToTypedDefault converts a raw any default value into a *TypeDefault. +// This is the canonical conversion used by the DefaultAsTyped compiler pass and +// by EffectiveTypedDefault for on-the-fly conversion when the pass hasn't run. +func AnyToTypedDefault(value any) *TypeDefault { + switch v := value.(type) { + case string: + return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindString, Value: v}} + case bool: + return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindBool, Value: v}} + case float64: + return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindFloat64, Value: v}} + case int64: + return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindInt64, Value: v}} + case int: + return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindInt64, Value: v}} + case []any: + arr := make([]*TypeDefault, len(v)) + for i, elem := range v { + arr[i] = AnyToTypedDefault(elem) + } + return &TypeDefault{Array: arr} + case map[string]any: + structMap := make(map[string]*TypeDefault, len(v)) + for k, val := range v { + structMap[k] = AnyToTypedDefault(val) + } + return &TypeDefault{Struct: structMap} + default: + return &TypeDefault{Scalar: &ScalarType{Value: value}} + } +} + +// EffectiveTypedDefault returns TypedDefault if populated by the DefaultAsTyped +// compiler pass. If not set (e.g. when the pass hasn't run), it derives a +// TypeDefault from the raw Default field on the fly. +func (t Type) EffectiveTypedDefault() *TypeDefault { + if t.TypedDefault != nil { + return t.TypedDefault + } + if t.Default != nil { + return AnyToTypedDefault(t.Default) + } + return nil +} + // Type representing every type defined by the IR. // Bonus: in a way that can be (un)marshaled to/from JSON, // which is useful for unit tests. @@ -92,6 +152,10 @@ type Type struct { Nullable bool Default any `json:",omitempty"` + // TypedDefault is a typed representation of Default, populated by the + // DefaultAsTyped compiler pass. Use this instead of type-asserting Default. + TypedDefault *TypeDefault `json:",omitempty"` + Disjunction *DisjunctionType `json:",omitempty"` Array *ArrayType `json:",omitempty"` Enum *EnumType `json:",omitempty"` @@ -340,6 +404,12 @@ func Default(value any) TypeOption { } } +func TypedDefaultOpt(td *TypeDefault) TypeOption { + return func(def *Type) { + def.TypedDefault = td + } +} + func Hints(hints JenniesHints) TypeOption { return func(def *Type) { def.Hints = hints diff --git a/internal/jennies/golang/jennies.go b/internal/jennies/golang/jennies.go index 1eed6598a..4af48158c 100644 --- a/internal/jennies/golang/jennies.go +++ b/internal/jennies/golang/jennies.go @@ -162,6 +162,7 @@ func (language *Language) CompilerPasses() compiler.Passes { &compiler.DisjunctionInferMapping{}, &compiler.UndiscriminatedDisjunctionToAny{}, &compiler.DisjunctionToType{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/golang/rawtypes.go b/internal/jennies/golang/rawtypes.go index 2aa7fc937..c10d26252 100644 --- a/internal/jennies/golang/rawtypes.go +++ b/internal/jennies/golang/rawtypes.go @@ -241,7 +241,7 @@ func (jenny RawTypes) generateConstructor(buffer *strings.Builder, context langu buffer.WriteString("\n}\n") } -func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast.RefType, objectType ast.Type, maybeExtraDefaults any) string { +func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast.RefType, objectType ast.Type, extraDefaults *ast.TypeDefault) string { var buffer strings.Builder objectName := formatObjectName(objectRef.ReferredType) @@ -252,16 +252,16 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast buffer.WriteString(objectName + "{\n") - extraDefaults := map[string]any{} - if val, ok := maybeExtraDefaults.(map[string]any); ok { - extraDefaults = val - } - for _, field := range objectType.Struct.Fields { resolvedFieldType := context.ResolveRefs(field.Type) - needsExplicitDefault := field.Type.Default != nil || - extraDefaults[field.Name] != nil || + var fieldExtraDefault *ast.TypeDefault + if extraDefaults != nil && extraDefaults.Struct != nil { + fieldExtraDefault = extraDefaults.Struct[field.Name] + } + + needsExplicitDefault := field.Type.EffectiveTypedDefault() != nil || + fieldExtraDefault != nil || (field.Required && field.Type.IsRef() && resolvedFieldType.IsStruct()) || (field.Required && field.Type.IsArray()) || (field.Required && field.Type.IsMap()) || @@ -275,11 +275,10 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast defaultValue := "" // nolint:gocritic - if extraDefault, ok := extraDefaults[field.Name]; ok { - defaultValue = formatScalar(extraDefault) - + if fieldExtraDefault != nil { if field.Type.IsRef() && resolvedFieldType.IsStructGeneratedFromDisjunction() { - disjunctionBranchName := formatFieldName(anyToDisjunctionBranchName(extraDefault)) + defaultValue = formatTypedDefault(fieldExtraDefault) + disjunctionBranchName := formatFieldName(typedDefaultToDisjunctionBranchName(fieldExtraDefault)) disjunctionBranch, found := resolvedFieldType.Struct.FieldByName(disjunctionBranchName) if !found { disjunctionBranchName = "Any" @@ -295,22 +294,28 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast defaultValue += fmt.Sprintf("\t%s: %s,\n", formatFieldName(disjunctionBranchName), actualDefault) defaultValue += "}" + if field.Type.Nullable { + defaultValue = "&" + defaultValue + } + } else if field.Type.IsRef() && resolvedFieldType.IsStruct() && fieldExtraDefault.Struct != nil { + defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, fieldExtraDefault) if field.Type.Nullable { defaultValue = "&" + defaultValue } } else { + defaultValue = formatTypedDefault(fieldExtraDefault) defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType) } } else if field.Type.IsConcreteScalar() { defaultValue = formatScalar(field.Type.Scalar.Value) defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType) - } else if resolvedFieldType.IsAnyOf(ast.KindScalar, ast.KindMap, ast.KindArray) && field.Type.Default != nil { - defaultValue = formatScalar(field.Type.Default) + } else if resolvedFieldType.IsAnyOf(ast.KindScalar, ast.KindMap, ast.KindArray) && field.Type.EffectiveTypedDefault() != nil { + defaultValue = formatTypedDefault(field.Type.EffectiveTypedDefault()) defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType) - } else if field.Type.IsRef() && resolvedFieldType.IsStruct() && field.Type.Default != nil { - defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, field.Type.Default) + } else if field.Type.IsRef() && resolvedFieldType.IsStruct() && field.Type.EffectiveTypedDefault() != nil { + defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, field.Type.EffectiveTypedDefault()) if field.Type.Nullable { defaultValue = "&" + defaultValue } @@ -327,10 +332,12 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast } } else if field.Type.IsRef() && resolvedFieldType.IsEnum() { memberName := resolvedFieldType.Enum.Values[0].Name - for _, member := range resolvedFieldType.Enum.Values { - if member.Value == field.Type.Default { - memberName = member.Name - break + if td := field.Type.EffectiveTypedDefault(); td != nil && td.Scalar != nil { + for _, member := range resolvedFieldType.Enum.Values { + if member.Value == td.Scalar.Value { + memberName = member.Name + break + } } } diff --git a/internal/jennies/golang/tools.go b/internal/jennies/golang/tools.go index 284c1cc16..d8bd9e7bb 100644 --- a/internal/jennies/golang/tools.go +++ b/internal/jennies/golang/tools.go @@ -6,6 +6,7 @@ import ( "regexp" "strings" + "github.com/grafana/cog/internal/ast" "github.com/grafana/cog/internal/tools" ) @@ -123,6 +124,40 @@ func anyToDisjunctionBranchName(value any) string { return valueToDisjunctionBranchName(reflect.ValueOf(value)) } +func formatTypedDefault(td *ast.TypeDefault) string { + if td == nil { + return "nil" + } + if td.Array != nil { + items := make([]string, 0, len(td.Array)) + for _, item := range td.Array { + items = append(items, formatTypedDefault(item)) + } + // FIXME: this is wrong, we can't just assume a list of strings. + return fmt.Sprintf("[]string{%s}", strings.Join(items, ", ")) + } + if td.Scalar != nil { + return formatScalar(td.Scalar.Value) + } + return "nil" +} + +func typedDefaultToDisjunctionBranchName(td *ast.TypeDefault) string { + if td.Array != nil { + if len(td.Array) > 0 { + return "ArrayOf" + typedDefaultToDisjunctionBranchName(td.Array[0]) + } + return "Array" + } + if td.Scalar != nil { + if td.Scalar.ScalarKind != "" { + return tools.UpperCamelCase(string(td.Scalar.ScalarKind)) + } + return tools.UpperCamelCase(reflect.TypeOf(td.Scalar.Value).Kind().String()) + } + return "Struct" +} + func valueToDisjunctionBranchName(value reflect.Value) string { reflectKind := value.Kind() diff --git a/internal/jennies/java/jennies.go b/internal/jennies/java/jennies.go index 1aac20c6d..a34c24f1e 100644 --- a/internal/jennies/java/jennies.go +++ b/internal/jennies/java/jennies.go @@ -158,6 +158,7 @@ func (language *Language) CompilerPasses() compiler.Passes { &compiler.DisjunctionToType{}, &compiler.RemoveIntersections{}, &compiler.InlineObjectsWithTypes{InlineTypes: []ast.Kind{ast.KindScalar, ast.KindMap, ast.KindArray}}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/java/rawtypes.go b/internal/jennies/java/rawtypes.go index 1f09c62b0..90756b57b 100644 --- a/internal/jennies/java/rawtypes.go +++ b/internal/jennies/java/rawtypes.go @@ -279,11 +279,11 @@ func (jenny RawTypes) constructors(object ast.Object) []ConstructorTemplate { ValueFromArg: name, }) - if field.Type.Default != nil { + if field.Type.EffectiveTypedDefault() != nil { defaultConstructorAssignments = append(defaultConstructorAssignments, ConstructorAssignmentTemplate{ Name: name, Type: field.Type, - Value: jenny.genDefaultForType(field.Type, field.Type.Default), + Value: jenny.genDefaultForType(field.Type, field.Type.EffectiveTypedDefault()), }) } else if field.Required && !field.Type.Nullable { // Fields without an explicit default, but that aren't allowed to be null defaultConstructorAssignments = append(defaultConstructorAssignments, ConstructorAssignmentTemplate{ @@ -312,41 +312,50 @@ func (jenny RawTypes) constructors(object ast.Object) []ConstructorTemplate { return constructors } -func (jenny RawTypes) genDefaultForType(t ast.Type, value any) string { +func (jenny RawTypes) genDefaultForType(t ast.Type, td *ast.TypeDefault) string { switch t.Kind { case ast.KindScalar: - return formatType(t.AsScalar().ScalarKind, value) + if td != nil && td.Scalar != nil { + return formatType(t.AsScalar().ScalarKind, td.Scalar.Value) + } + return formatType(t.AsScalar().ScalarKind, nil) case ast.KindRef: - return jenny.formatReferenceDefaults(t, value) + return jenny.formatReferenceDefaults(t, td) case ast.KindArray: - if value == nil { + if td == nil || td.Array == nil { return "List.of()" } - return fmt.Sprintf("List.of(%s)", jenny.genDefaultForType(t.AsArray().ValueType, value)) + items := make([]string, 0, len(td.Array)) + for _, elem := range td.Array { + items = append(items, jenny.genDefaultForType(t.AsArray().ValueType, elem)) + } + return fmt.Sprintf("List.of(%s)", strings.Join(items, ", ")) } return "" } -func (jenny RawTypes) formatReferenceDefaults(ref ast.Type, value any) string { - // Enums - if _, ok := value.(map[string]any); !ok { +func (jenny RawTypes) formatReferenceDefaults(ref ast.Type, td *ast.TypeDefault) string { + if td == nil || td.Struct == nil { + // not a struct default — treat as enum/scalar ref jenny.typeFormatter.packageMapper(ref.AsRef().ReferredPkg, ref.AsRef().ReferredType) - return jenny.typeFormatter.formatRefType(ref, value) + var scalarValue any + if td != nil && td.Scalar != nil { + scalarValue = td.Scalar.Value + } + return jenny.typeFormatter.formatRefType(ref, scalarValue) } obj, ok := jenny.typeFormatter.context.LocateObjectByRef(ref.AsRef()) if !ok { return "" } - - defaultValues := value.(map[string]any) objectFields := obj.Type.AsStruct().Fields args := make([]string, len(objectFields)) for i, f := range objectFields { - if v, ok := defaultValues[f.Name]; ok { - args[i] = jenny.genDefaultForType(f.Type, v) + if fieldTD, ok := td.Struct[f.Name]; ok { + args[i] = jenny.genDefaultForType(f.Type, fieldTD) } else { args[i] = jenny.typeFormatter.emptyValueForType(f.Type, false) } diff --git a/internal/jennies/jsonschema/jennies.go b/internal/jennies/jsonschema/jennies.go index 4073e73e3..9e8050aa4 100644 --- a/internal/jennies/jsonschema/jennies.go +++ b/internal/jennies/jsonschema/jennies.go @@ -60,6 +60,7 @@ func (language *Language) CompilerPasses() compiler.Passes { return compiler.Passes{ &compiler.DisjunctionWithNullToOptional{}, &compiler.InferEntrypoint{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/openapi/jennies.go b/internal/jennies/openapi/jennies.go index 4f5900279..8c01bc435 100644 --- a/internal/jennies/openapi/jennies.go +++ b/internal/jennies/openapi/jennies.go @@ -53,5 +53,6 @@ func (language *Language) CompilerPasses() compiler.Passes { // should be a superset of the compiler passes defined for jsonschema jennies &compiler.DisjunctionWithNullToOptional{}, &compiler.InferEntrypoint{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/php/jennies.go b/internal/jennies/php/jennies.go index b58e664a0..a478de5d8 100644 --- a/internal/jennies/php/jennies.go +++ b/internal/jennies/php/jennies.go @@ -167,6 +167,7 @@ func (language *Language) CompilerPasses() compiler.Passes { &compiler.InlineObjectsWithTypes{ InlineTypes: []ast.Kind{ast.KindScalar, ast.KindArray, ast.KindMap, ast.KindDisjunction}, }, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/php/rawtypes.go b/internal/jennies/php/rawtypes.go index 4270bdffc..2c5d02cbc 100644 --- a/internal/jennies/php/rawtypes.go +++ b/internal/jennies/php/rawtypes.go @@ -309,13 +309,8 @@ func (jenny RawTypes) generateConstructor(context languages.Context, def ast.Obj } // set for default values for fields that need one or have one - if !field.Type.Nullable || field.Type.Default != nil { - var defaultsOverrides map[string]any - if overrides, ok := field.Type.Default.(map[string]any); ok { - defaultsOverrides = overrides - } - - defaultValue = defaultValueForType(jenny.config, context.Schemas, field.Type, orderedmap.FromMap(defaultsOverrides)) + if !field.Type.Nullable || field.Type.EffectiveTypedDefault() != nil { + defaultValue = defaultValueForType(jenny.config, context.Schemas, field.Type, field.Type.EffectiveTypedDefault()) } // initialize constant fields diff --git a/internal/jennies/php/tools.go b/internal/jennies/php/tools.go index c108f4f3d..12bfd044d 100644 --- a/internal/jennies/php/tools.go +++ b/internal/jennies/php/tools.go @@ -2,11 +2,12 @@ package php import ( "fmt" + "maps" "regexp" + "slices" "strings" "github.com/grafana/cog/internal/ast" - "github.com/grafana/cog/internal/orderedmap" "github.com/grafana/cog/internal/tools" ) @@ -160,9 +161,21 @@ func disjunctionCaseForType(typesFormatter *typeFormatter, input string, typeDef * Default and "empty" values management * *****************************************/ -func defaultValueForType(config Config, schemas ast.Schemas, typeDef ast.Type, defaultsOverrides *orderedmap.Map[string, any]) any { - if !typeDef.IsRef() && typeDef.Default != nil { - return typeDef.Default +func defaultValueForType(config Config, schemas ast.Schemas, typeDef ast.Type, overrideDefault *ast.TypeDefault) any { + if !typeDef.IsRef() && overrideDefault != nil { + if overrideDefault.Scalar != nil { + return overrideDefault.Scalar.Value + } + if overrideDefault.Array != nil { + items := make([]any, len(overrideDefault.Array)) + for i, elem := range overrideDefault.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } + } + return items + } + // Struct default → fall through } switch typeDef.Kind { @@ -178,10 +191,12 @@ func defaultValueForType(config Config, schemas ast.Schemas, typeDef ast.Type, d referredObj, found := schemas.LocateObject(ref.ReferredPkg, ref.ReferredType) if found && referredObj.Type.IsEnum() { enumName := formatObjectName(referredObj.Type.AsEnum().Values[0].Name) - for _, enumValue := range referredObj.Type.AsEnum().Values { - if enumValue.Value == typeDef.Default { - enumName = formatEnumMemberName(enumValue.Name) - break + if overrideDefault != nil && overrideDefault.Scalar != nil { + for _, enumValue := range referredObj.Type.AsEnum().Values { + if enumValue.Value == overrideDefault.Scalar.Value { + enumName = formatEnumMemberName(enumValue.Name) + break + } } } @@ -192,29 +207,36 @@ func defaultValueForType(config Config, schemas ast.Schemas, typeDef ast.Type, d var extraDefaults []string - if defaultsOverrides != nil { - extraDefaults = make([]string, 0, defaultsOverrides.Len()) - defaultsOverrides.Iterate(func(k string, v any) { + if overrideDefault != nil && overrideDefault.Struct != nil { + keys := slices.Sorted(maps.Keys(overrideDefault.Struct)) + extraDefaults = make([]string, 0, len(keys)) + for _, k := range keys { + fieldTD := overrideDefault.Struct[k] if !referredObj.Type.IsStruct() { - return + continue } field, fieldFound := referredObj.Type.AsStruct().FieldByName(k) if !fieldFound { - return + continue } - value := v + var value any if field.Type.IsRef() { - var fieldOverrides *orderedmap.Map[string, any] - if overrides, ok := value.(map[string]any); ok { - fieldOverrides = orderedmap.FromMap(overrides) + value = defaultValueForType(config, schemas, field.Type, fieldTD) + } else if fieldTD.Scalar != nil { + value = fieldTD.Scalar.Value + } else if fieldTD.Array != nil { + items := make([]any, len(fieldTD.Array)) + for i, elem := range fieldTD.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } } - - value = defaultValueForType(config, schemas, field.Type, fieldOverrides) + value = items } extraDefaults = append(extraDefaults, fmt.Sprintf("%s: %s", formatFieldName(k), formatValue(value))) - }) + } } formattedRef := formatObjectName(ref.ReferredType) diff --git a/internal/jennies/python/jennies.go b/internal/jennies/python/jennies.go index 87d63370a..eb594fa97 100644 --- a/internal/jennies/python/jennies.go +++ b/internal/jennies/python/jennies.go @@ -116,6 +116,7 @@ func (language *Language) CompilerPasses() compiler.Passes { &compiler.FlattenDisjunctions{}, &compiler.DisjunctionInferMapping{}, &compiler.RenameNumericEnumValues{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/python/rawtypes.go b/internal/jennies/python/rawtypes.go index 6a57d0f02..020619662 100644 --- a/internal/jennies/python/rawtypes.go +++ b/internal/jennies/python/rawtypes.go @@ -11,8 +11,7 @@ import ( "github.com/grafana/cog/internal/jennies/common" "github.com/grafana/cog/internal/jennies/template" "github.com/grafana/cog/internal/languages" - "github.com/grafana/cog/internal/orderedmap" - "github.com/grafana/cog/internal/tools" +"github.com/grafana/cog/internal/tools" ) type RawTypes struct { @@ -197,13 +196,8 @@ func (jenny RawTypes) generateInitMethod(schemas ast.Schemas, object ast.Object) continue } - if !field.Type.Nullable || field.Type.Default != nil { - var defaultsOverrides map[string]any - if overrides, ok := field.Type.Default.(map[string]any); ok { - defaultsOverrides = overrides - } - - defaultValue = defaultValueForType(schemas, field.Type, jenny.importModule, orderedmap.FromMap(defaultsOverrides)) + if !field.Type.Nullable || field.Type.EffectiveTypedDefault() != nil { + defaultValue = defaultValueForType(schemas, field.Type, jenny.importModule, field.Type.EffectiveTypedDefault()) } if field.Type.IsConcreteScalar() { diff --git a/internal/jennies/python/tools.go b/internal/jennies/python/tools.go index f55d2b85d..48f50c79d 100644 --- a/internal/jennies/python/tools.go +++ b/internal/jennies/python/tools.go @@ -2,10 +2,11 @@ package python import ( "fmt" + "maps" + "slices" "strings" "github.com/grafana/cog/internal/ast" - "github.com/grafana/cog/internal/orderedmap" "github.com/grafana/cog/internal/tools" ) @@ -132,9 +133,21 @@ func isReservedPythonKeyword(input string) bool { * Default and "empty" values management * ******************************************************************************/ -func defaultValueForType(schemas ast.Schemas, typeDef ast.Type, importModule moduleImporter, defaultsOverrides *orderedmap.Map[string, any]) any { - if !typeDef.IsRef() && typeDef.Default != nil { - return typeDef.Default +func defaultValueForType(schemas ast.Schemas, typeDef ast.Type, importModule moduleImporter, overrideDefault *ast.TypeDefault) any { + if !typeDef.IsRef() && overrideDefault != nil { + if overrideDefault.Scalar != nil { + return overrideDefault.Scalar.Value + } + if overrideDefault.Array != nil { + items := make([]any, len(overrideDefault.Array)) + for i, elem := range overrideDefault.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } + } + return items + } + // Struct default → fall through } switch typeDef.Kind { @@ -152,10 +165,12 @@ func defaultValueForType(schemas ast.Schemas, typeDef ast.Type, importModule mod referredObj, found := schemas.LocateObject(ref.ReferredPkg, ref.ReferredType) if found && referredObj.Type.IsEnum() { enumName := tools.UpperSnakeCase(referredObj.Type.AsEnum().Values[0].Name) - for _, enumValue := range referredObj.Type.AsEnum().Values { - if enumValue.Value == typeDef.Default { - enumName = tools.UpperSnakeCase(enumValue.Name) - break + if overrideDefault != nil && overrideDefault.Scalar != nil { + for _, enumValue := range referredObj.Type.AsEnum().Values { + if enumValue.Value == overrideDefault.Scalar.Value { + enumName = tools.UpperSnakeCase(enumValue.Name) + break + } } } @@ -172,29 +187,36 @@ func defaultValueForType(schemas ast.Schemas, typeDef ast.Type, importModule mod var extraDefaults []string - if defaultsOverrides != nil { - extraDefaults = make([]string, 0, defaultsOverrides.Len()) - defaultsOverrides.Iterate(func(k string, v any) { + if overrideDefault != nil && overrideDefault.Struct != nil { + keys := slices.Sorted(maps.Keys(overrideDefault.Struct)) + extraDefaults = make([]string, 0, len(keys)) + for _, k := range keys { + fieldTD := overrideDefault.Struct[k] if !referredObj.Type.IsStruct() { - return + continue } field, fieldFound := referredObj.Type.AsStruct().FieldByName(k) if !fieldFound { - return + continue } - value := v + var value any if field.Type.IsRef() { - var fieldOverrides *orderedmap.Map[string, any] - if overrides, ok := value.(map[string]any); ok { - fieldOverrides = orderedmap.FromMap(overrides) + value = defaultValueForType(schemas, field.Type, importModule, fieldTD) + } else if fieldTD.Scalar != nil { + value = fieldTD.Scalar.Value + } else if fieldTD.Array != nil { + items := make([]any, len(fieldTD.Array)) + for i, elem := range fieldTD.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } } - - value = defaultValueForType(schemas, field.Type, importModule, fieldOverrides) + value = items } extraDefaults = append(extraDefaults, fmt.Sprintf("%s=%s", formatIdentifier(k), formatValue(value))) - }) + } } formattedRef := tools.UpperCamelCase(ref.ReferredType) diff --git a/internal/jennies/terraform/jennies.go b/internal/jennies/terraform/jennies.go index 95d41b187..e51089bdf 100644 --- a/internal/jennies/terraform/jennies.go +++ b/internal/jennies/terraform/jennies.go @@ -80,6 +80,7 @@ func (language *Language) CompilerPasses() compiler.Passes { &compiler.DisjunctionInferMapping{}, &compiler.UndiscriminatedDisjunctionToAny{}, &compiler.DisjunctionToType{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/typescript/builder_test.go b/internal/jennies/typescript/builder_test.go index 8025fda70..b57a0e99d 100644 --- a/internal/jennies/typescript/builder_test.go +++ b/internal/jennies/typescript/builder_test.go @@ -3,6 +3,7 @@ package typescript import ( "testing" + "github.com/grafana/cog/internal/ast/compiler" "github.com/grafana/cog/internal/jennies/common" "github.com/grafana/cog/internal/languages" "github.com/grafana/cog/internal/testutils" @@ -32,6 +33,13 @@ func TestBuilder_Generate(t *testing.T) { req := require.New(tc) context := tc.UnmarshalJSONInput(testutils.BuildersContextInputFile) + + // Populate TypedDefault on all types so that default-handling code can + // work with typed defaults instead of the raw Default any field. + pass := &compiler.DefaultAsTyped{} + context.Schemas, err = pass.Process(context.Schemas) + req.NoError(err) + context, err = languages.GenerateBuilderNilChecks(language, context) req.NoError(err) diff --git a/internal/jennies/typescript/jennies.go b/internal/jennies/typescript/jennies.go index 128494ff5..8dea0740b 100644 --- a/internal/jennies/typescript/jennies.go +++ b/internal/jennies/typescript/jennies.go @@ -152,6 +152,7 @@ func (language *Language) Jennies(globalConfig languages.Config) *codejen.JennyL func (language *Language) CompilerPasses() compiler.Passes { return compiler.Passes{ &compiler.RenameNumericEnumValues{}, + &compiler.DefaultAsTyped{}, } } diff --git a/internal/jennies/typescript/rawtypes.go b/internal/jennies/typescript/rawtypes.go index 8bec11855..4c066532d 100644 --- a/internal/jennies/typescript/rawtypes.go +++ b/internal/jennies/typescript/rawtypes.go @@ -151,8 +151,8 @@ func (jenny RawTypes) defaultValueForObject(object ast.Object, packageMapper pac case ast.KindEnum: enum := object.Type.AsEnum() defaultValue := enum.Values[0].Value - if object.Type.Default != nil { - defaultValue = object.Type.Default + if td := object.Type.EffectiveTypedDefault(); td != nil && td.Scalar != nil { + defaultValue = td.Scalar.Value } return raw(jenny.typeFormatter.enums.formatValue(object, defaultValue)) @@ -162,8 +162,20 @@ func (jenny RawTypes) defaultValueForObject(object ast.Object, packageMapper pac } func (jenny RawTypes) defaultValueForType(typeDef ast.Type, packageMapper packageMapper) any { - if typeDef.Default != nil { - return typeDef.Default + if td := typeDef.EffectiveTypedDefault(); td != nil { + if td.Scalar != nil { + return td.Scalar.Value + } + if td.Array != nil { + items := make([]any, len(td.Array)) + for i, elem := range td.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } + } + return items + } + // Struct default → fall through to switch } switch typeDef.Kind { @@ -173,8 +185,8 @@ func (jenny RawTypes) defaultValueForType(typeDef ast.Type, packageMapper packag return jenny.defaultValuesForStructType(typeDef, packageMapper) case ast.KindEnum: // anonymous enum defaultValue := typeDef.AsEnum().Values[0].Value - if typeDef.Default != nil { - defaultValue = typeDef.Default + if td := typeDef.EffectiveTypedDefault(); td != nil && td.Scalar != nil { + defaultValue = td.Scalar.Value } return defaultValue @@ -199,17 +211,26 @@ func (jenny RawTypes) defaultValuesForStructType(structType ast.Type, packageMap defaults := orderedmap.New[string, any]() for _, field := range structType.AsStruct().Fields { - if field.Type.Default != nil { + if td := field.Type.EffectiveTypedDefault(); td != nil { switch field.Type.Kind { case ast.KindRef: defaults.Set(field.Name, jenny.defaultValuesForReference(field.Type, packageMapper)) continue case ast.KindStruct: - defaultMap := field.Type.Default.(map[string]any) - defaults.Set(field.Name, jenny.defaultValueForStructs(field.Type.AsStruct(), orderedmap.FromMap(defaultMap))) + defaults.Set(field.Name, jenny.defaultValueForStructs(field.Type.AsStruct(), td.Struct)) continue default: - defaults.Set(field.Name, field.Type.Default) + if td.Scalar != nil { + defaults.Set(field.Name, td.Scalar.Value) + } else if td.Array != nil { + items := make([]any, len(td.Array)) + for i, elem := range td.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } + } + defaults.Set(field.Name, items) + } continue } } @@ -299,12 +320,15 @@ func (jenny RawTypes) defaultValuesForReference(typeDef ast.Type, packageMapper } if referredType.Type.IsEnum() { - return raw(jenny.typeFormatter.enums.formatValue(referredType, typeDef.Default)) + var enumDefault any + if td := typeDef.EffectiveTypedDefault(); td != nil && td.Scalar != nil { + enumDefault = td.Scalar.Value + } + return raw(jenny.typeFormatter.enums.formatValue(referredType, enumDefault)) } - if hasStructDefaults(referredType.Type, typeDef.Default) { - defaultMap := typeDef.Default.(map[string]any) - return jenny.defaultValueForStructs(referredType.Type.AsStruct(), orderedmap.FromMap(defaultMap)) + if td := typeDef.EffectiveTypedDefault(); referredType.Type.IsStruct() && td != nil && td.Struct != nil { + return jenny.defaultValueForStructs(referredType.Type.AsStruct(), td.Struct) } if pkg != "" { @@ -314,28 +338,36 @@ func (jenny RawTypes) defaultValuesForReference(typeDef ast.Type, packageMapper return raw(fmt.Sprintf("default%s()", tools.UpperCamelCase(referredTypeName))) } -func (jenny RawTypes) defaultValueForStructs(def ast.StructType, m *orderedmap.Map[string, any]) any { +func (jenny RawTypes) defaultValueForStructs(def ast.StructType, defaults map[string]*ast.TypeDefault) any { var buffer strings.Builder for _, f := range def.Fields { - if m.Has(f.Name) { - switch x := m.Get(f.Name).(type) { - case map[string]any: - buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, jenny.defaultValueForStructs(f.Type.AsStruct(), orderedmap.FromMap(x)))) - case nil: + if td, ok := defaults[f.Name]; ok { + if td.Struct != nil { + buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, jenny.defaultValueForStructs(f.Type.AsStruct(), td.Struct))) + } else if td.Array != nil { + items := make([]any, len(td.Array)) + for i, elem := range td.Array { + if elem.Scalar != nil { + items[i] = elem.Scalar.Value + } + } + buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, formatValue(items))) + } else if td.Scalar == nil { + // null/unset value buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, formatValue([]any{}))) - default: + } else { if f.Type.IsRef() { ref := f.Type.AsRef() referredType, refFound := jenny.schemas.LocateObject(ref.ReferredPkg, ref.ReferredType) if refFound && referredType.Type.IsEnum() { - buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, jenny.typeFormatter.enums.formatValue(referredType, x))) + buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, jenny.typeFormatter.enums.formatValue(referredType, td.Scalar.Value))) continue } } - buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, formatValue(x))) + buffer.WriteString(fmt.Sprintf("%s: %v, ", f.Name, formatValue(td.Scalar.Value))) } } else if f.Required { switch f.Type.Kind { @@ -387,7 +419,3 @@ func (jenny RawTypes) defaultValueForConstantReferences(def ast.ConstantReferenc return "unknown" } -func hasStructDefaults(typeDef ast.Type, defaults any) bool { - _, ok := defaults.(map[string]any) - return ok && typeDef.IsStruct() -} diff --git a/testdata/jennies/builders/initialization_safeguards/TypescriptBuilder/src/initializationSafeguards/somePanelBuilder.gen.ts b/testdata/jennies/builders/initialization_safeguards/TypescriptBuilder/src/initializationSafeguards/somePanelBuilder.gen.ts index 14f5e9295..94912a558 100644 --- a/testdata/jennies/builders/initialization_safeguards/TypescriptBuilder/src/initializationSafeguards/somePanelBuilder.gen.ts +++ b/testdata/jennies/builders/initialization_safeguards/TypescriptBuilder/src/initializationSafeguards/somePanelBuilder.gen.ts @@ -25,9 +25,7 @@ export class SomePanelBuilder implements cog.Builder