Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions internal/ast/compiler/default_as_typed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package compiler

import (
"github.com/grafana/cog/internal/ast"
)

var _ Pass = (*DefaultAsTyped)(nil)

// DefaultAsTyped converts the raw `Default any` field on every ast.Type into a
// typed *ast.TypeDefault, populating TypedDefault alongside the existing Default.
// Jennies can then access TypedDefault instead of performing unsafe type assertions
// on Default.
type DefaultAsTyped struct{}

func (pass *DefaultAsTyped) Process(schemas []*ast.Schema) ([]*ast.Schema, error) {
visitor := &Visitor{
OnObject: pass.processObject,
}

return visitor.VisitSchemas(schemas)
}

func (pass *DefaultAsTyped) processObject(_ *Visitor, _ *ast.Schema, object ast.Object) (ast.Object, error) {
object.Type = processTypeDefaults(object.Type)
return object, nil
}

// processTypeDefaults recursively sets TypedDefault on any Type that has
// Default != nil, and recurses into child types.
func processTypeDefaults(t ast.Type) ast.Type {
if t.Default != nil {
t.TypedDefault = ast.AnyToTypedDefault(t.Default)
}

switch {
case t.IsStruct():
for i, field := range t.Struct.Fields {
field.Type = processTypeDefaults(field.Type)
t.Struct.Fields[i] = field
}
case t.IsIntersection():
for i, branch := range t.Intersection.Branches {
t.Intersection.Branches[i] = processTypeDefaults(branch)
}
case t.IsDisjunction():
for i, branch := range t.Disjunction.Branches {
t.Disjunction.Branches[i] = processTypeDefaults(branch)
}
case t.IsArray():
t.Array.ValueType = processTypeDefaults(t.Array.ValueType)
case t.IsMap():
t.Map.ValueType = processTypeDefaults(t.Map.ValueType)
}

return t
}

140 changes: 140 additions & 0 deletions internal/ast/compiler/default_as_typed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package compiler

import (
"testing"

"github.com/grafana/cog/internal/ast"
"github.com/grafana/cog/internal/testutils"
)

func TestDefaultAsTyped_ScalarStringDefault(t *testing.T) {
schema := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("AString", ast.String(ast.Default("hello"))),
)),
),
}

expected := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("AString", ast.String(
ast.Default("hello"),
ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindString, Value: "hello"}}),
)),
)),
),
}

pass := &DefaultAsTyped{}
runPassOnSchema(t, pass, schema, expected)
}

func TestDefaultAsTyped_ScalarBoolDefault(t *testing.T) {
schema := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("ABool", ast.Bool(ast.Default(true))),
)),
),
}

expected := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("ABool", ast.Bool(
ast.Default(true),
ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindBool, Value: true}}),
)),
)),
),
}

pass := &DefaultAsTyped{}
runPassOnSchema(t, pass, schema, expected)
}

func TestDefaultAsTyped_Float64Default(t *testing.T) {
schema := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("AFloat", ast.NewScalar(ast.KindFloat64, ast.Default(float64(3.14)))),
)),
),
}

expected := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("AFloat", ast.NewScalar(ast.KindFloat64,
ast.Default(3.14),
ast.TypedDefaultOpt(&ast.TypeDefault{Scalar: &ast.ScalarType{ScalarKind: ast.KindFloat64, Value: float64(3.14)}}),
)),
)),
),
}

pass := &DefaultAsTyped{}
runPassOnSchema(t, pass, schema, expected)
}

func TestDefaultAsTyped_NoDefault(t *testing.T) {
schema := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("AString", ast.String()),
)),
),
}

pass := &DefaultAsTyped{}
runPassOnSchema(t, pass, schema, schema)
}

func TestDefaultAsTyped_StructDefault(t *testing.T) {
structDefault := map[string]any{
"name": "alice",
}

// Build schema with a field whose type has a struct default set directly
innerType := ast.NewStruct(
ast.NewStructField("name", ast.String()),
)
innerType.Default = structDefault

schema := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("inner", innerType),
)),
),
}

expectedInnerType := innerType
expectedInnerType.TypedDefault = &ast.TypeDefault{
Struct: map[string]*ast.TypeDefault{
"name": {Scalar: &ast.ScalarType{ScalarKind: ast.KindString, Value: "alice"}},
},
}

