Skip to content

Commit c9d4a48

Browse files
Merge pull request #682 from atlassian/msg/api-exposure
Feature: Expose Lambda Extension package
2 parents e6fd74a + 592cc1f commit c9d4a48

File tree

4 files changed

+286
-0
lines changed

4 files changed

+286
-0
lines changed

pkg/lambda/extension.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package lambda
2+
3+
import (
4+
"github.com/sirupsen/logrus"
5+
6+
"github.com/atlassian/gostatsd/internal/awslambda/extension"
7+
"github.com/atlassian/gostatsd/internal/flush"
8+
"github.com/atlassian/gostatsd/pkg/statsd"
9+
)
10+
11+
// Extension exposes the internal type definition.
12+
//
13+
// Note: This is not ideal but simplifies exposing the API.
14+
type Extension = extension.Server
15+
16+
// NewExtension extends a `statsd.Server` by adding a server that integrates with
17+
// the AWS Lambda Extension Runtime API to allow near native support for on host collection.
18+
// The provided server is then run within the extension in additional to the server for AWS integration.
19+
func NewExtension(logger logrus.FieldLogger, server *statsd.Server, opts Options) (Extension, error) {
20+
if err := opts.Validate(); err != nil {
21+
return nil, err
22+
}
23+
24+
s := *server
25+
26+
var extOpts []extension.ManagerOpt
27+
if opts.EnableManualFlush {
28+
fc := flush.NewFlushCoordinator()
29+
s.ForwarderFlushCoordinator = fc
30+
extOpts = append(extOpts, extension.WithManualFlushEnabled(fc, opts.TelemetryAddr))
31+
}
32+
33+
return extension.NewManager(
34+
opts.RuntimeAPI,
35+
opts.ExecutableName,
36+
logger,
37+
&s,
38+
extOpts...,
39+
), nil
40+
}

pkg/lambda/extension_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package lambda
2+
3+
import (
4+
"testing"
5+
6+
"github.com/sirupsen/logrus"
7+
"github.com/stretchr/testify/assert"
8+
9+
"github.com/atlassian/gostatsd/pkg/statsd"
10+
)
11+
12+
func TestNewExtension(t *testing.T) {
13+
t.Parallel()
14+
15+
for _, tc := range []struct {
16+
name string
17+
options Options
18+
errVal string
19+
}{
20+
{
21+
name: "no options value defined",
22+
options: Options{},
23+
errVal: "missing `RuntimeAPI` value; missing `ExecutableName` value",
24+
},
25+
{
26+
name: "valid options provided",
27+
options: Options{
28+
RuntimeAPI: "runtime-api",
29+
ExecutableName: "bin",
30+
},
31+
errVal: "",
32+
},
33+
{
34+
name: "valid options provided with enable manual flush",
35+
options: Options{
36+
RuntimeAPI: "runtime-api",
37+
ExecutableName: "bin",
38+
EnableManualFlush: true,
39+
TelemetryAddr: ":0",
40+
},
41+
errVal: "",
42+
},
43+
} {
44+
tc := tc
45+
t.Run(tc.name, func(t *testing.T) {
46+
t.Parallel()
47+
48+
actual, err := NewExtension(logrus.New(), &statsd.Server{}, tc.options)
49+
if tc.errVal != "" {
50+
assert.Nil(t, actual, "Must not be a valid extension")
51+
assert.EqualError(t, err, tc.errVal)
52+
} else {
53+
assert.NotNil(t, actual, "Must be a valid extension")
54+
assert.NoError(t, err, "Must not error")
55+
}
56+
})
57+
}
58+
}

