diff --git a/providers-sdk/v1/plugin/grpc.go b/providers-sdk/v1/plugin/grpc.go index a2cefd04d2..b325fe5220 100644 --- a/providers-sdk/v1/plugin/grpc.go +++ b/providers-sdk/v1/plugin/grpc.go @@ -6,12 +6,16 @@ package plugin import ( "bytes" "context" + "runtime/debug" "unicode/utf8" plugin "github.com/hashicorp/go-plugin" + "github.com/rs/zerolog/log" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/encoding" _ "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/status" ) func init() { @@ -83,6 +87,18 @@ func (m *GRPCClient) StoreData(req *StoreReq) (*StoreRes, error) { return m.client.StoreData(context.Background(), req) } +// recoverPanic converts a panic into a gRPC Internal error. The full stack +// trace is logged locally; only a short message is sent over the wire. +// The message is prefixed with "panic in provider " so the caller can +// distinguish recovered panics from other Internal errors. +func recoverPanic(method string, retErr *error) { + if r := recover(); r != nil { + stack := debug.Stack() + log.Error().Str("method", method).Interface("panic", r).Str("stack", string(stack)).Msg("recovered panic in provider") + *retErr = status.Errorf(codes.Internal, "panic in provider %s: %v", method, r) + } +} + // Here is the gRPC server that GRPCClient talks to. type GRPCServer struct { // This is the real implementation @@ -95,11 +111,13 @@ func (m *GRPCServer) Heartbeat(ctx context.Context, req *HeartbeatReq) (*Heartbe return m.Impl.Heartbeat(req) } -func (m *GRPCServer) ParseCLI(ctx context.Context, req *ParseCLIReq) (*ParseCLIRes, error) { +func (m *GRPCServer) ParseCLI(ctx context.Context, req *ParseCLIReq) (resp *ParseCLIRes, err error) { + defer recoverPanic("ParseCLI", &err) return m.Impl.ParseCLI(req) } -func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (*ConnectRes, error) { +func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (resp *ConnectRes, err error) { + defer recoverPanic("Connect", &err) conn, err := m.broker.Dial(req.CallbackServer) if err != nil { return nil, err @@ -112,11 +130,13 @@ func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (*ConnectRes, return m.Impl.Connect(req, a) } -func (m *GRPCServer) Disconnect(ctx context.Context, req *DisconnectReq) (*DisconnectRes, error) { +func (m *GRPCServer) Disconnect(ctx context.Context, req *DisconnectReq) (resp *DisconnectRes, err error) { + defer recoverPanic("Disconnect", &err) return m.Impl.Disconnect(req) } -func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (*ConnectRes, error) { +func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (resp *ConnectRes, err error) { + defer recoverPanic("MockConnect", &err) conn, err := m.broker.Dial(req.CallbackServer) if err != nil { return nil, err @@ -129,12 +149,14 @@ func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (*Connect return m.Impl.MockConnect(req, a) } -func (m *GRPCServer) Shutdown(ctx context.Context, req *ShutdownReq) (*ShutdownRes, error) { +func (m *GRPCServer) Shutdown(ctx context.Context, req *ShutdownReq) (resp *ShutdownRes, err error) { + defer recoverPanic("Shutdown", &err) return m.Impl.Shutdown(req) } -func (m *GRPCServer) GetData(ctx context.Context, req *DataReq) (*DataRes, error) { - resp, err := m.Impl.GetData(req) +func (m *GRPCServer) GetData(ctx context.Context, req *DataReq) (resp *DataRes, err error) { + defer recoverPanic("GetData", &err) + resp, err = m.Impl.GetData(req) if err != nil { return nil, err } @@ -142,7 +164,8 @@ func (m *GRPCServer) GetData(ctx context.Context, req *DataReq) (*DataRes, error return resp, nil } -func (m *GRPCServer) StoreData(ctx context.Context, req *StoreReq) (*StoreRes, error) { +func (m *GRPCServer) StoreData(ctx context.Context, req *StoreReq) (resp *StoreRes, err error) { + defer recoverPanic("StoreData", &err) return m.Impl.StoreData(req) } diff --git a/providers-sdk/v1/plugin/grpc_test.go b/providers-sdk/v1/plugin/grpc_test.go new file mode 100644 index 0000000000..de1a14827a --- /dev/null +++ b/providers-sdk/v1/plugin/grpc_test.go @@ -0,0 +1,55 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package plugin + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestRecoverPanic(t *testing.T) { + t.Run("recovers string panic", func(t *testing.T) { + var err error + func() { + defer recoverPanic("TestMethod", &err) + panic("something went wrong") + }() + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Equal(t, "panic in provider TestMethod: something went wrong", st.Message()) + assert.NotContains(t, st.Message(), "goroutine") // stack trace stays in log, not on the wire + }) + + t.Run("recovers nil pointer panic", func(t *testing.T) { + var err error + func() { + defer recoverPanic("GetData", &err) + var s *string + _ = *s // nil pointer dereference + }() + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) + assert.Contains(t, st.Message(), "panic in provider GetData") + }) + + t.Run("no panic leaves error nil", func(t *testing.T) { + var err error + func() { + defer recoverPanic("TestMethod", &err) + // no panic + }() + + assert.NoError(t, err) + }) +} diff --git a/providers/runtime.go b/providers/runtime.go index 5cc8ec7841..41bbf0a745 100644 --- a/providers/runtime.go +++ b/providers/runtime.go @@ -5,6 +5,7 @@ package providers import ( "errors" + "strings" "sync" "time" @@ -18,6 +19,7 @@ import ( "go.mondoo.com/mql/v13/types" "go.mondoo.com/mql/v13/utils/multierr" "go.mondoo.com/mql/v13/utils/stringx" + "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -492,8 +494,17 @@ func (r *Runtime) handlePluginError(err error, provider *ConnectedProvider) (boo } switch st.Code() { - case 14: - // Error: Unavailable. Happens when the plugin crashes. + case codes.Internal: + // A recovered panic in the provider sends an Internal error prefixed + // with "panic in provider ". Only apply panic-specific handling when + // this prefix is present; other Internal errors fall through. + if strings.HasPrefix(st.Message(), "panic in provider ") { + log.Error().Str("provider", provider.Instance.Name).Msg(st.Message()) + return true, errors.New("the '" + provider.Instance.Name + "' provider panicked: " + st.Message()) + } + + case codes.Unavailable: + // Happens when the plugin crashes. // TODO: try to restart the plugin and reset its connections provider.Instance.isClosed = true provider.Instance.err = errors.New("the '" + provider.Instance.Name + "' provider crashed: " + err.Error())