Skip to content

Commit 02e6d94

Browse files
authored
Merge pull request #4 from brucechen7274/main
feat: Add enum support to code generation and JSON schema generation
2 parents cd42837 + bd783cb commit 02e6d94

File tree

5 files changed

+160
-38
lines changed

5 files changed

+160
-38
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/sashabaranov/go-openai v1.41.1
77
github.com/spf13/cobra v1.9.1
88
github.com/xeipuuv/gojsonschema v1.2.0
9+
golang.org/x/tools v0.36.0
910
gopkg.in/yaml.v3 v3.0.1
1011
)
1112

@@ -16,5 +17,4 @@ require (
1617
github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect
1718
golang.org/x/mod v0.27.0 // indirect
1819
golang.org/x/sync v0.16.0 // indirect
19-
golang.org/x/tools v0.36.0 // indirect
2020
)

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
22
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
33
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
5+
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
46
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
57
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
68
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

pkg/gen/gogen.go

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ func (gen *CodeGenerator) Generate(spec *spec.Spec) ([]byte, error) {
4444
gen.write("// Code generated by suricata-gen; DO NOT EDIT.\n\n")
4545
gen.write("package %s\n\n", packageName(spec.Package))
4646

47+
// Generate enums first
48+
if len(spec.Enums) > 0 {
49+
gen.generateEnums(spec.Enums)
50+
}
51+
4752
if len(spec.Messages) > 0 {
48-
if err := gen.generateMessageSchemas(spec.Messages); err != nil {
53+
if err := gen.generateMessageSchemas(spec.Messages, spec.Enums); err != nil {
4954
return nil, err
5055
}
51-
gen.generateTypes(spec.Messages)
56+
gen.generateTypes(spec.Messages, spec.Enums)
5257
}
5358

5459
// Generate RPC methods
@@ -64,12 +69,69 @@ func (gen *CodeGenerator) Generate(spec *spec.Spec) ([]byte, error) {
6469
return src, nil
6570
}
6671

67-
func (gen *CodeGenerator) generateMessageSchemas(messages map[string]spec.Message) error {
72+
func (gen *CodeGenerator) generateEnums(enums map[string]spec.Enum) {
73+
if len(enums) == 0 {
74+
return
75+
}
76+
77+
// Generate enum type definitions
78+
gen.write("// Enum types\n")
79+
gen.write("type (\n")
80+
for name := range enums {
81+
gen.write("\t%s string\n", name)
82+
}
83+
gen.write(")\n\n")
84+
85+
// Generate enum constants and methods for each enum
86+
for name, enum := range enums {
87+
gen.generateEnumConstants(name, enum)
88+
gen.generateEnumMethods(name, enum)
89+
}
90+
}
91+
92+
func (gen *CodeGenerator) generateEnumConstants(name string, enum spec.Enum) {
93+
gen.write("// %s values\n", name)
94+
gen.write("const (\n")
95+
for _, value := range enum.Values {
96+
constName := name + CapitalizeFirst(toCamelCase(value))
97+
gen.write("\t%s %s = \"%s\"\n", constName, name, value)
98+
}
99+
gen.write(")\n\n")
100+
}
101+
102+
func (gen *CodeGenerator) generateEnumMethods(name string, enum spec.Enum) {
103+
// Generate IsValid method
104+
gen.write("// IsValid checks if the %s value is valid\n", name)
105+
gen.write("func (e %s) IsValid() bool {\n", name)
106+
gen.write("\tswitch e {\n")
107+
gen.write("\tcase ")
108+
for i, value := range enum.Values {
109+
if i > 0 {
110+
gen.write(", ")
111+
}
112+
constName := name + CapitalizeFirst(toCamelCase(value))
113+
gen.write(constName)
114+
}
115+
gen.write(":\n")
116+
gen.write("\t\treturn true\n")
117+
gen.write("\tdefault:\n")
118+
gen.write("\t\treturn false\n")
119+
gen.write("\t}\n")
120+
gen.write("}\n\n")
121+
122+
// Generate String method
123+
gen.write("// String returns the string representation of %s\n", name)
124+
gen.write("func (e %s) String() string {\n", name)
125+
gen.write("\treturn string(e)\n")
126+
gen.write("}\n\n")
127+
}
128+
129+
func (gen *CodeGenerator) generateMessageSchemas(messages map[string]spec.Message, enums map[string]spec.Enum) error {
68130
schemaGen := NewJSONSchemaGenerator()
69131

70132
gen.write("var (\n")
71133
for name, msg := range messages {
72-
schema, err := schemaGen.GenerateJSONSchema(name, &msg, messages)
134+
schema, err := schemaGen.GenerateJSONSchema(name, &msg, messages, enums)
73135
if err != nil {
74136
return err
75137
}
@@ -85,13 +147,13 @@ func (gen *CodeGenerator) generateMessageSchemas(messages map[string]spec.Messag
85147
return nil
86148
}
87149

88-
func (gen *CodeGenerator) generateTypes(messages map[string]spec.Message) {
150+
func (gen *CodeGenerator) generateTypes(messages map[string]spec.Message, enums map[string]spec.Enum) {
89151
// Generate structs for messages
90152
gen.write("type (\n")
91153
for name, msg := range messages {
92154
gen.write(fmt.Sprintf("\t%s struct {\n", name))
93155
for _, field := range msg.Fields {
94-
goType := goTypeForField(field)
156+
goType := goTypeForField(field, enums)
95157
fieldName := toCamelCase(field.Name)
96158

97159
tagParts := []string{field.Name}
@@ -253,7 +315,7 @@ func toCamelCase(s string) string {
253315
return strings.Join(parts, "")
254316
}
255317

256-
func goTypeForField(f spec.Field) string {
318+
func goTypeForField(f spec.Field, enums map[string]spec.Enum) string {
257319
var goType string
258320
switch f.Type {
259321
case "string":
@@ -267,8 +329,13 @@ func goTypeForField(f spec.Field) string {
267329
case "datetime":
268330
goType = "time.Time" // RFC3339 format
269331
default:
270-
// Custom message type
271-
goType = f.Type
332+
// Check if it's an enum type
333+
if _, isEnum := enums[f.Type]; isEnum {
334+
goType = f.Type // Use the enum type name directly
335+
} else {
336+
// Custom message type
337+
goType = f.Type
338+
}
272339
}
273340

274341
// Pointer for optional scalar or custom type (but not slices)

pkg/gen/jsonschema.go

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ func NewJSONSchemaGenerator() *JSONSchemaGenerator {
3535

3636
// GenerateJSONSchema returns a JSON Schema object (as a map) for the given message.
3737
// It recursively includes referenced custom types.
38-
func (gen *JSONSchemaGenerator) GenerateJSONSchema(name string, msg *spec.Message, allMessages map[string]spec.Message) (JSONSchema, error) {
38+
func (gen *JSONSchemaGenerator) GenerateJSONSchema(name string, msg *spec.Message, allMessages map[string]spec.Message, allEnums map[string]spec.Enum) (JSONSchema, error) {
3939
schema, has := gen.schemas[name]
4040
if has {
4141
return schema, nil
4242
}
4343

44-
schema, err := gen.generateJSONSchema(msg, allMessages)
44+
schema, err := gen.generateJSONSchema(msg, allMessages, allEnums)
4545
if err != nil {
4646
return nil, err
4747
}
@@ -50,7 +50,7 @@ func (gen *JSONSchemaGenerator) GenerateJSONSchema(name string, msg *spec.Messag
5050
return schema, nil
5151
}
5252

53-
func (gen *JSONSchemaGenerator) generateJSONSchema(msg *spec.Message, allMessages map[string]spec.Message) (JSONSchema, error) {
53+
func (gen *JSONSchemaGenerator) generateJSONSchema(msg *spec.Message, allMessages map[string]spec.Message, allEnums map[string]spec.Enum) (JSONSchema, error) {
5454
properties := make(map[string]any)
5555

5656
schema := map[string]any{
@@ -60,7 +60,7 @@ func (gen *JSONSchemaGenerator) generateJSONSchema(msg *spec.Message, allMessage
6060

6161
requiredFields := []string{}
6262
for _, field := range msg.Fields {
63-
fieldSchema, err := gen.fieldToSchema(field, allMessages)
63+
fieldSchema, err := gen.fieldToSchema(field, allMessages, allEnums)
6464
if err != nil {
6565
return nil, fmt.Errorf("field %q: %w", field.Name, err)
6666
}
@@ -80,36 +80,50 @@ func (gen *JSONSchemaGenerator) generateJSONSchema(msg *spec.Message, allMessage
8080
}
8181

8282
// fieldToSchema generates the JSON Schema for a single field, recursively if needed.
83-
func (gen *JSONSchemaGenerator) fieldToSchema(field spec.Field, allMessages map[string]spec.Message) (map[string]interface{}, error) {
83+
func (gen *JSONSchemaGenerator) fieldToSchema(field spec.Field, allMessages map[string]spec.Message, allEnums map[string]spec.Enum) (map[string]interface{}, error) {
8484
var baseSchema map[string]any
8585

86-
switch field.Type {
87-
case "string":
88-
baseSchema = map[string]any{"type": "string"}
89-
case "int", "int32", "int64":
90-
baseSchema = map[string]any{"type": "integer"}
91-
case "float", "float32", "float64":
92-
baseSchema = map[string]any{"type": "number"}
93-
case "bool":
94-
baseSchema = map[string]any{"type": "boolean"}
95-
case "datetime":
96-
baseSchema = map[string]any{"type": "string", "format": "date-time"} // RFC3339
97-
default:
98-
// Custom type - lookup in allMessages
99-
msg, ok := allMessages[field.Type]
100-
if !ok {
101-
return nil, fmt.Errorf("unknown custom type %q", field.Type)
86+
// Check if it's an enum type
87+
if enum, isEnum := allEnums[field.Type]; isEnum {
88+
baseSchema = map[string]any{
89+
"type": "string",
90+
"enum": enum.Values,
10291
}
92+
if enum.Description != "" {
93+
baseSchema["description"] = enum.Description
94+
}
95+
}
10396

104-
// Recursive schema for nested message
105-
nestedSchema, err := gen.GenerateJSONSchema(field.Type, &msg, allMessages)
106-
if err != nil {
107-
return nil, err
97+
// Check primitive types and custom message types
98+
if baseSchema == nil {
99+
switch field.Type {
100+
case "string":
101+
baseSchema = map[string]any{"type": "string"}
102+
case "int", "int32", "int64":
103+
baseSchema = map[string]any{"type": "integer"}
104+
case "float", "float32", "float64":
105+
baseSchema = map[string]any{"type": "number"}
106+
case "bool":
107+
baseSchema = map[string]any{"type": "boolean"}
108+
case "datetime":
109+
baseSchema = map[string]any{"type": "string", "format": "date-time"} // RFC3339
110+
default:
111+
// Custom message type - lookup in allMessages
112+
msg, ok := allMessages[field.Type]
113+
if !ok {
114+
return nil, fmt.Errorf("unknown custom type %q", field.Type)
115+
}
116+
117+
// Recursive schema for nested message
118+
nestedSchema, err := gen.GenerateJSONSchema(field.Type, &msg, allMessages, allEnums)
119+
if err != nil {
120+
return nil, err
121+
}
122+
baseSchema = nestedSchema
108123
}
109-
baseSchema = nestedSchema
110124
}
111125

112-
if field.Description != "" {
126+
if field.Description != "" && baseSchema["description"] == nil {
113127
baseSchema["description"] = field.Description
114128
}
115129

pkg/spec/spec.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@ import (
2626
type Spec struct {
2727
Version string `yaml:"version"`
2828
Package string `yaml:"package"`
29+
Enums map[string]Enum `yaml:"enums"`
2930
Messages map[string]Message `yaml:"messages"`
3031
Tools map[string]Tool `yaml:"tools"`
3132
Agents map[string]Agent `yaml:"agents"`
3233
}
3334

35+
type Enum struct {
36+
Description string `yaml:"description,omitempty"`
37+
Values []string `yaml:"values"`
38+
}
39+
3440
type Message struct {
3541
Fields []Field `yaml:"fields"`
3642
}
@@ -86,6 +92,12 @@ func isPrimitiveType(t string) bool {
8692
}
8793
}
8894

95+
// isEnumType checks if the given type is a defined enum type
96+
func (spec *Spec) isEnumType(t string) bool {
97+
_, exists := spec.Enums[t]
98+
return exists
99+
}
100+
89101
func (spec *Spec) Validate() error {
90102
if spec.Version == "" {
91103
return fmt.Errorf("spec: version is required")
@@ -94,6 +106,10 @@ func (spec *Spec) Validate() error {
94106
return fmt.Errorf("spec: package is required")
95107
}
96108

109+
if err := spec.validateEnums(); err != nil {
110+
return err
111+
}
112+
97113
if err := spec.validateMessages(); err != nil {
98114
return err
99115
}
@@ -104,6 +120,29 @@ func (spec *Spec) Validate() error {
104120
return spec.validateAgents()
105121
}
106122

123+
func (spec *Spec) validateEnums() error {
124+
for name, enum := range spec.Enums {
125+
if name == "" {
126+
return fmt.Errorf("spec: enum has empty name")
127+
}
128+
if len(enum.Values) == 0 {
129+
return fmt.Errorf("spec: enum %q has no values", name)
130+
}
131+
// Check for duplicate values
132+
seen := make(map[string]bool)
133+
for _, value := range enum.Values {
134+
if value == "" {
135+
return fmt.Errorf("spec: enum %q has empty value", name)
136+
}
137+
if seen[value] {
138+
return fmt.Errorf("spec: enum %q has duplicate value %q", name, value)
139+
}
140+
seen[value] = true
141+
}
142+
}
143+
return nil
144+
}
145+
107146
func (spec *Spec) validateMessages() error {
108147
for name, msg := range spec.Messages {
109148
if name == "" {
@@ -117,7 +156,7 @@ func (spec *Spec) validateMessages() error {
117156
return fmt.Errorf("spec: field %q in message %q has empty type", field.Name, name)
118157
}
119158
// Validate field type existence
120-
if !isPrimitiveType(field.Type) {
159+
if !isPrimitiveType(field.Type) && !spec.isEnumType(field.Type) {
121160
if _, ok := spec.Messages[field.Type]; !ok {
122161
return fmt.Errorf("spec: field %q in message %q references undefined type %q", field.Name, name, field.Type)
123162
}

0 commit comments

Comments
 (0)