pkg/lambda/options.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package lambda
2+
3+
import (
4+
"errors"
5+
"os"
6+
"path"
7+
8+
"github.com/spf13/pflag"
9+
"go.uber.org/multierr"
10+
11+
"github.com/atlassian/gostatsd/internal/awslambda/extension/api"
12+
)
13+
14+
const (
15+
ParamRuntimeAPI = `lambda-runtime-api`
16+
ParamEntryPointName = `lamda-entrypoint-name`
17+
ParamEnableManualFlush = `enable-manual-flush`
18+
ParamTelemetryAddr = `telemetry-addr`
19+
)
20+
21+
// Options is used to provide overrides when calling `NewExtension`
22+
type Options struct {
23+
// RuntimeAPI is AWS service that is responsible for orchestrating
24+
// the extension with the other Lambda's that are part of the deployment.
25+
// This will set to `env:AWS_LAMBDA_RUNTIME_API` by default as per the docs,
26+
// but can be overriden for testing purposes and validation.
27+
RuntimeAPI string
28+
// ExecutableName is the full file name of the lambda extension that is
29+
// used to validate the bootstrap process.
30+
ExecutableName string
31+
// EnableManualFlush allows for Lambda's
32+
// to control when the forwarder would send the accumlated metrics.
33+
EnableManualFlush bool
34+
// TelemetryAddr is used by the AWS runtime to publish
35+
// events back to service.
36+
TelemetryAddr string
37+
}
38+
39+
func NewOptionsFromEnvironment() Options {
40+
return Options{
41+
RuntimeAPI: os.Getenv(api.EnvLambdaAPIHostname),
42+
ExecutableName: path.Base(os.Args[0]),
43+
EnableManualFlush: false,
44+
TelemetryAddr: "http://sandbox:8083",
45+
}
46+
}
47+
48+
// AddFlags is used to add preconfigured entries
49+
// into an existing `FlagSet`
50+
func (o *Options) AddFlags(fs *pflag.FlagSet) {
51+
fs.StringVar(
52+
&o.RuntimeAPI,
53+
ParamRuntimeAPI,
54+
o.RuntimeAPI,
55+
"Sets the runtime api that is used to contact the lambda runtime service",
56+
)
57+
fs.StringVar(
58+
&o.ExecutableName,
59+
ParamEntryPointName,
60+
o.ExecutableName,
61+
"Sets the name of the executable name used within the bootstrap process.",
62+
)
63+
fs.BoolVar(
64+
&o.EnableManualFlush,
65+
ParamEnableManualFlush,
66+
o.EnableManualFlush,
67+
"When set, enables lambda(s) to force the forwarder to send data. Useful in low volume lambdas",
68+
)
69+
fs.StringVar(
70+
&o.TelemetryAddr,
71+
ParamTelemetryAddr,
72+
o.TelemetryAddr,
73+
"Allows for callbacks to force manual flushing of accumulated metrics",
74+
)
75+
}
76+
77+
// Validate ensure all the values are valid,
78+
// any values that not are reported as errors.
79+
// All invalid values are reported together.
80+
func (o Options) Validate() (errs error) {
81+
if o.RuntimeAPI == "" {
82+
errs = multierr.Append(errs, errors.New("missing `RuntimeAPI` value"))
83+
}
84+
if o.ExecutableName == "" {
85+
errs = multierr.Append(errs, errors.New("missing `ExecutableName` value"))
86+
}
87+
if o.EnableManualFlush && o.TelemetryAddr == "" {
88+
errs = multierr.Append(errs, errors.New("missing `TelemetryAddr` when `EnableManualFlush` is enabled"))
89+
}
90+
return errs
91+
}

pkg/lambda/options_test.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package lambda
2+
3+
import (
4+
"os"
5+
"path"
6+
"testing"
7+
8+
"github.com/spf13/pflag"
9+
"github.com/stretchr/testify/assert"
10+
11+
"github.com/atlassian/gostatsd/internal/awslambda/extension/api"
12+
)
13+
14+
func TestDefaultOptions(t *testing.T) {
15+
t.Setenv(api.EnvLambdaAPIHostname, "example")
16+
17+
expected := Options{
18+
RuntimeAPI: "example",
19+
ExecutableName: path.Base(os.Args[0]),
20+
EnableManualFlush: false,
21+
TelemetryAddr: "http://sandbox:8083",
22+
}
23+
24+
assert.Equal(t, expected, NewOptionsFromEnvironment(), "Must match the expected value")
25+
}
26+
27+
func TestOptionsAddFlag(t *testing.T) {
28+
t.Parallel()
29+
30+
o := &Options{
31+
RuntimeAPI: "my-awesome-api",
32+
ExecutableName: "bin",
33+
EnableManualFlush: true,
34+
TelemetryAddr: ":8089",
35+
}
36+
37+
fs := pflag.NewFlagSet(t.Name(), pflag.ContinueOnError)
38+
o.AddFlags(fs)
39+
40+
for _, flag := range []string{
41+
ParamRuntimeAPI,
42+
ParamEntryPointName,
43+
ParamEnableManualFlush,
44+
ParamTelemetryAddr,
45+
} {
46+
assert.NotNil(
47+
t,
48+
fs.Lookup(flag),
49+
"Must have a valid entry for expected flag name %q",
50+
flag,
51+
)
52+
}
53+
54+
}
55+
56+
func TestOptionsValidate(t *testing.T) {
57+
t.Parallel()
58+
59+
for _, tc := range []struct {
60+
name string
61+
options Options
62+
errVal string
63+
}{
64+
{
65+
name: "empty options",
66+
options: Options{},
67+
errVal: "missing `RuntimeAPI` value; missing `ExecutableName` value",
68+
},
69+
{
70+
name: "enabled manual flushes",
71+
options: Options{
72+
EnableManualFlush: true,
73+
},
74+
errVal: "missing `RuntimeAPI` value; missing `ExecutableName` value; missing `TelemetryAddr` when `EnableManualFlush` is enabled",
75+
},
76+
{
77+
name: "(simulated) Default options",
78+
options: Options{
79+
RuntimeAPI: "runtime-api",
80+
ExecutableName: "bin",
81+
},
82+
errVal: "",
83+
},
84+
} {
85+
tc := tc
86+
t.Run(tc.name, func(t *testing.T) {
87+
t.Parallel()
88+
89+
err := tc.options.Validate()
90+
if tc.errVal != "" {
91+
assert.EqualError(t, err, tc.errVal, "Must match the expected error message")
92+
} else {
93+
assert.NoError(t, err, "Must not error")
94+
}
95+
})
96+
}
97+
}

0 commit comments

Comments
 (0)