diff --git a/api/api.go b/api/api.go index ba7c26f83..6fd9c697d 100644 --- a/api/api.go +++ b/api/api.go @@ -1774,3 +1774,7 @@ func (api *api) GetForwards() (*GetForwardsResponse, error) { NumForwards: uint64(numForwards), }, nil } + +func (a *api) IsShuttingDown() bool { + return a.svc.IsShuttingDown() +} diff --git a/api/models.go b/api/models.go index 971d80d1b..1103e22ed 100644 --- a/api/models.go +++ b/api/models.go @@ -83,6 +83,7 @@ type API interface { ExecuteCustomNodeCommand(ctx context.Context, command string) (interface{}, error) SendEvent(event string, properties interface{}) GetForwards() (*GetForwardsResponse, error) + IsShuttingDown() bool } type App struct { diff --git a/http/http_service.go b/http/http_service.go index be7637c86..8ce08edc0 100644 --- a/http/http_service.go +++ b/http/http_service.go @@ -14,7 +14,7 @@ import ( "github.com/golang-jwt/jwt/v5" echojwt "github.com/labstack/echo-jwt/v4" "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + echoMiddleware "github.com/labstack/echo/v4/middleware" "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -23,6 +23,7 @@ import ( "github.com/getAlby/hub/events" "github.com/getAlby/hub/logger" "github.com/getAlby/hub/service" + "github.com/getAlby/hub/middleware" "github.com/getAlby/hub/api" "github.com/getAlby/hub/frontend" @@ -63,20 +64,20 @@ func NewHttpService(svc service.Service, eventPublisher events.EventPublisher) * func (httpSvc *HttpService) RegisterSharedRoutes(e *echo.Echo) { e.HideBanner = true - e.Use(middleware.SecureWithConfig(middleware.SecureConfig{ + e.Use(echoMiddleware.SecureWithConfig(echoMiddleware.SecureConfig{ ContentTypeNosniff: "nosniff", XFrameOptions: "DENY", ContentSecurityPolicy: "default-src 'self'; img-src 'self' https://uploads.getalby-assets.com https://getalby.com; connect-src 'self' https://api.getalby.com https://getalby.com https://zapplanner.albylabs.com wss://relay.getalby.com/v1; frame-src https://embed.bitrefill.com", ReferrerPolicy: "no-referrer", })) - e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ + e.Use(echoMiddleware.RequestLoggerWithConfig(echoMiddleware.RequestLoggerConfig{ LogURI: true, LogStatus: true, LogRemoteIP: true, LogUserAgent: true, LogHost: true, LogRequestID: true, - LogValuesFunc: func(c echo.Context, values middleware.RequestLoggerValues) error { + LogValuesFunc: func(c echo.Context, values echoMiddleware.RequestLoggerValues) error { logger.Logger.WithFields(logrus.Fields{ "uri": values.URI, "status": values.Status, @@ -89,15 +90,17 @@ func (httpSvc *HttpService) RegisterSharedRoutes(e *echo.Echo) { }, })) - e.Use(middleware.Recover()) - e.Use(middleware.RequestID()) + e.Use(middleware.ShutdownMiddleware(httpSvc.api)) + + e.Use(echoMiddleware.Recover()) + e.Use(echoMiddleware.RequestID()) e.GET("/api/info", httpSvc.infoHandler) e.POST("/api/setup", httpSvc.setupHandler) e.POST("/api/restore", httpSvc.restoreBackupHandler) // allow one unlock request per second - unlockRateLimiter := middleware.RateLimiter(middleware.NewRateLimiterMemoryStore(1)) + unlockRateLimiter := echoMiddleware.RateLimiter(echoMiddleware.NewRateLimiterMemoryStore(1)) e.POST("/api/start", httpSvc.startHandler, unlockRateLimiter) e.POST("/api/unlock", httpSvc.unlockHandler, unlockRateLimiter) e.POST("/api/backup", httpSvc.createBackupHandler, unlockRateLimiter) diff --git a/http/http_service_test.go b/http/http_service_test.go index 3e1bbf720..b553c500b 100644 --- a/http/http_service_test.go +++ b/http/http_service_test.go @@ -13,12 +13,14 @@ import ( "github.com/getAlby/hub/config" "github.com/getAlby/hub/constants" "github.com/getAlby/hub/events" + "github.com/getAlby/hub/lnclient" "github.com/getAlby/hub/logger" "github.com/getAlby/hub/tests/db" "github.com/getAlby/hub/tests/mocks" "github.com/labstack/echo/v4" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -41,6 +43,7 @@ func TestUnlock_IncorrectPassword(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -75,6 +78,7 @@ func TestUnlock_UnknownPermission(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -108,6 +112,7 @@ func TestGetApps_NoToken(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -139,6 +144,7 @@ func TestGetApps_ReadonlyPermission(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -192,6 +198,7 @@ func TestGetApps_FullPermission(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -243,6 +250,7 @@ func TestCreateApp_NoToken(t *testing.T) { mockSvc.On("GetKeys").Return(mocks.NewMockKeys(t)) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mocks.NewMockAlbyOAuthService(t)) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -284,6 +292,7 @@ func TestCreateApp_FullPermission(t *testing.T) { mockSvc.On("GetKeys").Return(mockKeys) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mockAlbyOAuthService) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -345,6 +354,7 @@ func TestCreateApp_ReadonlyPermission(t *testing.T) { mockSvc.On("GetKeys").Return(mockKeys) mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) mockSvc.On("GetAlbyOAuthSvc").Return(mockAlbyOAuthService) + mockSvc.On("IsShuttingDown").Return(false) httpSvc := NewHttpService(mockSvc, mockEventPublisher) httpSvc.RegisterSharedRoutes(e) @@ -381,3 +391,140 @@ func TestCreateApp_ReadonlyPermission(t *testing.T) { assert.Equal(t, http.StatusForbidden, rec2.Code) } + +func TestShutdown_BlockedEndpoint(t *testing.T) { + e := echo.New() + logger.Init(strconv.Itoa(int(logrus.DebugLevel))) + mockSvc := mocks.NewMockService(t) + gormDb, err := db.NewDB(t) + require.NoError(t, err) + defer db.CloseDB(gormDb) + + mockEventPublisher := events.NewEventPublisher() + + mockConfig := mocks.NewMockConfig(t) + mockConfig.On("GetEnv").Return(&config.AppConfig{}) + mockConfig.On("CheckUnlockPassword", "123").Return(true) + mockConfig.On("GetJWTSecret").Return("dummy secret", nil) + + mockKeys := mocks.NewMockKeys(t) + + mockAlbyOAuthService := mocks.NewMockAlbyOAuthService(t) + + mockSvc.On("GetDB").Return(gormDb) + mockSvc.On("GetConfig").Return(mockConfig) + mockSvc.On("GetKeys").Return(mockKeys) + mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) + mockSvc.On("GetAlbyOAuthSvc").Return(mockAlbyOAuthService) + mockSvc.On("IsShuttingDown").Return(false) + + httpSvc := NewHttpService(mockSvc, mockEventPublisher) + httpSvc.RegisterSharedRoutes(e) + + requestBody := api.UnlockRequest{UnlockPassword: "123", Permission: "readonly"} + jsonBody, _ := json.Marshal(requestBody) + req := httptest.NewRequest(http.MethodPost, "/api/unlock", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") // Set Content-Type header + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + // mock node sutting down after unlock + mockSvc.On("IsShuttingDown").Unset() + mockSvc.On("IsShuttingDown").Return(true) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + + type authTokenResponse struct { + Token string `json:"token"` + } + + var unlockAuthTokenResponse authTokenResponse + err = json.Unmarshal(body, &unlockAuthTokenResponse) + require.NoError(t, err) + assert.NotEmpty(t, unlockAuthTokenResponse.Token) + + req2 := httptest.NewRequest(http.MethodGet, "/api/peers", nil) + req2.Header.Set("Authorization", "Bearer "+unlockAuthTokenResponse.Token) + req2.Header.Set("Content-Type", "application/json") + + rec2 := httptest.NewRecorder() + e.ServeHTTP(rec2, req2) + + assert.Equal(t, http.StatusServiceUnavailable, rec2.Code) + assert.Contains(t, rec2.Body.String(), "Node is shutting down") + +} + +func TestShutdown_AllowedEndpoint(t *testing.T) { + e := echo.New() + logger.Init(strconv.Itoa(int(logrus.DebugLevel))) + mockSvc := mocks.NewMockService(t) + gormDb, err := db.NewDB(t) + require.NoError(t, err) + defer db.CloseDB(gormDb) + + mockEventPublisher := events.NewEventPublisher() + + mockConfig := mocks.NewMockConfig(t) + mockConfig.On("GetEnv").Return(&config.AppConfig{}) + mockConfig.On("CheckUnlockPassword", "123").Return(true) + mockConfig.On("GetJWTSecret").Return("dummy secret", nil) + + mockKeys := mocks.NewMockKeys(t) + + mockAlbyOAuthService := mocks.NewMockAlbyOAuthService(t) + + mockSvc.On("GetDB").Return(gormDb) + mockSvc.On("GetConfig").Return(mockConfig) + mockSvc.On("GetKeys").Return(mockKeys) + mockSvc.On("GetAlbySvc").Return(mocks.NewMockAlbyService(t)) + mockSvc.On("GetAlbyOAuthSvc").Return(mockAlbyOAuthService) + mockSvc.On("IsShuttingDown").Return(false) + + httpSvc := NewHttpService(mockSvc, mockEventPublisher) + httpSvc.RegisterSharedRoutes(e) + + requestBody := api.UnlockRequest{UnlockPassword: "123", Permission: "readonly"} + jsonBody, _ := json.Marshal(requestBody) + req := httptest.NewRequest(http.MethodPost, "/api/unlock", bytes.NewBuffer(jsonBody)) + req.Header.Set("Content-Type", "application/json") // Set Content-Type header + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + mockLNClient := mocks.NewMockLNClient(t) + mockLNClient.On("GetNodeStatus", mock.Anything).Return(&lnclient.NodeStatus{ + IsReady: true, // or false, doesn't matter + InternalNodeStatus: map[string]interface{}{"running": true}, + }, nil) + + // mock node sutting down after unlock + mockSvc.On("IsShuttingDown").Unset() + mockSvc.On("IsShuttingDown").Return(true) + mockSvc.On("GetLNClient").Return(mockLNClient) + + body, err := io.ReadAll(rec.Body) + require.NoError(t, err) + + type authTokenResponse struct { + Token string `json:"token"` + } + + var unlockAuthTokenResponse authTokenResponse + err = json.Unmarshal(body, &unlockAuthTokenResponse) + require.NoError(t, err) + assert.NotEmpty(t, unlockAuthTokenResponse.Token) + + req2 := httptest.NewRequest(http.MethodGet, "/api/node/status", nil) + req2.Header.Set("Authorization", "Bearer "+unlockAuthTokenResponse.Token) + req2.Header.Set("Content-Type", "application/json") + + rec2 := httptest.NewRecorder() + e.ServeHTTP(rec2, req2) + + assert.Equal(t, http.StatusOK, rec2.Code) +} diff --git a/middleware/shutdown.go b/middleware/shutdown.go new file mode 100644 index 000000000..a9e265752 --- /dev/null +++ b/middleware/shutdown.go @@ -0,0 +1,39 @@ +package middleware + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" +) + +type ShutdownNotifier interface { + IsShuttingDown() bool +} + +func ShutdownMiddleware(notifier ShutdownNotifier) echo.MiddlewareFunc { + + // whitelist routes that can still be called when the node is shutting down + safeRoutes := map[string]bool{ + "/api/health": true, + "/api/node/status": true, + "/api/info": true, + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Request().URL.Path + + if safeRoutes[path] || strings.HasPrefix(path, "/api/alby/") { + return next(c) + } + + if notifier.IsShuttingDown() { + return c.JSON(http.StatusServiceUnavailable, map[string]string{ + "message": "Node is shutting down. Please wait.", + }) + } + return next(c) + } + } +} diff --git a/service/models.go b/service/models.go index 024950d74..c6fb9d01f 100644 --- a/service/models.go +++ b/service/models.go @@ -34,4 +34,5 @@ type Service interface { GetKeys() keys.Keys GetRelayStatuses() []RelayStatus GetStartupState() string + IsShuttingDown() bool } diff --git a/service/service.go b/service/service.go index baee4f0df..3bef46b34 100644 --- a/service/service.go +++ b/service/service.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/adrg/xdg" @@ -46,6 +47,7 @@ type service struct { keys keys.Keys relayStatuses []RelayStatus startupState string + shuttingDown atomic.Bool } func NewService(ctx context.Context) (*service, error) { @@ -279,6 +281,10 @@ func (svc *service) GetStartupState() string { return svc.startupState } +func (svc *service) IsShuttingDown() bool { + return svc.shuttingDown.Load() +} + func (svc *service) removeExcessEvents() { logger.Logger.Debug("Cleaning up excess events") diff --git a/service/stop.go b/service/stop.go index 83e558c2a..878be618c 100644 --- a/service/stop.go +++ b/service/stop.go @@ -10,8 +10,10 @@ import ( func (svc *service) StopApp() { if svc.appCancelFn != nil { logger.Logger.Info("Stopping app...") + svc.shuttingDown.Store(true) svc.appCancelFn() svc.wg.Wait() + svc.shuttingDown.Store(false) logger.Logger.Info("app stopped") } } diff --git a/tests/mocks/Service.go b/tests/mocks/Service.go index 7564cd89d..7f7804479 100644 --- a/tests/mocks/Service.go +++ b/tests/mocks/Service.go @@ -548,6 +548,50 @@ func (_c *MockService_GetTransactionsService_Call) RunAndReturn(run func() trans return _c } +// IsShuttingDown provides a mock function for the type MockService +func (_mock *MockService) IsShuttingDown() bool { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for IsShuttingDown") + } + + var r0 bool + if returnFunc, ok := ret.Get(0).(func() bool); ok { + r0 = returnFunc() + } else { + r0 = ret.Get(0).(bool) + } + return r0 +} + +// MockService_IsShuttingDown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsShuttingDown' +type MockService_IsShuttingDown_Call struct { + *mock.Call +} + +// IsShuttingDown is a helper method to define mock.On call +func (_e *MockService_Expecter) IsShuttingDown() *MockService_IsShuttingDown_Call { + return &MockService_IsShuttingDown_Call{Call: _e.mock.On("IsShuttingDown")} +} + +func (_c *MockService_IsShuttingDown_Call) Run(run func()) *MockService_IsShuttingDown_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockService_IsShuttingDown_Call) Return(b bool) *MockService_IsShuttingDown_Call { + _c.Call.Return(b) + return _c +} + +func (_c *MockService_IsShuttingDown_Call) RunAndReturn(run func() bool) *MockService_IsShuttingDown_Call { + _c.Call.Return(run) + return _c +} + // Shutdown provides a mock function for the type MockService func (_mock *MockService) Shutdown() { _mock.Called() @@ -604,14 +648,20 @@ type MockService_StartApp_Call struct { } // StartApp is a helper method to define mock.On call -// - encryptionKey +// - encryptionKey string func (_e *MockService_Expecter) StartApp(encryptionKey interface{}) *MockService_StartApp_Call { return &MockService_StartApp_Call{Call: _e.mock.On("StartApp", encryptionKey)} } func (_c *MockService_StartApp_Call) Run(run func(encryptionKey string)) *MockService_StartApp_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) }) return _c } diff --git a/wails/wails_app.go b/wails/wails_app.go index cc41605ea..3cf18a2a7 100644 --- a/wails/wails_app.go +++ b/wails/wails_app.go @@ -3,6 +3,7 @@ package wails import ( "context" "embed" + "fmt" "github.com/getAlby/hub/api" "github.com/getAlby/hub/apps" @@ -90,6 +91,13 @@ func LaunchWailsApp(app *WailsApp, assets embed.FS, appIcon []byte) { } } +func (app *WailsApp) CheckShutdown() error { + if app.svc.IsShuttingDown() { + return fmt.Errorf("node is shutting down, please wait") + } + return nil +} + func NewWailsLogger() WailsLogger { return WailsLogger{} } diff --git a/wails/wails_handlers.go b/wails/wails_handlers.go index 45cc08993..144166858 100644 --- a/wails/wails_handlers.go +++ b/wails/wails_handlers.go @@ -27,6 +27,26 @@ type WailsRequestRouterResponse struct { func (app *WailsApp) WailsRequestRouter(route string, method string, body string) WailsRequestRouterResponse { ctx := app.ctx + // parse route to remove query parameters for matching + path := route + if q := strings.IndexByte(route, '?'); q >= 0 { + path = route[:q] + } + // whitelist routes that can still be called when the node is shutting down + safeRoutes := map[string]bool{ + "/api/health": true, + "/api/node/status": true, + "/api/info": true, + } + + isSafe := safeRoutes[path] || strings.HasPrefix(path, "/api/alby/") + + if !isSafe { + if err := app.CheckShutdown(); err != nil { + return WailsRequestRouterResponse{Body: nil, Error: "Node is shutting down. Please wait."} + } + } + // the grouping is done to avoid other parameters like &unused=true albyCallbackRegex := regexp.MustCompile( `/api/alby/callback\?code=([^&]+)(&.*)?`,