diff --git a/management/internals/modules/reverseproxy/interface.go b/management/internals/modules/reverseproxy/interface.go index 7614b3ce574..8a81ee30717 100644 --- a/management/internals/modules/reverseproxy/interface.go +++ b/management/internals/modules/reverseproxy/interface.go @@ -12,6 +12,7 @@ type Manager interface { CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) DeleteService(ctx context.Context, accountID, userID, serviceID string) error + DeleteAllServices(ctx context.Context, accountID, userID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error diff --git a/management/internals/modules/reverseproxy/interface_mock.go b/management/internals/modules/reverseproxy/interface_mock.go index d5f38c38a25..6533d90bf99 100644 --- a/management/internals/modules/reverseproxy/interface_mock.go +++ b/management/internals/modules/reverseproxy/interface_mock.go @@ -49,6 +49,20 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service) } +// DeleteAllServices mocks base method. +func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteAllServices indicates an expected call of DeleteAllServices. +func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) +} + // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 535705a37ff..8068178a59b 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -150,7 +151,7 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -330,19 +331,33 @@ func (m *managerImpl) preserveServiceMetadata(service, existingService *reversep } func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() - switch { case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "") + m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") case !service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") case service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") default: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") + } +} + +func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) { + oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() + mapping := service.ToProtoMapping(operation, oldService, oidcCfg) + m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster) +} + +func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) { + if len(mappings) == 0 { + return + } + update := &proto.GetMappingUpdateResponse{ + Mapping: mappings, } + m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster) } // validateTargetReferences checks that all target IDs reference existing peers or resources in the account. @@ -397,7 +412,54 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") + + m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + var services []*reverseproxy.Service + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var err error + services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + for _, service := range services { + if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil { + return fmt.Errorf("failed to delete service: %w", err) + } + } + + return nil + }) + if err != nil { + return err + } + + clusterMappings := make(map[string][]*proto.ProxyMapping) + oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() + + for _, service := range services { + m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta()) + mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg) + clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping) + } + + for cluster, mappings := range clusterMappings { + m.sendMappingsToCluster(mappings, cluster) + } m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -452,7 +514,7 @@ func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID st return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -465,12 +527,20 @@ func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID return fmt.Errorf("failed to get services: %w", err) } + clusterMappings := make(map[string][]*proto.ProxyMapping) + oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() + for _, service := range services { err = m.replaceHostByLookup(ctx, accountID, service) if err != nil { return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg) + clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping) + } + + for cluster, mappings := range clusterMappings { + m.sendMappingsToCluster(mappings, cluster) } return nil diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 4771d35af48..e47ea53152b 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -61,9 +61,6 @@ type ProxyServiceServer struct { // Map of cluster address -> set of proxy IDs clusterProxies sync.Map - // Channel for broadcasting reverse proxy updates to all proxies - updatesChan chan *proto.ProxyMapping - // Manager for access logs accessLogManager accesslogs.Manager @@ -101,7 +98,7 @@ type proxyConnection struct { proxyID string address string stream proto.ProxyService_GetMappingUpdateServer - sendChan chan *proto.ProxyMapping + sendChan chan *proto.GetMappingUpdateResponse ctx context.Context cancel context.CancelFunc } @@ -110,7 +107,6 @@ type proxyConnection struct { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ - updatesChan: make(chan *proto.ProxyMapping, 100), accessLogManager: accessLogMgr, oidcConfig: oidcConfig, tokenStore: tokenStore, @@ -177,7 +173,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest proxyID: proxyID, address: proxyAddress, stream: stream, - sendChan: make(chan *proto.ProxyMapping, 100), + sendChan: make(chan *proto.GetMappingUpdateResponse, 100), ctx: connCtx, cancel: cancel, } @@ -288,7 +284,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) for { select { case msg := <-conn.sendChan: - if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil { + if err := conn.stream.Send(msg); err != nil { errChan <- err return } @@ -339,7 +335,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA // Management should call this when services are created/updated/removed. // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { +func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) { log.Debugf("Broadcasting service update to all connected proxy servers") s.connectedProxies.Range(func(key, value interface{}) bool { conn := value.(*proxyConnection) @@ -349,7 +345,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) { } select { case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID) + log.Debugf("Sent service update to proxy server %s", conn.proxyID) default: log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID) } @@ -418,7 +414,7 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { +func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) { if clusterAddr == "" { s.SendServiceUpdate(update) return @@ -441,7 +437,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi } select { case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr) default: log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } @@ -451,23 +447,31 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi } // perProxyMessage returns a copy of update with a fresh one-time token for -// create/update operations. For delete operations the original message is -// returned unchanged because proxies do not need to authenticate for removal. +// create/update operations. For delete operations the original mapping is +// used unchanged because proxies do not need to authenticate for removal. // Returns nil if token generation fails (the proxy should be skipped). -func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping { - if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" { - return update - } +func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse { + resp := make([]*proto.ProxyMapping, 0, len(update.Mapping)) + for _, mapping := range update.Mapping { + if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED { + resp = append(resp, mapping) + continue + } - token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute) - if err != nil { - log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) - return nil + token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute) + if err != nil { + log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err) + return nil + } + + msg := shallowCloneMapping(mapping) + msg.AuthToken = token + resp = append(resp, msg) } - msg := shallowCloneMapping(update) - msg.AuthToken = token - return msg + return &proto.GetMappingUpdateResponse{ + Mapping: resp, + } } // shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 84fb549231e..31b1df3b11d 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -17,6 +17,10 @@ type mockReverseProxyManager struct { err error } +func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { if m.err != nil { return nil, m.err diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 4c84e6010ea..de8ca3c84e1 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -16,8 +16,8 @@ import ( // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. -func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping { - ch := make(chan *proto.ProxyMapping, 10) +func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { + ch := make(chan *proto.GetMappingUpdateResponse, 10) conn := &proxyConnection{ proxyID: proxyID, address: clusterAddr, @@ -31,7 +31,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan return ch } -func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping { +func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse { select { case msg := <-ch: return msg @@ -45,20 +45,19 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { defer tokenStore.Close() s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } const cluster = "proxy.example.com" const numProxies = 3 - channels := make([]chan *proto.ProxyMapping, numProxies) + channels := make([]chan *proto.GetMappingUpdateResponse, numProxies) for i := range numProxies { id := "proxy-" + string(rune('a'+i)) channels[i] = registerFakeProxy(s, id, cluster) } - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", @@ -68,14 +67,20 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { }, } + update := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{mapping}, + } + s.SendServiceUpdateToCluster(update, cluster) tokens := make([]string, numProxies) for i, ch := range channels { - msg := drainChannel(ch) - require.NotNil(t, msg, "proxy %d should receive a message", i) - assert.Equal(t, update.Domain, msg.Domain) - assert.Equal(t, update.Id, msg.Id) + resp := drainChannel(ch) + require.NotNil(t, resp, "proxy %d should receive a message", i) + require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i) + msg := resp.Mapping[0] + assert.Equal(t, mapping.Domain, msg.Domain) + assert.Equal(t, mapping.Id, msg.Id) assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i) tokens[i] = msg.AuthToken } @@ -100,31 +105,36 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { defer tokenStore.Close() s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } const cluster = "proxy.example.com" ch1 := registerFakeProxy(s, "proxy-a", cluster) ch2 := registerFakeProxy(s, "proxy-b", cluster) - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "service-1", AccountId: "account-1", Domain: "test.example.com", } + update := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{mapping}, + } + s.SendServiceUpdateToCluster(update, cluster) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) - require.NotNil(t, msg1) - require.NotNil(t, msg2) + resp1 := drainChannel(ch1) + resp2 := drainChannel(ch2) + require.NotNil(t, resp1) + require.NotNil(t, resp2) + require.Len(t, resp1.Mapping, 1) + require.Len(t, resp2.Mapping, 1) // Delete operations should not generate tokens - assert.Empty(t, msg1.AuthToken) - assert.Empty(t, msg2.AuthToken) + assert.Empty(t, resp1.Mapping[0].AuthToken) + assert.Empty(t, resp2.Mapping[0].AuthToken) // No tokens should have been created assert.Equal(t, 0, tokenStore.GetTokenCount()) @@ -135,27 +145,35 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { defer tokenStore.Close() s := &ProxyServiceServer{ - tokenStore: tokenStore, - updatesChan: make(chan *proto.ProxyMapping, 100), + tokenStore: tokenStore, } // Register proxies in different clusters (SendServiceUpdate broadcasts to all) ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") ch2 := registerFakeProxy(s, "proxy-b", "cluster-b") - update := &proto.ProxyMapping{ + mapping := &proto.ProxyMapping{ Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, Id: "service-1", AccountId: "account-1", Domain: "test.example.com", } + update := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{mapping}, + } + s.SendServiceUpdate(update) - msg1 := drainChannel(ch1) - msg2 := drainChannel(ch2) - require.NotNil(t, msg1) - require.NotNil(t, msg2) + resp1 := drainChannel(ch1) + resp2 := drainChannel(ch2) + require.NotNil(t, resp1) + require.NotNil(t, resp2) + require.Len(t, resp1.Mapping, 1) + require.Len(t, resp2.Mapping, 1) + + msg1 := resp1.Mapping[0] + msg2 := resp2.Mapping[0] assert.NotEmpty(t, msg1.AuthToken) assert.NotEmpty(t, msg2.AuthToken) diff --git a/management/server/account.go b/management/server/account.go index 1e35d4ad181..d436445e808 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -714,6 +714,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } + err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID) + if err != nil { + return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err) + } + for _, otherUser := range account.Users { if otherUser.Id == userID { continue diff --git a/management/server/account_test.go b/management/server/account_test.go index 1cc0c9571b1..f9e9c162d50 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -31,6 +31,7 @@ import ( reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -3122,7 +3123,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil)) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil) + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil)) return manager, updateManager, nil } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 0a9a560cde4..e467d284369 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -345,6 +345,10 @@ type testServiceManager struct { store store.Store } +func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { return nil, nil } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index db7cfd32d24..f6c337929f2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4906,6 +4906,28 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren return service, nil } +func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { + tx := s.db.Preload("Targets") + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var serviceList []*reverseproxy.Service + result := tx.Find(&serviceList, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get services from store") + } + + for _, service := range serviceList { + if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil { + return nil, fmt.Errorf("decrypt service data: %w", err) + } + } + + return serviceList, nil +} + func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { var service *reverseproxy.Service result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) diff --git a/management/server/store/store.go b/management/server/store/store.go index a8e44a438b8..a62e74c785f 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -256,6 +256,7 @@ type Store interface { UpdateService(ctx context.Context, service *reverseproxy.Service) error DeleteService(ctx context.Context, accountID, serviceID string) error GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) + GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 2f451dc430b..868128e4d26 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1094,6 +1094,21 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID) } +// GetServicesByAccountID mocks base method. +func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID) + ret0, _ := ret[0].([]*reverseproxy.Service) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServicesByAccountID indicates an expected call of GetServicesByAccountID. +func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID) +} + // GetAccountSettings mocks base method. func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) { m.ctrl.T.Helper() diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 53d7019f724..2a6cd108e17 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -179,6 +179,10 @@ type storeBackedServiceManager struct { tokenStore *nbgrpc.OneTimeTokenStore } +func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { + return nil +} + func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) }