Skip to content

Commit 92e1b7b

Browse files
authored
Truncate service error message (#10136)
## What changed? Truncate service error message to prevent grpc-go to return `RST_STREAM` instead of the service error. ## Why? grpc-go returns `RST_STREAM` error when the error message is too long, masking the actual service error returned by Temporal. ## How did you test it? - [x] built - [ ] run locally and tested manually - [ ] covered by existing tests - [x] added new unit test(s) - [ ] added new functional test(s) ## Potential risks
1 parent 2ef71ec commit 92e1b7b

2 files changed

Lines changed: 104 additions & 10 deletions

File tree

common/rpc/interceptor/service_error_interceptor.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@ import (
66

77
"go.temporal.io/api/serviceerror"
88
"go.temporal.io/server/common/persistence/serialization"
9+
"go.temporal.io/server/common/util"
910
"google.golang.org/grpc"
11+
"google.golang.org/grpc/status"
12+
)
13+
14+
const (
15+
maxMessageLength = 4000
16+
truncatedSuffix = "... <truncated>"
1017
)
1118

1219
func ServiceErrorInterceptor(
@@ -24,5 +31,14 @@ func ServiceErrorInterceptor(
2431
if errors.As(err, &deserializationError) || errors.As(err, &serializationError) {
2532
err = serviceerror.NewDataLoss(err.Error())
2633
}
27-
return resp, serviceerror.ToStatus(err).Err()
34+
35+
// truncate message length if needed
36+
st := serviceerror.ToStatus(err)
37+
if len(st.Message()) > maxMessageLength {
38+
p := st.Proto()
39+
p.Message = util.TruncateUTF8(p.Message, maxMessageLength-len(truncatedSuffix)) + truncatedSuffix
40+
st = status.FromProto(p)
41+
}
42+
43+
return resp, st.Err()
2844
}

common/rpc/interceptor/service_error_interceptor_test.go

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ package interceptor
22

33
import (
44
"context"
5+
"strings"
56
"testing"
67

7-
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
89
enumspb "go.temporal.io/api/enums/v1"
910
"go.temporal.io/api/serviceerror"
1011
"go.temporal.io/server/common/persistence/serialization"
@@ -28,24 +29,24 @@ func (e *ErrorWithoutStatus) Error() string {
2829
// Error returns string message.
2930
func TestServiceErrorInterceptorUnknown(t *testing.T) {
3031

31-
_, err := ServiceErrorInterceptor(context.Background(), nil, nil,
32+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
3233
func(ctx context.Context, req any) (any, error) {
3334
return nil, status.Error(codes.InvalidArgument, "invalid argument")
3435
})
3536

36-
assert.Error(t, err)
37-
assert.Equal(t, codes.InvalidArgument, status.Code(err))
37+
require.Error(t, err)
38+
require.Equal(t, codes.InvalidArgument, status.Code(err))
3839

39-
_, err = ServiceErrorInterceptor(context.Background(), nil, nil,
40+
_, err = ServiceErrorInterceptor(t.Context(), nil, nil,
4041
func(ctx context.Context, req any) (any, error) {
4142
errWithoutStatus := &ErrorWithoutStatus{
4243
Message: "unknown error without status",
4344
}
4445
return nil, errWithoutStatus
4546
})
4647

47-
assert.Error(t, err)
48-
assert.Equal(t, codes.Unknown, status.Code(err))
48+
require.Error(t, err)
49+
require.Equal(t, codes.Unknown, status.Code(err))
4950
}
5051

5152
func TestServiceErrorInterceptorSer(t *testing.T) {
@@ -54,10 +55,87 @@ func TestServiceErrorInterceptorSer(t *testing.T) {
5455
serialization.NewSerializationError(enumspb.ENCODING_TYPE_PROTO3, nil),
5556
}
5657
for _, inErr := range serErrors {
57-
_, err := ServiceErrorInterceptor(context.Background(), nil, nil,
58+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
5859
func(_ context.Context, _ any) (any, error) {
5960
return nil, inErr
6061
})
61-
assert.Equal(t, serviceerror.ToStatus(err).Code(), codes.DataLoss)
62+
require.Equal(t, codes.DataLoss, serviceerror.ToStatus(err).Code())
6263
}
6364
}
65+
66+
func TestServiceErrorInterceptorTruncation(t *testing.T) {
67+
t.Run("nil error is not affected", func(t *testing.T) {
68+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
69+
func(_ context.Context, _ any) (any, error) {
70+
return "ok", nil
71+
})
72+
require.NoError(t, err)
73+
})
74+
75+
t.Run("short message is not truncated", func(t *testing.T) {
76+
msg := "short error"
77+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
78+
func(_ context.Context, _ any) (any, error) {
79+
return nil, serviceerror.NewInternal(msg)
80+
})
81+
require.Error(t, err)
82+
st := status.Convert(err)
83+
require.Equal(t, msg, st.Message())
84+
})
85+
86+
t.Run("message at exact limit is not truncated", func(t *testing.T) {
87+
msg := strings.Repeat("a", maxMessageLength)
88+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
89+
func(_ context.Context, _ any) (any, error) {
90+
return nil, serviceerror.NewInternal(msg)
91+
})
92+
require.Error(t, err)
93+
st := status.Convert(err)
94+
require.Equal(t, msg, st.Message())
95+
})
96+
97+
t.Run("message over limit is truncated", func(t *testing.T) {
98+
msg := strings.Repeat("a", maxMessageLength+100)
99+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
100+
func(_ context.Context, _ any) (any, error) {
101+
return nil, serviceerror.NewInternal(msg)
102+
})
103+
require.Error(t, err)
104+
st := status.Convert(err)
105+
require.LessOrEqual(t, len(st.Message()), maxMessageLength)
106+
require.True(t, strings.HasSuffix(st.Message(), truncatedSuffix))
107+
})
108+
109+
t.Run("truncation preserves error code", func(t *testing.T) {
110+
msg := strings.Repeat("x", maxMessageLength+500)
111+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
112+
func(_ context.Context, _ any) (any, error) {
113+
return nil, serviceerror.NewNotFound(msg)
114+
})
115+
require.Error(t, err)
116+
require.Equal(t, codes.NotFound, status.Code(err))
117+
})
118+
119+
t.Run("truncation respects multi-byte UTF-8 boundary", func(t *testing.T) {
120+
// Fill up to near the limit with multi-byte characters (3 bytes each for '€')
121+
// then push over the limit so truncation must split within the repeated chars.
122+
euroCount := maxMessageLength / len("€") // each '€' is 3 bytes
123+
msg := strings.Repeat("€", euroCount+100)
124+
_, err := ServiceErrorInterceptor(t.Context(), nil, nil,
125+
func(_ context.Context, _ any) (any, error) {
126+
return nil, serviceerror.NewInternal(msg)
127+
})
128+
require.Error(t, err)
129+
st := status.Convert(err)
130+
require.LessOrEqual(t, len(st.Message()), maxMessageLength)
131+
require.True(t, strings.HasSuffix(st.Message(), truncatedSuffix))
132+
// Verify the truncated body (without suffix) is valid UTF-8 by checking
133+
// that no partial rune was left behind — the full message should be valid.
134+
require.True(t, strings.HasSuffix(st.Message(), truncatedSuffix))
135+
body := strings.TrimSuffix(st.Message(), truncatedSuffix)
136+
// Every character in body should be '€' (no partial runes).
137+
for _, r := range body {
138+
require.Equal(t, '€', r)
139+
}
140+
})
141+
}

0 commit comments

Comments
 (0)