expected := &ast.Schema{
Package: "test",
Objects: testutils.ObjectsMap(
ast.NewObject("test", "SomeObject", ast.NewStruct(
ast.NewStructField("inner", expectedInnerType),
)),
),
}

pass := &DefaultAsTyped{}
runPassOnSchema(t, pass, schema, expected)
}
70 changes: 70 additions & 0 deletions internal/ast/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,66 @@ func (constraint TypeConstraint) DeepCopy() TypeConstraint {
// JenniesHints meant to be used by jennies, to gain a finer control on the codegen from schemas
type JenniesHints map[string]any

// TypeDefault represents a typed default value for a field or type.
// Exactly one of Scalar, Array, or Struct should be non-nil.
type TypeDefault struct {
// Scalar holds the default for scalar and enum types.
// ScalarKind identifies the scalar type; Value holds the actual Go value.
Scalar *ScalarType

// Array holds typed defaults for each element of an array default.
Array []*TypeDefault

// Struct maps field names to typed defaults, used for struct/ref types
// whose default is an object literal.
Struct map[string]*TypeDefault
}

// AnyToTypedDefault converts a raw any default value into a *TypeDefault.
// This is the canonical conversion used by the DefaultAsTyped compiler pass and
// by EffectiveTypedDefault for on-the-fly conversion when the pass hasn't run.
func AnyToTypedDefault(value any) *TypeDefault {
switch v := value.(type) {
case string:
return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindString, Value: v}}
case bool:
return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindBool, Value: v}}
case float64:
return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindFloat64, Value: v}}
case int64:
return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindInt64, Value: v}}
case int:
return &TypeDefault{Scalar: &ScalarType{ScalarKind: KindInt64, Value: v}}
case []any:
arr := make([]*TypeDefault, len(v))
for i, elem := range v {
arr[i] = AnyToTypedDefault(elem)
}
return &TypeDefault{Array: arr}
case map[string]any:
structMap := make(map[string]*TypeDefault, len(v))
for k, val := range v {
structMap[k] = AnyToTypedDefault(val)
}
return &TypeDefault{Struct: structMap}
default:
return &TypeDefault{Scalar: &ScalarType{Value: value}}
}
}

// EffectiveTypedDefault returns TypedDefault if populated by the DefaultAsTyped
// compiler pass. If not set (e.g. when the pass hasn't run), it derives a
// TypeDefault from the raw Default field on the fly.
func (t Type) EffectiveTypedDefault() *TypeDefault {
if t.TypedDefault != nil {
return t.TypedDefault
}
if t.Default != nil {
return AnyToTypedDefault(t.Default)
}
return nil
}

// Type representing every type defined by the IR.
// Bonus: in a way that can be (un)marshaled to/from JSON,
// which is useful for unit tests.
Expand All @@ -92,6 +152,10 @@ type Type struct {
Nullable bool
Default any `json:",omitempty"`

// TypedDefault is a typed representation of Default, populated by the
// DefaultAsTyped compiler pass. Use this instead of type-asserting Default.
TypedDefault *TypeDefault `json:",omitempty"`

Disjunction *DisjunctionType `json:",omitempty"`
Array *ArrayType `json:",omitempty"`
Enum *EnumType `json:",omitempty"`
Expand Down Expand Up @@ -340,6 +404,12 @@ func Default(value any) TypeOption {
}
}

func TypedDefaultOpt(td *TypeDefault) TypeOption {
return func(def *Type) {
def.TypedDefault = td
}
}

