Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions providers-sdk/v1/plugin/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ package plugin
import (
"bytes"
"context"
"runtime/debug"
"unicode/utf8"

plugin "github.com/hashicorp/go-plugin"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/status"
)

func init() {
Expand Down Expand Up @@ -83,6 +87,18 @@ func (m *GRPCClient) StoreData(req *StoreReq) (*StoreRes, error) {
return m.client.StoreData(context.Background(), req)
}

// recoverPanic converts a panic into a gRPC Internal error. The full stack
// trace is logged locally; only a short message is sent over the wire.
// The message is prefixed with "panic in provider " so the caller can
// distinguish recovered panics from other Internal errors.
func recoverPanic(method string, retErr *error) {
if r := recover(); r != nil {
stack := debug.Stack()
log.Error().Str("method", method).Interface("panic", r).Str("stack", string(stack)).Msg("recovered panic in provider")
*retErr = status.Errorf(codes.Internal, "panic in provider %s: %v", method, r)
}
}

// Here is the gRPC server that GRPCClient talks to.
type GRPCServer struct {
// This is the real implementation
Expand All @@ -95,11 +111,13 @@ func (m *GRPCServer) Heartbeat(ctx context.Context, req *HeartbeatReq) (*Heartbe
return m.Impl.Heartbeat(req)
}

func (m *GRPCServer) ParseCLI(ctx context.Context, req *ParseCLIReq) (*ParseCLIRes, error) {
func (m *GRPCServer) ParseCLI(ctx context.Context, req *ParseCLIReq) (resp *ParseCLIRes, err error) {
defer recoverPanic("ParseCLI", &err)
return m.Impl.ParseCLI(req)
}

func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (*ConnectRes, error) {
func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (resp *ConnectRes, err error) {
defer recoverPanic("Connect", &err)
conn, err := m.broker.Dial(req.CallbackServer)
if err != nil {
return nil, err
Expand All @@ -112,11 +130,13 @@ func (m *GRPCServer) Connect(ctx context.Context, req *ConnectReq) (*ConnectRes,
return m.Impl.Connect(req, a)
}

func (m *GRPCServer) Disconnect(ctx context.Context, req *DisconnectReq) (*DisconnectRes, error) {
func (m *GRPCServer) Disconnect(ctx context.Context, req *DisconnectReq) (resp *DisconnectRes, err error) {
defer recoverPanic("Disconnect", &err)
return m.Impl.Disconnect(req)
}

func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (*ConnectRes, error) {
func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (resp *ConnectRes, err error) {
defer recoverPanic("MockConnect", &err)
conn, err := m.broker.Dial(req.CallbackServer)
if err != nil {
return nil, err
Expand All @@ -129,20 +149,23 @@ func (m *GRPCServer) MockConnect(ctx context.Context, req *ConnectReq) (*Connect
return m.Impl.MockConnect(req, a)
}

func (m *GRPCServer) Shutdown(ctx context.Context, req *ShutdownReq) (*ShutdownRes, error) {
func (m *GRPCServer) Shutdown(ctx context.Context, req *ShutdownReq) (resp *ShutdownRes, err error) {
defer recoverPanic("Shutdown", &err)
return m.Impl.Shutdown(req)
}

func (m *GRPCServer) GetData(ctx context.Context, req *DataReq) (*DataRes, error) {
resp, err := m.Impl.GetData(req)
func (m *GRPCServer) GetData(ctx context.Context, req *DataReq) (resp *DataRes, err error) {
defer recoverPanic("GetData", &err)
resp, err = m.Impl.GetData(req)
if err != nil {
return nil, err
}
sanitizeDataRes(resp)
return resp, nil
}

func (m *GRPCServer) StoreData(ctx context.Context, req *StoreReq) (*StoreRes, error) {
func (m *GRPCServer) StoreData(ctx context.Context, req *StoreReq) (resp *StoreRes, err error) {
defer recoverPanic("StoreData", &err)
return m.Impl.StoreData(req)
}

Expand Down
55 changes: 55 additions & 0 deletions providers-sdk/v1/plugin/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package plugin

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestRecoverPanic(t *testing.T) {
t.Run("recovers string panic", func(t *testing.T) {
var err error
func() {
defer recoverPanic("TestMethod", &err)
panic("something went wrong")
}()

require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Internal, st.Code())
assert.Equal(t, "panic in provider TestMethod: something went wrong", st.Message())
assert.NotContains(t, st.Message(), "goroutine") // stack trace stays in log, not on the wire
})

t.Run("recovers nil pointer panic", func(t *testing.T) {
var err error
func() {
defer recoverPanic("GetData", &err)
var s *string
_ = *s // nil pointer dereference
}()

require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.Internal, st.Code())
assert.Contains(t, st.Message(), "panic in provider GetData")
})

t.Run("no panic leaves error nil", func(t *testing.T) {
var err error
func() {
defer recoverPanic("TestMethod", &err)
// no panic
}()

assert.NoError(t, err)
})
}
15 changes: 13 additions & 2 deletions providers/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package providers

import (
"errors"
"strings"
"sync"
"time"

Expand All @@ -18,6 +19,7 @@ import (
"go.mondoo.com/mql/v13/types"
"go.mondoo.com/mql/v13/utils/multierr"
"go.mondoo.com/mql/v13/utils/stringx"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

Expand Down Expand Up @@ -492,8 +494,17 @@ func (r *Runtime) handlePluginError(err error, provider *ConnectedProvider) (boo
}

switch st.Code() {
case 14:
// Error: Unavailable. Happens when the plugin crashes.
case codes.Internal:
// A recovered panic in the provider sends an Internal error prefixed
// with "panic in provider ". Only apply panic-specific handling when
// this prefix is present; other Internal errors fall through.
if strings.HasPrefix(st.Message(), "panic in provider ") {
log.Error().Str("provider", provider.Instance.Name).Msg(st.Message())
return true, errors.New("the '" + provider.Instance.Name + "' provider panicked: " + st.Message())
}

case codes.Unavailable:
// Happens when the plugin crashes.
// TODO: try to restart the plugin and reset its connections
provider.Instance.isClosed = true
provider.Instance.err = errors.New("the '" + provider.Instance.Name + "' provider crashed: " + err.Error())
Expand Down
Loading