diff --git a/.gitignore b/.gitignore index cf5efeb4..18195b3b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ node_modules # Go workspaces go.work go.work.sum +.idea \ No newline at end of file diff --git a/README.md b/README.md index f5944a2a..e4b26f0c 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,10 @@ The Sliding Sync proxy requires some environment variables set to function. They Here is a short description of each, as of writing: ``` -SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' +SYNCV3_SERVER Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' (Supports unix socket: /path/to/socket) SYNCV3_DB Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING SYNCV3_SECRET Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database. -SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on. +SYNCV3_BINDADDR Default: 0.0.0.0:8008. The interface and port to listen on. (Supports unix socket: /path/to/socket) SYNCV3_TLS_CERT Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address. SYNCV3_TLS_KEY Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file. SYNCV3_PPROF Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen. diff --git a/cmd/syncv3/main.go b/cmd/syncv3/main.go index 9345c056..d4436ac0 100644 --- a/cmd/syncv3/main.go +++ b/cmd/syncv3/main.go @@ -60,10 +60,10 @@ const ( var helpMsg = fmt.Sprintf(` Environment var -%s Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' +%s Required. The destination homeserver to talk to (CS API HTTPS URL) e.g 'https://matrix-client.matrix.org' (Supports unix socket: /path/to/socket) %s Required. The postgres connection string: https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING %s Required. A secret to use to encrypt access tokens. Must remain the same for the lifetime of the database. -%s Default: 0.0.0.0:8008. The interface and port to listen on. +%s Default: 0.0.0.0:8008. The interface and port to listen on. (Supports unix socket: /path/to/socket) %s Default: unset. Path to a certificate file to serve to HTTPS clients. Specifying this enables TLS on the bound address. %s Default: unset. Path to a key file for the certificate. Must be provided along with the certificate file. %s Default: unset. The bind addr for pprof debugging e.g ':6060'. If not set, does not listen. diff --git a/internal/util.go b/internal/util.go index 51760a78..c0dffad5 100644 --- a/internal/util.go +++ b/internal/util.go @@ -1,5 +1,12 @@ package internal +import ( + "context" + "net" + "net/http" + "strings" +) + // Keys returns a slice containing copies of the keys of the given map, in no particular // order. func Keys[K comparable, V any](m map[K]V) []K { @@ -12,3 +19,22 @@ func Keys[K comparable, V any](m map[K]V) []K { } return output } + +func IsUnixSocket(httpOrUnixStr string) bool { + return strings.HasPrefix(httpOrUnixStr, "/") +} + +func UnixTransport(httpOrUnixStr string) *http.Transport { + return &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", httpOrUnixStr) + }, + } +} + +func GetBaseURL(httpOrUnixStr string) string { + if IsUnixSocket(httpOrUnixStr) { + return "http://unix" + } + return httpOrUnixStr +} diff --git a/internal/util_test.go b/internal/util_test.go index 63c14237..33f88c05 100644 --- a/internal/util_test.go +++ b/internal/util_test.go @@ -28,3 +28,31 @@ func assertSlice(t *testing.T, got, want []string) { t.Errorf("After sorting, got %v but expected %v", got, want) } } + +func TestUnixSocket_True(t *testing.T) { + address := "/path/to/socket" + if !IsUnixSocket(address) { + t.Errorf("%s is socket", address) + } +} + +func TestUnixSocket_False(t *testing.T) { + address := "localhost:8080" + if IsUnixSocket(address) { + t.Errorf("%s is not socket", address) + } +} + +func TestGetBaseUrl_UnixSocket(t *testing.T) { + address := "/path/to/socket" + if GetBaseURL(address) != "http://unix" { + t.Errorf("%s is unix socket", address) + } +} + +func TestGetBaseUrl_Http(t *testing.T) { + address := "localhost:8080" + if GetBaseURL(address) != "localhost:8080" { + t.Errorf("%s is not a unix socket", address) + } +} diff --git a/sync2/client.go b/sync2/client.go index e072e5b5..7d47f611 100644 --- a/sync2/client.go +++ b/sync2/client.go @@ -4,7 +4,8 @@ import ( "context" "encoding/json" "fmt" - "io/ioutil" + "github.com/matrix-org/sliding-sync/internal" + "io" "net/http" "net/url" "time" @@ -40,15 +41,20 @@ type HTTPClient struct { func NewHTTPClient(shortTimeout, longTimeout time.Duration, destHomeServer string) *HTTPClient { return &HTTPClient{ - LongTimeoutClient: &http.Client{ - Timeout: longTimeout, - Transport: otelhttp.NewTransport(http.DefaultTransport), - }, - Client: &http.Client{ - Timeout: shortTimeout, - Transport: otelhttp.NewTransport(http.DefaultTransport), - }, - DestinationServer: destHomeServer, + LongTimeoutClient: newClient(longTimeout, destHomeServer), + Client: newClient(shortTimeout, destHomeServer), + DestinationServer: internal.GetBaseURL(destHomeServer), + } +} + +func newClient(timeout time.Duration, destHomeServer string) *http.Client { + transport := http.DefaultTransport + if internal.IsUnixSocket(destHomeServer) { + transport = internal.UnixTransport(destHomeServer) + } + return &http.Client{ + Timeout: timeout, + Transport: otelhttp.NewTransport(transport), } } @@ -66,7 +72,7 @@ func (v *HTTPClient) Versions(ctx context.Context) ([]string, error) { return nil, fmt.Errorf("/versions returned HTTP %d", res.StatusCode) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return nil, err } @@ -99,7 +105,7 @@ func (v *HTTPClient) WhoAmI(ctx context.Context, accessToken string) (string, st return "", "", fmt.Errorf("/whoami returned HTTP %d", res.StatusCode) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { return "", "", err } diff --git a/v3.go b/v3.go index 8040fc6b..6333b81f 100644 --- a/v3.go +++ b/v3.go @@ -4,7 +4,10 @@ import ( "context" "embed" "encoding/json" + "errors" "fmt" + "io/fs" + "net" "net/http" "os" "strings" @@ -216,12 +219,18 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str // Block forever var err error - if tlsCert != "" && tlsKey != "" { - logger.Info().Msgf("listening TLS on %s", bindAddr) - err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv) + if internal.IsUnixSocket(bindAddr) { + logger.Info().Msgf("listening on unix socket %s", bindAddr) + listener := unixSocketListener(bindAddr) + err = http.Serve(listener, srv) } else { - logger.Info().Msgf("listening on %s", bindAddr) - err = http.ListenAndServe(bindAddr, srv) + if tlsCert != "" && tlsKey != "" { + logger.Info().Msgf("listening TLS on %s", bindAddr) + err = http.ListenAndServeTLS(bindAddr, tlsCert, tlsKey, srv) + } else { + logger.Info().Msgf("listening on %s", bindAddr) + err = http.ListenAndServe(bindAddr, srv) + } } if err != nil { sentry.CaptureException(err) @@ -230,6 +239,23 @@ func RunSyncV3Server(h http.Handler, bindAddr, destV2Server, tlsCert, tlsKey str } } +func unixSocketListener(bindAddr string) net.Listener { + err := os.Remove(bindAddr) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + logger.Fatal().Err(err).Msg("failed to remove existing unix socket") + } + listener, err := net.Listen("unix", bindAddr) + if err != nil { + logger.Fatal().Err(err).Msg("failed to serve unix socket") + } + // TODO: safe default for now (rwxr-xr-x), could be extracted as env variable if needed + err = os.Chmod(bindAddr, 0755) + if err != nil { + logger.Fatal().Err(err).Msg("failed to set unix socket permissions") + } + return listener +} + type HandlerError struct { StatusCode int Err error