Skip to content

Commit 60dbe7b

Browse files
committed
feat(protobuf): support loading protobuf schemas from a directory
The protobuf schema parser now supports directories by passing the directory path as the schema file. It will parse all '.proto' files within the directory, allowing for schemas split across multiple files. Signed-off-by: Jiyong Huang <huangjy@emqx.io>
1 parent 5dee73c commit 60dbe7b

File tree

6 files changed

+258
-17
lines changed

6 files changed

+258
-17
lines changed

internal/converter/protobuf/converter.go

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ package protobuf
1616

1717
import (
1818
"fmt"
19+
"os"
20+
"path/filepath"
21+
"strings"
1922

2023
"github.com/jhump/protoreflect/desc" //nolint:staticcheck
2124
"github.com/jhump/protoreflect/desc/protoparse" //nolint:staticcheck
@@ -44,19 +47,52 @@ func NewConverter(schemaFile string, soFile string, messageName string) (message
4447
if soFile != "" {
4548
return static.LoadStaticConverter(soFile, messageName)
4649
} else {
47-
if fds, err := protoParser.ParseFiles(schemaFile); err != nil {
48-
return nil, fmt.Errorf("parse schema file %s failed: %s", schemaFile, err)
49-
} else {
50-
messageDescriptor := fds[0].FindMessage(messageName)
51-
if messageDescriptor == nil {
52-
return nil, fmt.Errorf("message type %s not found in schema file %s", messageName, schemaFile)
50+
protoFiles, err := collectProtoFiles(schemaFile)
51+
if err != nil {
52+
return nil, fmt.Errorf("collect proto files from %s failed: %s", schemaFile, err)
53+
}
54+
fds, err := protoParser.ParseFiles(protoFiles...)
55+
if err != nil {
56+
return nil, fmt.Errorf("parse schema file(s) %s failed: %s", schemaFile, err)
57+
}
58+
for _, fd := range fds {
59+
messageDescriptor := fd.FindMessage(messageName)
60+
if messageDescriptor != nil {
61+
return &Converter{
62+
descriptor: messageDescriptor,
63+
fc: GetFieldConverter(),
64+
}, nil
5365
}
54-
return &Converter{
55-
descriptor: messageDescriptor,
56-
fc: GetFieldConverter(),
57-
}, nil
5866
}
67+
return nil, fmt.Errorf("message type %s not found in schema path %s", messageName, schemaFile)
68+
}
69+
}
70+
71+
// collectProtoFiles returns a list of .proto file paths for the given path.
72+
// If the path is a directory, it returns full paths for directory entries.
73+
// If it is a single file, it returns the path as-is.
74+
func collectProtoFiles(path string) ([]string, error) {
75+
info, err := os.Stat(path)
76+
if err != nil {
77+
return nil, err
78+
}
79+
if !info.IsDir() {
80+
return []string{path}, nil
81+
}
82+
entries, err := os.ReadDir(path)
83+
if err != nil {
84+
return nil, err
85+
}
86+
var result []string
87+
for _, e := range entries {
88+
if !e.IsDir() && strings.HasSuffix(e.Name(), ".proto") {
89+
result = append(result, filepath.Join(path, e.Name()))
90+
}
91+
}
92+
if len(result) == 0 {
93+
return nil, fmt.Errorf("no .proto files found in directory %s", path)
5994
}
95+
return result, nil
6096
}
6197

6298
func (c *Converter) Encode(ctx api.StreamContext, d any) (b []byte, err error) {

internal/converter/protobuf/converter_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package protobuf
1616

1717
import (
1818
"fmt"
19+
"path/filepath"
1920
"reflect"
2021
"testing"
2122

@@ -300,3 +301,53 @@ func TestErr(t *testing.T) {
300301
require.True(t, ok)
301302
require.Equal(t, errorx.CovnerterErr, errWithCode.Code())
302303
}
304+
305+
// ---- collectProtoFiles tests ----
306+
307+
func TestCollectProtoFiles_SingleFile(t *testing.T) {
308+
result, err := collectProtoFiles("../../schema/test/test1.proto")
309+
require.NoError(t, err)
310+
assert.Equal(t, []string{"../../schema/test/test1.proto"}, result)
311+
}
312+
313+
func TestCollectProtoFiles_Directory(t *testing.T) {
314+
result, err := collectProtoFiles("../../schema/test/multidir")
315+
require.NoError(t, err)
316+
assert.Len(t, result, 2)
317+
assert.Contains(t, result, filepath.Join("../../schema/test/multidir", "msg_a.proto"))
318+
assert.Contains(t, result, filepath.Join("../../schema/test/multidir", "msg_b.proto"))
319+
}
320+
321+
func TestCollectProtoFiles_EmptyDir(t *testing.T) {
322+
emptyDir := t.TempDir()
323+
_, err := collectProtoFiles(emptyDir)
324+
assert.Error(t, err)
325+
assert.Contains(t, err.Error(), "no .proto files found")
326+
}
327+
328+
func TestCollectProtoFiles_NotExist(t *testing.T) {
329+
_, err := collectProtoFiles("../../schema/test/nonexistent")
330+
assert.Error(t, err)
331+
}
332+
333+
// ---- Directory-based NewConverter tests ----
334+
335+
func TestNewConverter_FromDirectory(t *testing.T) {
336+
// SensorData is defined in multidir/msg_a.proto
337+
c, err := NewConverter("../../schema/test/multidir", "", "SensorData")
338+
require.NoError(t, err)
339+
require.NotNil(t, c)
340+
}
341+
342+
func TestNewConverter_FromDirectory_SecondFile(t *testing.T) {
343+
// VehicleStatus is defined in multidir/msg_b.proto
344+
c, err := NewConverter("../../schema/test/multidir", "", "VehicleStatus")
345+
require.NoError(t, err)
346+
require.NotNil(t, c)
347+
}
348+
349+
func TestNewConverter_FromDirectory_NotFound(t *testing.T) {
350+
_, err := NewConverter("../../schema/test/multidir", "", "NonExistentMsg")
351+
assert.Error(t, err)
352+
assert.Contains(t, err.Error(), "not found")
353+
}

internal/schema/protobuf.go

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ func (p *PbType) Scan(logger api.Logger, schemaDir string) (map[string]*modules.
3636
} else {
3737
newSchemas = make(map[string]*modules.Files, len(files))
3838
for _, file := range files {
39+
// Subdirectory: treat as a single schema ID containing multiple .proto files
40+
if file.IsDir() {
41+
schemaId := file.Name()
42+
ffs, ok := newSchemas[schemaId]
43+
if !ok {
44+
ffs = &modules.Files{}
45+
newSchemas[schemaId] = ffs
46+
}
47+
// SchemaFile points to the directory itself
48+
ffs.SchemaFile = filepath.Join(schemaDir, file.Name())
49+
logger.Infof("schema directory %s/%s loaded", schemaDir, schemaId)
50+
continue
51+
}
3952
fileName := filepath.Base(file.Name())
4053
ext := filepath.Ext(fileName)
4154
schemaId := strings.TrimSuffix(fileName, filepath.Ext(fileName))
@@ -59,15 +72,48 @@ func (p *PbType) Scan(logger api.Logger, schemaDir string) (map[string]*modules.
5972
}
6073

6174
func (p *PbType) Infer(_ api.Logger, filePath string, messageId string) (ast.StreamFields, error) {
62-
if fds, err := protoParser.ParseFiles(filePath); err != nil {
63-
return nil, fmt.Errorf("parse schema file %s failed: %s", filePath, err)
64-
} else {
65-
messageDescriptor := fds[0].FindMessage(messageId)
66-
if messageDescriptor == nil {
67-
return nil, fmt.Errorf("message type %s not found in schema file %s", messageId, filePath)
75+
protoFiles, err := collectProtoFiles(filePath)
76+
if err != nil {
77+
return nil, fmt.Errorf("collect proto files from %s failed: %s", filePath, err)
78+
}
79+
fds, err := protoParser.ParseFiles(protoFiles...)
80+
if err != nil {
81+
return nil, fmt.Errorf("parse schema file(s) %s failed: %s", filePath, err)
82+
}
83+
for _, fd := range fds {
84+
messageDescriptor := fd.FindMessage(messageId)
85+
if messageDescriptor != nil {
86+
return convertMessage(messageDescriptor)
6887
}
69-
return convertMessage(messageDescriptor)
7088
}
89+
return nil, fmt.Errorf("message type %s not found in schema path %s", messageId, filePath)
90+
}
91+
92+
// collectProtoFiles returns a list of .proto file paths for the given path.
93+
// If the path is a directory, it returns dir-relative paths (e.g. "multidir/msg_a.proto").
94+
// If it is a single file, it returns the path as-is.
95+
func collectProtoFiles(path string) ([]string, error) {
96+
info, err := os.Stat(path)
97+
if err != nil {
98+
return nil, err
99+
}
100+
if !info.IsDir() {
101+
return []string{path}, nil
102+
}
103+
entries, err := os.ReadDir(path)
104+
if err != nil {
105+
return nil, err
106+
}
107+
var result []string
108+
for _, e := range entries {
109+
if !e.IsDir() && strings.HasSuffix(e.Name(), ".proto") {
110+
result = append(result, filepath.Join(path, e.Name()))
111+
}
112+
}
113+
if len(result) == 0 {
114+
return nil, fmt.Errorf("no .proto files found in directory %s", path)
115+
}
116+
return result, nil
71117
}
72118

73119
func convertMessage(m *desc.MessageDescriptor) (ast.StreamFields, error) {

internal/schema/protobuf_test.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package schema
22

33
import (
4+
"path/filepath"
45
"testing"
56

67
"github.com/stretchr/testify/assert"
@@ -51,3 +52,92 @@ func TestInferProtobufWithEmbedType(t *testing.T) {
5152
t.Errorf("InferProtobuf result is not expected, got %v, expected %v", result, expected)
5253
}
5354
}
55+
56+
// ---- collectProtoFiles tests ----
57+
58+
func TestCollectProtoFiles_SingleFile(t *testing.T) {
59+
result, err := collectProtoFiles("test/test1.proto")
60+
require.NoError(t, err)
61+
assert.Equal(t, []string{"test/test1.proto"}, result)
62+
}
63+
64+
func TestCollectProtoFiles_Directory(t *testing.T) {
65+
result, err := collectProtoFiles("test/multidir")
66+
require.NoError(t, err)
67+
assert.Len(t, result, 2)
68+
assert.Contains(t, result, filepath.Join("test/multidir", "msg_a.proto"))
69+
assert.Contains(t, result, filepath.Join("test/multidir", "msg_b.proto"))
70+
}
71+
72+
func TestCollectProtoFiles_EmptyDir(t *testing.T) {
73+
emptyDir := t.TempDir()
74+
_, err := collectProtoFiles(emptyDir)
75+
assert.Error(t, err)
76+
assert.Contains(t, err.Error(), "no .proto files found")
77+
}
78+
79+
func TestCollectProtoFiles_NotExist(t *testing.T) {
80+
_, err := collectProtoFiles("test/nonexistent")
81+
assert.Error(t, err)
82+
}
83+
84+
// ---- Directory-based Scan tests ----
85+
86+
type mockLogger struct{}
87+
88+
func (m *mockLogger) Debug(args ...interface{}) {}
89+
func (m *mockLogger) Info(args ...interface{}) {}
90+
func (m *mockLogger) Warn(args ...interface{}) {}
91+
func (m *mockLogger) Error(args ...interface{}) {}
92+
func (m *mockLogger) Debugln(args ...interface{}) {}
93+
func (m *mockLogger) Infoln(args ...interface{}) {}
94+
func (m *mockLogger) Warnln(args ...interface{}) {}
95+
func (m *mockLogger) Errorln(args ...interface{}) {}
96+
func (m *mockLogger) Debugf(format string, args ...interface{}) {}
97+
func (m *mockLogger) Infof(format string, args ...interface{}) {}
98+
func (m *mockLogger) Warnf(format string, args ...interface{}) {}
99+
func (m *mockLogger) Errorf(format string, args ...interface{}) {}
100+
101+
func TestScan_WithSubdirectory(t *testing.T) {
102+
pt := &PbType{}
103+
schemas, err := pt.Scan(&mockLogger{}, "test")
104+
require.NoError(t, err)
105+
// Should find "multidir" as a subdirectory-based schema ID
106+
assert.Contains(t, schemas, "multidir")
107+
assert.NotEmpty(t, schemas["multidir"].SchemaFile)
108+
// Should also find regular .proto files
109+
assert.Contains(t, schemas, "test1")
110+
}
111+
112+
// ---- Directory-based Infer tests ----
113+
114+
func TestInfer_FromDirectory(t *testing.T) {
115+
pt := &PbType{}
116+
// Infer a message from directory containing multiple proto files
117+
result, err := pt.Infer(nil, "test/multidir", "SensorData")
118+
require.NoError(t, err)
119+
expected := ast.StreamFields{
120+
{Name: "temperature", FieldType: &ast.BasicType{Type: ast.FLOAT}},
121+
{Name: "humidity", FieldType: &ast.BasicType{Type: ast.BIGINT}},
122+
}
123+
assert.Equal(t, expected, result)
124+
}
125+
126+
func TestInfer_FromDirectory_SecondFile(t *testing.T) {
127+
pt := &PbType{}
128+
// Infer a message defined in the second proto file
129+
result, err := pt.Infer(nil, "test/multidir", "VehicleStatus")
130+
require.NoError(t, err)
131+
// VehicleStatus has: speed (int32), vin (string), battery (BatteryInfo)
132+
assert.Len(t, result, 3)
133+
assert.Equal(t, "speed", result[0].Name)
134+
assert.Equal(t, "vin", result[1].Name)
135+
assert.Equal(t, "battery", result[2].Name)
136+
}
137+
138+
func TestInfer_FromDirectory_MessageNotFound(t *testing.T) {
139+
pt := &PbType{}
140+
_, err := pt.Infer(nil, "test/multidir", "NonExistentMessage")
141+
assert.Error(t, err)
142+
assert.Contains(t, err.Error(), "not found")
143+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
syntax = "proto3";
2+
3+
message SensorData {
4+
double temperature = 1;
5+
int32 humidity = 2;
6+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
syntax = "proto3";
2+
3+
message VehicleStatus {
4+
int32 speed = 1;
5+
string vin = 2;
6+
BatteryInfo battery = 3;
7+
}
8+
9+
message BatteryInfo {
10+
double voltage = 1;
11+
int32 level = 2;
12+
}

0 commit comments

Comments
 (0)