Skip to content

Commit 66150c5

Browse files
authored
better support for running multiple separate brokers (#85)
* better support for HA brokers * slightly better genkey/pubkey impl
1 parent 62bfa29 commit 66150c5

8 files changed

+96
-47
lines changed

cmd/dump.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ var dumpCmd = &cobra.Command{
1313
Use: "dump",
1414
Short: "Dump current config",
1515
Run: func(cmd *cobra.Command, args []string) {
16-
config, err := pkg.LoadConfig(configFiles, deploymentId)
16+
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
1717
if err != nil {
1818
log.Panic(err)
1919
}

cmd/genkey.go

+23-5
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,43 @@ package cmd
33
import (
44
"encoding/base64"
55
"fmt"
6+
"os"
67

78
log "github.com/sirupsen/logrus"
89
"github.com/spf13/cobra"
910
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1011
)
1112

13+
var replicaCount int
14+
15+
const defaultReplicaCount = 3
16+
const minReplicaCount = 1
17+
const maxReplicaCount = 16
18+
1219
var genkeyCmd = &cobra.Command{
1320
Use: "genkey",
14-
Short: "Generates a random private key in base64 and prints it to stdout",
21+
Short: "Generates a random Semgrep Network Broker private key and prints it to stdout.",
1522
Run: func(cmd *cobra.Command, args []string) {
16-
privateKey, err := wgtypes.GeneratePrivateKey()
17-
if err != nil {
18-
log.Panic(fmt.Errorf("failed to generate private key: %v", err))
23+
if replicaCount < minReplicaCount || replicaCount > maxReplicaCount {
24+
log.Panic(fmt.Errorf("replica count must be between %v and %v", minReplicaCount, maxReplicaCount))
1925
}
2026

21-
fmt.Println(base64.StdEncoding.EncodeToString(privateKey[:]))
27+
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
28+
defer encoder.Close()
29+
30+
for i := 0; i < replicaCount; i++ {
31+
privateKey, err := wgtypes.GeneratePrivateKey()
32+
if err != nil {
33+
log.Panic(fmt.Errorf("failed to generate private key %v: %v", i, err))
34+
}
35+
if _, err := encoder.Write(privateKey[:]); err != nil {
36+
log.Panic(fmt.Errorf("failed to write private key %v: %v", i, err))
37+
}
38+
}
2239
},
2340
}
2441

2542
func init() {
43+
genkeyCmd.PersistentFlags().IntVarP(&replicaCount, "replica-count", "r", defaultReplicaCount, "Number of broker replicas to support")
2644
rootCmd.AddCommand(genkeyCmd)
2745
}

cmd/pubkey.go

+26-21
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,40 @@ import (
88

99
log "github.com/sirupsen/logrus"
1010
"github.com/spf13/cobra"
11+
"golang.zx2c4.com/wireguard/device"
1112
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1213
)
1314

1415
var pubkeyCmd = &cobra.Command{
1516
Use: "pubkey",
16-
Short: "Reads a base64 private key from stdin, outputs the corresponding base64 public key",
17+
Short: "Reads a Semgrep Network Broker private key from stdin and ptints the corresponding public key to stdout.",
1718
Run: func(cmd *cobra.Command, args []string) {
18-
keyBase64, err := io.ReadAll(os.Stdin)
19-
if err != nil {
20-
log.Panic(err)
21-
}
22-
23-
keyBytes := make([]byte, 32)
24-
n, err := base64.StdEncoding.Decode(keyBytes, keyBase64)
25-
if err != nil {
26-
log.Panic(err)
27-
}
28-
if n != 32 {
29-
log.Panic("not enough bytes")
30-
}
3119

32-
privateKey, err := wgtypes.NewKey(keyBytes)
33-
if err != nil {
34-
log.Panic(err)
20+
decoder := base64.NewDecoder(base64.StdEncoding, os.Stdin)
21+
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
22+
defer encoder.Close()
23+
24+
privateKeyBytes := make([]byte, device.NoisePrivateKeySize)
25+
26+
for i := 0; ; i++ {
27+
_, err := io.ReadFull(decoder, privateKeyBytes)
28+
if err != nil {
29+
if err == io.EOF {
30+
break
31+
} else {
32+
log.Panic(fmt.Errorf("error reading private key %v: %v", i, err))
33+
}
34+
}
35+
privateKey, err := wgtypes.NewKey(privateKeyBytes)
36+
if err != nil {
37+
log.Panic(fmt.Errorf("error creating private key %v: %v", i, err))
38+
}
39+
40+
publicKey := privateKey.PublicKey()
41+
if _, err := encoder.Write(publicKey[:]); err != nil {
42+
log.Panic(fmt.Errorf("error writing public key %v: %v", i, err))
43+
}
3544
}
36-
37-
publicKey := privateKey.PublicKey()
38-
39-
fmt.Println(base64.StdEncoding.EncodeToString(publicKey[:]))
4045
},
4146
}
4247

cmd/relay.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ var relayCmd = &cobra.Command{
3030
}()
3131

3232
// load config(s)
33-
config, err := pkg.LoadConfig(configFiles, 0)
33+
config, err := pkg.LoadConfig(configFiles, 0, 0)
3434
if err != nil {
3535
log.Panic(err)
3636
}

cmd/root.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
var configFiles []string
1919
var jsonLog bool
2020
var deploymentId int
21+
var brokerIndex int
2122

2223
var rootCmd = &cobra.Command{
2324
Use: "semgrep-network-broker",
@@ -39,7 +40,7 @@ var rootCmd = &cobra.Command{
3940
}()
4041

4142
// load config(s)
42-
config, err := pkg.LoadConfig(configFiles, deploymentId)
43+
config, err := pkg.LoadConfig(configFiles, deploymentId, brokerIndex)
4344
if err != nil {
4445
log.Panic(err)
4546
}
@@ -75,7 +76,7 @@ func StartNetworkBroker(config *pkg.Config) (func() error, error) {
7576
return wireguardTeardown()
7677
}
7778

78-
// start inbound proxy (r2c --> customer)
79+
// start inbound proxy (semgrep --> customer)
7980
if err := config.Inbound.Start(tnet); err != nil {
8081
teardown()
8182
return nil, fmt.Errorf("failed to start inbound proxy: %v", err)
@@ -95,4 +96,5 @@ func init() {
9596
rootCmd.PersistentFlags().StringArrayVarP(&configFiles, "config", "c", nil, "config file(s)")
9697
rootCmd.PersistentFlags().BoolVarP(&jsonLog, "json-log", "j", false, "JSON log output")
9798
rootCmd.PersistentFlags().IntVarP(&deploymentId, "deployment-id", "d", 0, "Semgrep deployment ID")
99+
rootCmd.PersistentFlags().IntVarP(&brokerIndex, "broker-index", "i", 0, "Semgrep network broker index")
98100
}

pkg/config.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9+
"net/netip"
910
"net/url"
1011
"os"
1112
"reflect"
@@ -72,13 +73,15 @@ type WireguardPeer struct {
7273
}
7374

7475
type WireguardBase struct {
75-
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
76-
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
77-
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
78-
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
79-
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
80-
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
81-
Verbose bool `mapstructure:"verbose" json:"verbose"`
76+
resolvedLocalAddress netip.Addr
77+
LocalAddress string `mapstructure:"localAddress" json:"localAddress" validate:"format=ip"`
78+
Dns []string `mapstructure:"dns" json:"dns" validate:"empty=true > format=ip"`
79+
Mtu int `mapstructure:"mtu" json:"mtu" validate:"gte=0" default:"1420"`
80+
PrivateKey SensitiveBase64String `mapstructure:"privateKey" json:"privateKey" validate:"empty=false"`
81+
ListenPort int `mapstructure:"listenPort" json:"listenPort" validate:"gte=0"`
82+
Peers []WireguardPeer `mapstructure:"peers" json:"peers" validate:"empty=false"`
83+
Verbose bool `mapstructure:"verbose" json:"verbose"`
84+
BrokerIndex int `mapstructure:"brokerIndex" json:"brokerIndex" validate:"gte=0"`
8285
}
8386

8487
type BitTester interface {
@@ -262,9 +265,11 @@ type Config struct {
262265
Outbound OutboundProxyConfig `mapstructure:"outbound" json:"outbound"`
263266
}
264267

265-
func LoadConfig(configFiles []string, deploymentId int) (*Config, error) {
268+
func LoadConfig(configFiles []string, deploymentId int, brokerIndex int) (*Config, error) {
266269
config := new(Config)
267270

271+
config.Inbound.Wireguard.BrokerIndex = brokerIndex
272+
268273
if deploymentId > 0 {
269274
hostname := os.Getenv("SEMGREP_HOSTNAME")
270275
if hostname == "" {

pkg/config_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func TestEmptyConfigs(t *testing.T) {
14-
config, err := LoadConfig(nil, 0)
14+
config, err := LoadConfig(nil, 0, 0)
1515
if err != nil {
1616
t.Error(err)
1717
}

pkg/wireguard.go

+27-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package pkg
22

33
import (
44
"encoding/hex"
5+
"errors"
56
"fmt"
67
"io"
78
"math/rand"
@@ -34,10 +35,22 @@ func (peer WireguardPeer) WriteTo(sb io.StringWriter) {
3435
}
3536
}
3637

38+
func (base WireguardBase) Validate() error {
39+
privateKeyCount := len(base.PrivateKey) / device.NoisePrivateKeySize
40+
41+
if base.BrokerIndex >= privateKeyCount {
42+
return errors.New("broker index beyond private key count")
43+
}
44+
45+
return nil
46+
}
47+
3748
func (base WireguardBase) GenerateConfig() string {
3849
sb := strings.Builder{}
3950

40-
sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(base.PrivateKey)))
51+
indexedPrivateKey := base.PrivateKey[device.NoisePrivateKeySize*base.BrokerIndex : device.NoisePrivateKeySize*(base.BrokerIndex+1)]
52+
53+
sb.WriteString(fmt.Sprintf("private_key=%s\n", hex.EncodeToString(indexedPrivateKey)))
4154
sb.WriteString(fmt.Sprintf("listen_port=%d\n", base.ListenPort))
4255

4356
for i := range base.Peers {
@@ -47,7 +60,16 @@ func (base WireguardBase) GenerateConfig() string {
4760
return sb.String()
4861
}
4962

50-
func (base *WireguardBase) ResolvePeerEndpoints() error {
63+
func (base *WireguardBase) ResolveConfig() error {
64+
resolvedLocalAddress, err := netip.ParseAddr(base.LocalAddress)
65+
if err != nil {
66+
return fmt.Errorf("LocalAddress parse failed: %v", err)
67+
}
68+
for i := 0; i < base.BrokerIndex; i++ {
69+
resolvedLocalAddress = resolvedLocalAddress.Next()
70+
}
71+
base.resolvedLocalAddress = resolvedLocalAddress
72+
5173
for i := range base.Peers {
5274
if base.Peers[i].Endpoint == "" {
5375
continue
@@ -77,22 +99,19 @@ func (config *WireguardBase) Start() (*netstack.Net, func() error, error) {
7799
return nil, nil, fmt.Errorf("invalid wireguard config: %v", err)
78100
}
79101

80-
// resolve peer endpoints (if not IP address already)
81-
if err := config.ResolvePeerEndpoints(); err != nil {
102+
// resolve local address and peer endpoints (if not IP address already)
103+
if err := config.ResolveConfig(); err != nil {
82104
return nil, nil, fmt.Errorf("failed to resolve peer endpoint: %v", err)
83105
}
84106

85-
// parse localAddres and DNS addresses -- MustParseAddr is fine here because we've already validated the config
86-
localAddress := netip.MustParseAddr(config.LocalAddress)
87-
88107
var dnsAddresses = make([]netip.Addr, len(config.Dns))
89108
for i := range config.Dns {
90109
dnsAddresses[i] = netip.MustParseAddr(config.Dns[i])
91110
}
92111

93112
// create the wireguard interface
94113
tun, tnet, err := netstack.CreateNetTUN(
95-
[]netip.Addr{localAddress},
114+
[]netip.Addr{config.resolvedLocalAddress},
96115
dnsAddresses,
97116
config.Mtu,
98117
)

0 commit comments

Comments
 (0)