Skip to content

[management] support account retrieval and creation by private domain #3825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -1719,23 +1719,26 @@ func (am *DefaultAccountManager) GetStore() store.Store {
return am.Store
}

// Creates account by private domain.
// Expects domain value to be a valid and a private dns domain.
func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel()

domain = strings.ToLower(domain)

count, err := am.Store.CountAccountsByPrivateDomain(ctx, domain)
if err != nil {
return nil, err
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
if handleNotFound(err) != nil {
return nil, false, err
}

if count > 0 {
return nil, status.Errorf(status.InvalidArgument, "account with private domain already exists")
// a primary account already exists for this private domain
if err == nil {
existingAccount, err := am.Store.GetAccount(ctx, existingPrimaryAccountID)
if err != nil {
return nil, false, err
}

return existingAccount, false, nil
}

// create a new account for this private domain
// retry twice for new ID clashes
for range 2 {
accountId := xid.New().String()
Expand Down Expand Up @@ -1765,7 +1768,7 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
Users: users,
// @todo check if using the MSP owner id here is ok
CreatedBy: initiatorId,
Domain: domain,
Domain: strings.ToLower(domain),
DomainCategory: types.PrivateCategory,
IsDomainPrimaryAccount: false,
Routes: routes,
Expand All @@ -1784,19 +1787,22 @@ func (am *DefaultAccountManager) CreateAccountByPrivateDomain(ctx context.Contex
}

if err := newAccount.AddAllGroup(); err != nil {
return nil, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
return nil, false, status.Errorf(status.Internal, "failed to add all group to new account by private domain")
}

if err := am.Store.SaveAccount(ctx, newAccount); err != nil {
log.WithContext(ctx).Errorf("failed to save new account %s by private domain: %v", newAccount.Id, err)
return nil, err
log.WithContext(ctx).WithFields(log.Fields{
"accountId": newAccount.Id,
"domain": domain,
}).Errorf("failed to create new account: %v", err)
return nil, false, err
}

am.StoreEvent(ctx, initiatorId, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
return newAccount, true, nil
}

return nil, status.Errorf(status.Internal, "failed to create new account by private domain")
return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain")
}

func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
Expand All @@ -1809,21 +1815,29 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return account, nil
}

// additional check to ensure there is only one account for this domain at the time of update
count, err := am.Store.CountAccountsByPrivateDomain(ctx, account.Domain)
if err != nil {
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)

// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}

if count > 1 {
return nil, status.Errorf(status.Internal, "more than one account exists with the same private domain")
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}

account.IsDomainPrimaryAccount = true

if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to update primary account %s by private domain: %v", account.Id, err)
return nil, status.Errorf(status.Internal, "failed to update primary account %s by private domain", account.Id)
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}

return account, nil
Expand Down
2 changes: 1 addition & 1 deletion management/server/account/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type Manager interface {
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
GetStore() store.Store
CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error)
GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
Expand Down
40 changes: 32 additions & 8 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"time"

"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
Expand Down Expand Up @@ -3198,7 +3198,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
}
}

func Test_CreateAccountByPrivateDomain(t *testing.T) {
func Test_GetCreateAccountByPrivateDomain(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
Expand All @@ -3209,9 +3209,10 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"

account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)

assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)
assert.Equal(t, types.PrivateCategory, account.DomainCategory)
Expand All @@ -3220,9 +3221,25 @@ func Test_CreateAccountByPrivateDomain(t *testing.T) {
assert.Equal(t, 0, len(account.Users))
assert.Equal(t, 0, len(account.SetupKeys))

// retry should fail
_, err = manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.Error(t, err)
// should return a new account because the previous one is not primary
account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)

assert.True(t, created2)
assert.False(t, account2.IsDomainPrimaryAccount)
assert.Equal(t, domain, account2.Domain)
assert.Equal(t, types.PrivateCategory, account2.DomainCategory)
assert.Equal(t, initiatorId, account2.CreatedBy)
assert.Equal(t, 1, len(account2.Groups))
assert.Equal(t, 0, len(account2.Users))
assert.Equal(t, 0, len(account2.SetupKeys))

account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)

_, err = manager.UpdateToPrimaryAccount(ctx, account2.Id)
assert.Error(t, err, "should not be able to update a second account to primary")
}

func Test_UpdateToPrimaryAccount(t *testing.T) {
Expand All @@ -3236,14 +3253,21 @@ func Test_UpdateToPrimaryAccount(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"

account, err := manager.CreateAccountByPrivateDomain(ctx, initiatorId, domain)
account, created, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.True(t, created)
assert.False(t, account.IsDomainPrimaryAccount)
assert.Equal(t, domain, account.Domain)

// retry should fail
account, err = manager.UpdateToPrimaryAccount(ctx, account.Id)
assert.NoError(t, err)
assert.True(t, account.IsDomainPrimaryAccount)

account2, created2, err := manager.GetOrCreateAccountByPrivateDomain(ctx, initiatorId, domain)
assert.NoError(t, err)
assert.False(t, created2)
assert.True(t, account.IsDomainPrimaryAccount)
assert.Equal(t, account.Id, account2.Id)
}

func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
Expand Down
11 changes: 6 additions & 5 deletions management/server/mock_server/account_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ type MockAccountManager struct {
DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error
BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
GetStoreFunc func() store.Store
CreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, error)
UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error)
GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)

GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
}

func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
Expand Down Expand Up @@ -862,11 +863,11 @@ func (am *MockAccountManager) GetStore() store.Store {
return nil
}

func (am *MockAccountManager) CreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, error) {
if am.CreateAccountByPrivateDomainFunc != nil {
return am.CreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) {
if am.GetOrCreateAccountByPrivateDomainFunc != nil {
return am.GetOrCreateAccountByPrivateDomainFunc(ctx, initiatorId, domain)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateAccountByPrivateDomain is not implemented")
return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented")
}

func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
Expand Down
Loading