func Hints(hints JenniesHints) TypeOption {
return func(def *Type) {
def.Hints = hints
Expand Down
1 change: 1 addition & 0 deletions internal/jennies/golang/jennies.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ func (language *Language) CompilerPasses() compiler.Passes {
&compiler.DisjunctionInferMapping{},
&compiler.UndiscriminatedDisjunctionToAny{},
&compiler.DisjunctionToType{},
&compiler.DefaultAsTyped{},
}
}

Expand Down
47 changes: 27 additions & 20 deletions internal/jennies/golang/rawtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (jenny RawTypes) generateConstructor(buffer *strings.Builder, context langu
buffer.WriteString("\n}\n")
}

func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast.RefType, objectType ast.Type, maybeExtraDefaults any) string {
func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast.RefType, objectType ast.Type, extraDefaults *ast.TypeDefault) string {
var buffer strings.Builder

objectName := formatObjectName(objectRef.ReferredType)
Expand All @@ -252,16 +252,16 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast

buffer.WriteString(objectName + "{\n")

extraDefaults := map[string]any{}
if val, ok := maybeExtraDefaults.(map[string]any); ok {
extraDefaults = val
}

for _, field := range objectType.Struct.Fields {
resolvedFieldType := context.ResolveRefs(field.Type)

needsExplicitDefault := field.Type.Default != nil ||
extraDefaults[field.Name] != nil ||
var fieldExtraDefault *ast.TypeDefault
if extraDefaults != nil && extraDefaults.Struct != nil {
fieldExtraDefault = extraDefaults.Struct[field.Name]
}

needsExplicitDefault := field.Type.EffectiveTypedDefault() != nil ||
fieldExtraDefault != nil ||
(field.Required && field.Type.IsRef() && resolvedFieldType.IsStruct()) ||
(field.Required && field.Type.IsArray()) ||
(field.Required && field.Type.IsMap()) ||
Expand All @@ -275,11 +275,10 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast
defaultValue := ""

// nolint:gocritic
if extraDefault, ok := extraDefaults[field.Name]; ok {
defaultValue = formatScalar(extraDefault)

if fieldExtraDefault != nil {
if field.Type.IsRef() && resolvedFieldType.IsStructGeneratedFromDisjunction() {
disjunctionBranchName := formatFieldName(anyToDisjunctionBranchName(extraDefault))
defaultValue = formatTypedDefault(fieldExtraDefault)
disjunctionBranchName := formatFieldName(typedDefaultToDisjunctionBranchName(fieldExtraDefault))
disjunctionBranch, found := resolvedFieldType.Struct.FieldByName(disjunctionBranchName)
if !found {
disjunctionBranchName = "Any"
Expand All @@ -295,22 +294,28 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast
defaultValue += fmt.Sprintf("\t%s: %s,\n", formatFieldName(disjunctionBranchName), actualDefault)
defaultValue += "}"

if field.Type.Nullable {
defaultValue = "&" + defaultValue
}
} else if field.Type.IsRef() && resolvedFieldType.IsStruct() && fieldExtraDefault.Struct != nil {
defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, fieldExtraDefault)
if field.Type.Nullable {
defaultValue = "&" + defaultValue
}
} else {
defaultValue = formatTypedDefault(fieldExtraDefault)
defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType)
}
} else if field.Type.IsConcreteScalar() {
defaultValue = formatScalar(field.Type.Scalar.Value)

defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType)
} else if resolvedFieldType.IsAnyOf(ast.KindScalar, ast.KindMap, ast.KindArray) && field.Type.Default != nil {
defaultValue = formatScalar(field.Type.Default)
} else if resolvedFieldType.IsAnyOf(ast.KindScalar, ast.KindMap, ast.KindArray) && field.Type.EffectiveTypedDefault() != nil {
defaultValue = formatTypedDefault(field.Type.EffectiveTypedDefault())

defaultValue = jenny.maybeValueAsPointer(defaultValue, field.Type.Nullable, resolvedFieldType)
} else if field.Type.IsRef() && resolvedFieldType.IsStruct() && field.Type.Default != nil {
defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, field.Type.Default)
} else if field.Type.IsRef() && resolvedFieldType.IsStruct() && field.Type.EffectiveTypedDefault() != nil {
defaultValue = jenny.defaultsForStruct(context, *field.Type.Ref, resolvedFieldType, field.Type.EffectiveTypedDefault())
if field.Type.Nullable {
defaultValue = "&" + defaultValue
}
Expand All @@ -327,10 +332,12 @@ func (jenny RawTypes) defaultsForStruct(context languages.Context, objectRef ast
}
} else if field.Type.IsRef() && resolvedFieldType.IsEnum() {
memberName := resolvedFieldType.Enum.Values[0].Name
for _, member := range resolvedFieldType.Enum.Values {
if member.Value == field.Type.Default {
memberName = member.Name
break
if td := field.Type.EffectiveTypedDefault(); td != nil && td.Scalar != nil {
for _, member := range resolvedFieldType.Enum.Values {
if member.Value == td.Scalar.Value {
memberName = member.Name
break
}
}
}

Expand Down
Loading
Loading