diff --git a/.github/workflows/check-license-dependencies.yml b/.github/workflows/check-license-dependencies.yml index d3da427b086..da59ce4453a 100644 --- a/.github/workflows/check-license-dependencies.yml +++ b/.github/workflows/check-license-dependencies.yml @@ -15,27 +15,28 @@ jobs: - name: Check for problematic license dependencies run: | echo "Checking for dependencies on management/, signal/, and relay/ packages..." + echo "" # Find all directories except the problematic ones and system dirs FOUND_ISSUES=0 - find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do + while IFS= read -r dir; do echo "=== Checking $dir ===" # Search for problematic imports, excluding test files - RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) - if [ ! -z "$RESULTS" ]; then + RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true) + if [ -n "$RESULTS" ]; then echo "❌ Found problematic dependencies:" echo "$RESULTS" FOUND_ISSUES=1 else echo "✓ No problematic dependencies found" fi - done + done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort) + + echo "" if [ $FOUND_ISSUES -eq 1 ]; then - echo "" echo "❌ Found dependencies on management/, signal/, or relay/ packages" - echo "These packages will change license and should not be imported by client or shared code" + echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code" exit 1 else - echo "" echo "✅ All license dependencies are clean" fi diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 90bde7707f9..ab13cf3895f 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) { func TestWireguardPortDefaultVsExplicit(t *testing.T) { tests := []struct { - name string - wireguardPort *int - expectedPort int - description string + name string + wireguardPort *int + expectedPort int + description string }{ { name: "no port specified uses default", diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index 7f72daa8e25..c5036da37af 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -29,7 +29,7 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/server" "github.com/netbirdio/netbird/client/ssh/testutil" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestMain(m *testing.M) { diff --git a/client/ssh/server/jwt_test.go b/client/ssh/server/jwt_test.go index 1cf7c0f7094..068d709b4b7 100644 --- a/client/ssh/server/jwt_test.go +++ b/client/ssh/server/jwt_test.go @@ -26,7 +26,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/client" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/testutil" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestJWTEnforcement(t *testing.T) { diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 39899352af2..80e2ae1413a 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -21,8 +21,8 @@ import ( "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/ssh/detection" - "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/auth/jwt" "github.com/netbirdio/netbird/version" ) @@ -349,7 +349,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error { return nil } -func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) { +func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) { s.mu.RLock() jwtExtractor := s.jwtExtractor s.mu.RUnlock() @@ -372,7 +372,7 @@ func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth return &userAuth, nil } -func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool { +func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool { return userAuth.UserId != "" } diff --git a/management/server/account.go b/management/server/account.go index 1a484b58e60..a15263c822f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -17,6 +17,8 @@ import ( "sync/atomic" "time" + "github.com/netbirdio/netbird/shared/auth" + cacheStore "github.com/eko/gocache/lib/v4/store" "github.com/eko/gocache/store/redis/v4" "github.com/rs/xid" @@ -1046,7 +1048,7 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun } // updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes -func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth, +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth auth.UserAuth, primaryDomain bool, ) error { if userAuth.Domain == "" { @@ -1095,7 +1097,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, userAccountID string, domainAccountID string, - userAuth nbcontext.UserAuth, + userAuth auth.UserAuth, ) error { primaryDomain := domainAccountID == "" || userAccountID == domainAccountID err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain) @@ -1114,7 +1116,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount( // addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { if userAuth.UserId == "" { return "", fmt.Errorf("user ID is empty") } @@ -1145,7 +1147,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai return newAccount.Id, nil } -func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) { newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID @@ -1309,7 +1311,7 @@ func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, ac return newOnboarding, nil } -func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) } @@ -1353,7 +1355,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. // requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager -func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { if userAuth.IsChild || userAuth.IsPAT { return nil } @@ -1511,7 +1513,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // // UserAuth IsChild -> checks that account exists -func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth auth.UserAuth) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory) @@ -1590,7 +1592,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont return domainAccountID, cancel, nil } -func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth auth.UserAuth) (string, error) { userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) @@ -1638,7 +1640,7 @@ func handleNotFound(err error) error { return nil } -func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool { +func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAuth) bool { return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index fe9fb25c632..8251b951fa0 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -6,10 +6,11 @@ import ( "net/netip" "time" + "github.com/netbirdio/netbird/shared/auth" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -45,10 +46,10 @@ type Manager interface { GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) - GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error GetUserByID(ctx context.Context, id string) (*types.User, error) - GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error @@ -120,12 +121,12 @@ type Manager interface { UpdateAccountPeers(ctx context.Context, accountID string) BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error + SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error GetStore() store.Store GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) - GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index c7cad6a4f4d..44cb7d62931 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,7 +25,6 @@ import ( nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -42,6 +41,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" ) func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account *types.Account, userID string) { @@ -442,7 +442,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { } func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { - type initUserParams nbcontext.UserAuth + type initUserParams auth.UserAuth var ( publicDomain = "public.com" @@ -465,7 +465,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { testCases := []struct { name string - inputClaims nbcontext.UserAuth + inputClaims auth.UserAuth inputInitUserParams initUserParams inputUpdateAttrs bool inputUpdateClaimAccount bool @@ -480,7 +480,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }{ { name: "New User With Public Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: publicDomain, UserId: "pub-domain-user", DomainCategory: types.PublicCategory, @@ -497,7 +497,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Unknown Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: unknownDomain, UserId: "unknown-domain-user", DomainCategory: types.UnknownCategory, @@ -514,7 +514,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New User With Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -531,7 +531,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "New Regular User With Existing Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: privateDomain, UserId: "new-pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -549,7 +549,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing User With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -566,7 +566,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, DomainCategory: types.PrivateCategory, @@ -584,7 +584,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { }, { name: "User With Private Category And Empty Domain", - inputClaims: nbcontext.UserAuth{ + inputClaims: auth.UserAuth{ Domain: "", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, @@ -613,7 +613,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, nbcontext.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } @@ -653,7 +653,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount Domain: domain, UserId: userId, @@ -912,13 +912,13 @@ func TestAccountManager_DeleteAccount(t *testing.T) { } func BenchmarkTest_GetAccountWithclaims(b *testing.B) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ Domain: "example.com", UserId: "pvt-domain-user", DomainCategory: types.PrivateCategory, } - publicClaims := nbcontext.UserAuth{ + publicClaims := auth.UserAuth{ Domain: "test.com", UserId: "public-domain-user", DomainCategory: types.PublicCategory, @@ -2648,7 +2648,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") t.Run("skip sync for token auth type", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group3"}, @@ -2663,7 +2663,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("empty jwt groups", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2677,7 +2677,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("jwt match existing api group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2698,7 +2698,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"])) - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1"}, @@ -2716,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add jwt group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group1", "group2"}, @@ -2730,7 +2730,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("existed group not update", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{"group2"}, @@ -2744,7 +2744,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("add new group", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{"group1", "group3"}, @@ -2762,7 +2762,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when list is empty", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user1", AccountId: "accountID", Groups: []string{}, @@ -2777,7 +2777,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { }) t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: "user2", AccountId: "accountID", Groups: []string{}, @@ -3630,7 +3630,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { // Test adding new user to existing account with approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3660,7 +3660,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { } // Create a domain-based account without user approval - ownerUserAuth := nbcontext.UserAuth{ + ownerUserAuth := auth.UserAuth{ UserId: "owner-user", Domain: "example.com", DomainCategory: types.PrivateCategory, @@ -3679,7 +3679,7 @@ func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { // Test adding new user to existing account without approval required newUserID := "new-user-id" - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: newUserID, Domain: "example.com", DomainCategory: types.PrivateCategory, diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index ece9dc32110..0c62357dcc0 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -9,18 +9,19 @@ import ( "github.com/golang-jwt/jwt/v5" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/base62" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) var _ Manager = (*manager)(nil) type Manager interface { - ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsed(ctx context.Context, tokenID string) error GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } @@ -55,20 +56,20 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s } } -func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { token, err := m.validator.ValidateAndParse(ctx, value) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } userAuth, err := m.extractor.ToUserAuth(token) if err != nil { - return nbcontext.UserAuth{}, nil, err + return auth.UserAuth{}, nil, err } return userAuth, token, err } -func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } diff --git a/management/server/auth/manager_mock.go b/management/server/auth/manager_mock.go index 30a7a716190..edf158a4978 100644 --- a/management/server/auth/manager_mock.go +++ b/management/server/auth/manager_mock.go @@ -3,9 +3,10 @@ package auth import ( "context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/golang-jwt/jwt/v5" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/types" ) @@ -15,18 +16,18 @@ var ( // @note really dislike this mocking approach but rather than have to do additional test refactoring. type MockManager struct { - ValidateAndParseTokenFunc func(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) - EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) + ValidateAndParseTokenFunc func(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) + EnsureUserAccessByJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) MarkPATUsedFunc func(ctx context.Context, tokenID string) error GetPATInfoFunc func(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) } // EnsureUserAccessByJWTGroups implements Manager. -func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func (m *MockManager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) { if m.EnsureUserAccessByJWTGroupsFunc != nil { return m.EnsureUserAccessByJWTGroupsFunc(ctx, userAuth, token) } - return nbcontext.UserAuth{}, nil + return auth.UserAuth{}, nil } // GetPATInfo implements Manager. @@ -46,9 +47,9 @@ func (m *MockManager) MarkPATUsed(ctx context.Context, tokenID string) error { } // ValidateAndParseToken implements Manager. -func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (nbcontext.UserAuth, *jwt.Token, error) { +func (m *MockManager) ValidateAndParseToken(ctx context.Context, value string) (auth.UserAuth, *jwt.Token, error) { if m.ValidateAndParseTokenFunc != nil { return m.ValidateAndParseTokenFunc(ctx, value) } - return nbcontext.UserAuth{}, &jwt.Token{}, nil + return auth.UserAuth{}, &jwt.Token{}, nil } diff --git a/management/server/auth/manager_test.go b/management/server/auth/manager_test.go index c8015eb370f..b9f091b1ee8 100644 --- a/management/server/auth/manager_test.go +++ b/management/server/auth/manager_test.go @@ -17,10 +17,10 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) { @@ -131,7 +131,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) { } // this has been validated and parsed by ValidateAndParseToken - userAuth := nbcontext.UserAuth{ + userAuth := nbauth.UserAuth{ AccountId: account.Id, Domain: domain, UserId: userId, @@ -236,7 +236,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tests := []struct { name string tokenFunc func() string - expected *nbcontext.UserAuth // nil indicates expected error + expected *nbauth.UserAuth // nil indicates expected error }{ { name: "Valid with custom claims", @@ -258,7 +258,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", AccountId: "account-id|567", Domain: "http://localhost", @@ -282,7 +282,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) { tokenString, _ := token.SignedString(key) return tokenString }, - expected: &nbcontext.UserAuth{ + expected: &nbauth.UserAuth{ UserId: "user-id|123", }, }, diff --git a/management/server/context/auth.go b/management/server/context/auth.go index 5cb28ddb7cb..cc59b8a63d0 100644 --- a/management/server/context/auth.go +++ b/management/server/context/auth.go @@ -4,7 +4,8 @@ import ( "context" "fmt" "net/http" - "time" + + "github.com/netbirdio/netbird/shared/auth" ) type key int @@ -13,45 +14,22 @@ const ( UserAuthContextKey key = iota ) -type UserAuth struct { - // The account id the user is accessing - AccountId string - // The account domain - Domain string - // The account domain category, TBC values - DomainCategory string - // Indicates whether this user was invited, TBC logic - Invited bool - // Indicates whether this is a child account - IsChild bool - - // The user id - UserId string - // Last login time for this user - LastLogin time.Time - // The Groups the user belongs to on this account - Groups []string - - // Indicates whether this user has authenticated with a Personal Access Token - IsPAT bool -} - -func GetUserAuthFromRequest(r *http.Request) (UserAuth, error) { +func GetUserAuthFromRequest(r *http.Request) (auth.UserAuth, error) { return GetUserAuthFromContext(r.Context()) } -func SetUserAuthInRequest(r *http.Request, userAuth UserAuth) *http.Request { +func SetUserAuthInRequest(r *http.Request, userAuth auth.UserAuth) *http.Request { return r.WithContext(SetUserAuthInContext(r.Context(), userAuth)) } -func GetUserAuthFromContext(ctx context.Context) (UserAuth, error) { - if userAuth, ok := ctx.Value(UserAuthContextKey).(UserAuth); ok { +func GetUserAuthFromContext(ctx context.Context) (auth.UserAuth, error) { + if userAuth, ok := ctx.Value(UserAuthContextKey).(auth.UserAuth); ok { return userAuth, nil } - return UserAuth{}, fmt.Errorf("user auth not in context") + return auth.UserAuth{}, fmt.Errorf("user auth not in context") } -func SetUserAuthInContext(ctx context.Context, userAuth UserAuth) context.Context { +func SetUserAuthInContext(ctx context.Context, userAuth auth.UserAuth) context.Context { //nolint ctx = context.WithValue(ctx, UserIDKey, userAuth.UserId) //nolint diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 4b9b79fdc18..c5c48ef3210 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -236,7 +237,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, AccountId: accountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 08a0b2afd2d..67638aea5ad 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index 42b519c292e..a027c067e36 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -107,7 +108,7 @@ func TestDNSSettingsHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, AccountId: testingDNSSettingsAccount.Id, Domain: testingDNSSettingsAccount.Domain, diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index d49b6c7e063..4716782f3fa 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -193,7 +194,7 @@ func TestNameserversHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", AccountId: testNSGroupAccountID, Domain: "hotmail.com", diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index a0695fa3fa4..923a24e31e5 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -14,11 +14,12 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -188,7 +189,7 @@ func TestEvents_GetEvents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_account", diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e861e873c1e..208a2e8288f 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -11,10 +11,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns groups of the account diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 34694ec8c4a..b7dd3944a2b 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -19,12 +19,13 @@ import ( "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -122,7 +123,7 @@ func TestGetGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -248,7 +249,7 @@ func TestWriteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -330,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index d7b598a5d7b..f99eca7941f 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -12,15 +12,15 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" - "github.com/netbirdio/netbird/shared/management/status" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" ) // handler is a handler that returns networks of the account diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index 59396dcebb6..c31729a39c7 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -8,10 +8,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type resourceHandler struct { diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 2e64c637ff5..c311a29feb4 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -7,10 +7,10 @@ import ( "github.com/gorilla/mux" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" ) type routersHandler struct { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 94564113f5d..e22c937334b 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -17,9 +17,10 @@ import ( "golang.org/x/exp/maps" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -277,7 +278,7 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "admin_user", Domain: "hotmail.com", AccountId: "test_id", @@ -425,7 +426,7 @@ func TestGetAccessiblePeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", @@ -508,7 +509,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/peers/%s", tc.peerID), bytes.NewBuffer([]byte(tc.requestBody))) req.Header.Set("Content-Type", "application/json") - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: tc.callerUserID, Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index cedd5ac8872..094a36e38f3 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -16,12 +16,13 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/util" ) @@ -113,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -206,7 +207,7 @@ func TestGetAllCountries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index cb699579305..a2d656a4716 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -9,11 +9,11 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" - "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index 4d6bad5e390..ab1639ab146 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index fd39ae2a3bc..ca5a0a6abfb 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) func initPoliciesTestData(policies ...*types.Policy) *handler { @@ -103,7 +104,7 @@ func TestPoliciesGetPolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -267,7 +268,7 @@ func TestPoliciesWritePolicy(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 3ebc4d1e105..744cde10b0e 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" - "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index c644b533a1a..8c60d6fe871 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -16,9 +16,10 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -175,7 +176,7 @@ func TestGetPostureCheck(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", @@ -828,7 +829,7 @@ func TestPostureCheckUpdate(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: "test_id", diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 466a7987f4b..a44d81e3ec8 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" @@ -493,7 +494,7 @@ func TestRoutesHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: "test_user", Domain: "hotmail.com", AccountId: testAccountID, diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 2287dadfe22..d267b6eea2a 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 7b46b486b64..b137b6dd1e5 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -15,10 +15,11 @@ import ( "github.com/stretchr/testify/assert" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -163,7 +164,7 @@ func TestSetupKeysHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: adminUser.Id, Domain: "hotmail.com", AccountId: "testAccountId", diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index bae07af4a58..867db3ca9f7 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -8,10 +8,10 @@ import ( "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 92544c56da0..7cda144686c 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -17,10 +17,11 @@ import ( "github.com/netbirdio/netbird/management/server/util" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -173,7 +174,7 @@ func TestTokenHandlers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index e080042187e..37f0a6c1dc8 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) @@ -128,7 +129,7 @@ func initUsersTestData() *handler { return nil }, - GetCurrentUserInfoFunc: func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { + GetCurrentUserInfoFunc: func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { switch userAuth.UserId { case "not-found": return nil, status.NewUserNotFoundError("not-found") @@ -225,7 +226,7 @@ func TestGetUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -335,7 +336,7 @@ func TestUpdateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -432,7 +433,7 @@ func TestCreateUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -481,7 +482,7 @@ func TestInviteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -540,7 +541,7 @@ func TestDeleteUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) req = mux.SetURLVars(req, tc.requestVars) - req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{ + req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{ UserId: existingUserID, Domain: testDomain, AccountId: existingAccountID, @@ -565,7 +566,7 @@ func TestCurrentUser(t *testing.T) { tt := []struct { name string expectedStatus int - requestAuth nbcontext.UserAuth + requestAuth auth.UserAuth expectedResult *api.User }{ { @@ -574,27 +575,27 @@ func TestCurrentUser(t *testing.T) { }, { name: "user not found", - requestAuth: nbcontext.UserAuth{UserId: "not-found"}, + requestAuth: auth.UserAuth{UserId: "not-found"}, expectedStatus: http.StatusNotFound, }, { name: "not of account", - requestAuth: nbcontext.UserAuth{UserId: "not-of-account"}, + requestAuth: auth.UserAuth{UserId: "not-of-account"}, expectedStatus: http.StatusForbidden, }, { name: "blocked user", - requestAuth: nbcontext.UserAuth{UserId: "blocked-user"}, + requestAuth: auth.UserAuth{UserId: "blocked-user"}, expectedStatus: http.StatusForbidden, }, { name: "service user", - requestAuth: nbcontext.UserAuth{UserId: "service-user"}, + requestAuth: auth.UserAuth{UserId: "service-user"}, expectedStatus: http.StatusForbidden, }, { name: "owner", - requestAuth: nbcontext.UserAuth{UserId: "owner"}, + requestAuth: auth.UserAuth{UserId: "owner"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "owner", @@ -613,7 +614,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "regular user", - requestAuth: nbcontext.UserAuth{UserId: "regular-user"}, + requestAuth: auth.UserAuth{UserId: "regular-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "regular-user", @@ -632,7 +633,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "admin user", - requestAuth: nbcontext.UserAuth{UserId: "admin-user"}, + requestAuth: auth.UserAuth{UserId: "admin-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "admin-user", @@ -651,7 +652,7 @@ func TestCurrentUser(t *testing.T) { }, { name: "restricted user", - requestAuth: nbcontext.UserAuth{UserId: "restricted-user"}, + requestAuth: auth.UserAuth{UserId: "restricted-user"}, expectedStatus: http.StatusOK, expectedResult: &api.User{ Id: "restricted-user", @@ -783,7 +784,7 @@ func TestApproveUserEndpoint(t *testing.T) { req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } @@ -841,7 +842,7 @@ func TestRejectUserEndpoint(t *testing.T) { req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil) require.NoError(t, err) - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ AccountId: existingAccountID, UserId: tc.requestingUser.Id, } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c316f..b23deb01e58 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -10,22 +10,23 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/auth" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) -type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) -type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type EnsureAccountFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) +type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) error -type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager + authManager serverauth.Manager ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc @@ -33,7 +34,7 @@ type AuthMiddleware struct { // NewAuthMiddleware instance constructor func NewAuthMiddleware( - authManager auth.Manager, + authManager serverauth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, @@ -53,18 +54,18 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return } - auth := strings.Split(r.Header.Get("Authorization"), " ") - authType := strings.ToLower(auth[0]) + authHeader := strings.Split(r.Header.Get("Authorization"), " ") + authType := strings.ToLower(authHeader[0]) // fallback to token when receive pat as bearer - if len(auth) >= 2 && authType == "bearer" && strings.HasPrefix(auth[1], "nbp_") { + if len(authHeader) >= 2 && authType == "bearer" && strings.HasPrefix(authHeader[1], "nbp_") { authType = "token" - auth[0] = authType + authHeader[0] = authType } switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, auth) + request, err := m.checkJWTFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) @@ -73,7 +74,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { h.ServeHTTP(w, request) case "token": - request, err := m.checkPATFromRequest(r, auth) + request, err := m.checkPATFromRequest(r, authHeader) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) @@ -88,8 +89,8 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromJWTRequest(auth) +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { @@ -139,8 +140,8 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*http.Request, error) { - token, err := getTokenFromPATRequest(auth) +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { + token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { return r, fmt.Errorf("error extracting token: %w", err) } @@ -159,7 +160,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, err } - userAuth := nbcontext.UserAuth{ + userAuth := auth.UserAuth{ UserId: user.Id, AccountId: user.AccountID, Domain: accDomain, diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f54229a..de1c75f523c 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -12,11 +12,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/auth" - nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + nbauth "github.com/netbirdio/netbird/shared/auth" + nbjwt "github.com/netbirdio/netbird/shared/auth/jwt" ) const ( @@ -61,9 +62,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use return nil, nil, "", "", fmt.Errorf("PAT invalid") } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { +func mockValidateAndParseToken(_ context.Context, token string) (nbauth.UserAuth, *jwt.Token, error) { if token == JWT { - return nbcontext.UserAuth{ + return nbauth.UserAuth{ UserId: userID, AccountId: accountID, Domain: testAccount.Domain, @@ -77,7 +78,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA Valid: true, }, nil } - return nbcontext.UserAuth{}, nil, fmt.Errorf("JWT invalid") + return nbauth.UserAuth{}, nil, fmt.Errorf("JWT invalid") } func mockMarkPATUsed(_ context.Context, token string) error { @@ -87,7 +88,7 @@ func mockMarkPATUsed(_ context.Context, token string) error { return fmt.Errorf("Should never get reached") } -func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbcontext.UserAuth, token *jwt.Token) (nbcontext.UserAuth, error) { +func mockEnsureUserAccessByJWTGroups(_ context.Context, userAuth nbauth.UserAuth, token *jwt.Token) (nbauth.UserAuth, error) { if userAuth.IsChild || userAuth.IsPAT { return userAuth, nil } @@ -183,13 +184,13 @@ func TestAuthMiddleware_Handler(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, ) @@ -226,13 +227,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name string path string authHeader string - expectedUserAuth *nbcontext.UserAuth // nil expects 401 response status + expectedUserAuth *nbauth.UserAuth // nil expects 401 response status }{ { name: "Valid PAT Token", path: "/test", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -244,7 +245,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid PAT Token accesses child", path: "/test?account=xyz", authHeader: "Token " + PAT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -257,7 +258,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token", path: "/test", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: accountID, UserId: userID, Domain: testAccount.Domain, @@ -269,7 +270,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { name: "Valid JWT Token with child", path: "/test?account=xyz", authHeader: "Bearer " + JWT, - expectedUserAuth: &nbcontext.UserAuth{ + expectedUserAuth: &nbauth.UserAuth{ AccountId: "xyz", UserId: userID, Domain: testAccount.Domain, @@ -288,13 +289,13 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { authMiddleware := NewAuthMiddleware( mockAuth, - func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { return userAuth.AccountId, userAuth.UserId, nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) error { + func(ctx context.Context, userAuth nbauth.UserAuth) error { return nil }, - func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, ) diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 804b4a73f7b..8a37a875960 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -14,8 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/auth" - nbcontext "github.com/netbirdio/netbird/management/server/context" + serverauth "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -29,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" ) func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { @@ -69,8 +69,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee } // @note this is required so that PAT's validate from store, but JWT's are mocked - authManager := auth.NewManager(store, "", "", "", "", []string{}, false) - authManagerMock := &auth.MockManager{ + authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) + authManagerMock := &serverauth.MockManager{ ValidateAndParseTokenFunc: mockValidateAndParseToken, EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups, MarkPATUsedFunc: authManager.MarkPATUsed, @@ -115,8 +115,8 @@ func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.Up } } -func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) { - userAuth := nbcontext.UserAuth{} +func mockValidateAndParseToken(_ context.Context, token string) (auth.UserAuth, *jwt.Token, error) { + userAuth := auth.UserAuth{} switch token { case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go index 49075a0d345..126a7691900 100644 --- a/management/server/idp/pocketid_test.go +++ b/management/server/idp/pocketid_test.go @@ -1,138 +1,137 @@ package idp import ( - "context" - "testing" + "context" + "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/telemetry" ) - func TestNewPocketIdManager(t *testing.T) { - type test struct { - name string - inputConfig PocketIdClientConfig - assertErrFunc require.ErrorAssertionFunc - assertErrFuncMessage string - } - - defaultTestConfig := PocketIdClientConfig{ - APIToken: "api_token", - ManagementEndpoint: "http://localhost", - } - - tests := []test{ - { - name: "Good Configuration", - inputConfig: defaultTestConfig, - assertErrFunc: require.NoError, - assertErrFuncMessage: "shouldn't return error", - }, - { - name: "Missing ManagementEndpoint", - inputConfig: PocketIdClientConfig{ - APIToken: defaultTestConfig.APIToken, - ManagementEndpoint: "", - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - { - name: "Missing APIToken", - inputConfig: PocketIdClientConfig{ - APIToken: "", - ManagementEndpoint: defaultTestConfig.ManagementEndpoint, - }, - assertErrFunc: require.Error, - assertErrFuncMessage: "should return error when field empty", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) - tc.assertErrFunc(t, err, tc.assertErrFuncMessage) - }) - } + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } } func TestPocketID_GetUserDataByID(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} - - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client - - md := AppMetadata{WTAccountID: "acc1"} - got, err := mgr.GetUserDataByID(context.Background(), "u1", md) - require.NoError(t, err) - assert.Equal(t, "u1", got.ID) - assert.Equal(t, "user1@example.com", got.Email) - assert.Equal(t, "User One", got.Name) - assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) } func TestPocketID_GetAccount_WithPagination(t *testing.T) { - // Single page response with two users - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client - - users, err := mgr.GetAccount(context.Background(), "accX") - require.NoError(t, err) - require.Len(t, users, 2) - assert.Equal(t, "u1", users[0].ID) - assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) - assert.Equal(t, "u2", users[1].ID) + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) } func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { - client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - accounts, err := mgr.GetAllAccounts(context.Background()) - require.NoError(t, err) - require.Len(t, accounts[UnsetAccountID], 2) + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) } func TestPocketID_CreateUser(t *testing.T) { - client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} - - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client - - ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") - require.NoError(t, err) - assert.Equal(t, "newid", ud.ID) - assert.Equal(t, "new@example.com", ud.Email) - assert.Equal(t, "New User", ud.Name) - assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) - if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { - assert.True(t, *ud.AppMetadata.WTPendingInvite) - } - assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) } func TestPocketID_InviteAndDeleteUser(t *testing.T) { - // Same mock for both calls; returns OK with empty JSON - client := &mockHTTPClient{code: 200, resBody: `{}`} + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} - mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) - require.NoError(t, err) - mgr.httpClient = client + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client - err = mgr.InviteUserByID(context.Background(), "u1") - require.NoError(t, err) + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) - err = mgr.DeleteUser(context.Background(), "u1") - require.NoError(t, err) + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e87043f2642..93fb84e4340 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -2,6 +2,7 @@ package mock_server import ( "context" + "github.com/netbirdio/netbird/shared/auth" "net" "net/netip" "time" @@ -12,7 +13,6 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -34,7 +34,7 @@ type MockAccountManager struct { GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error @@ -84,7 +84,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (string, string, error) DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func(settings *types.Settings) string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) @@ -119,7 +119,7 @@ type MockAccountManager struct { GetStoreFunc func() store.Store UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) @@ -469,7 +469,7 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { +func (am *MockAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { if am.GetUserFromUserAuthFunc != nil { return am.GetUserFromUserAuthFunc(ctx, userAuth) } @@ -674,7 +674,7 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { +func (am *MockAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error) { if am.GetAccountIDFromUserAuthFunc != nil { return am.GetAccountIDFromUserAuthFunc(ctx, userAuth) } @@ -936,7 +936,7 @@ func (am *MockAccountManager) BuildUserInfosForAccount(ctx context.Context, acco return nil, status.Errorf(codes.Unimplemented, "method BuildUserInfosForAccount is not implemented") } -func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error { +func (am *MockAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error { return status.Errorf(codes.Unimplemented, "method SyncUserJWTGroups is not implemented") } @@ -968,7 +968,7 @@ func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string return nil, status.Errorf(codes.Unimplemented, "method GetOwnerInfo is not implemented") } -func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { if am.GetCurrentUserInfoFunc != nil { return am.GetCurrentUserInfoFunc(ctx, userAuth) } diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index c6cec6f7e75..e2dea2c6bde 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -10,8 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be85865..6b8cf94129c 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -8,11 +8,11 @@ import ( "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index 8054d05c6ca..6be90baa7a9 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index 72b15fd9a9a..e90c61a97e1 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -5,8 +5,8 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) type NetworkRouter struct { diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index d65dc50456d..f0bbbc32e10 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,8 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go index 6eb391cb5fd..da29e1d87f4 100644 --- a/management/server/types/route_firewall_rule.go +++ b/management/server/types/route_firewall_rule.go @@ -1,8 +1,8 @@ package types import ( - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // RouteFirewallRule a firewall rule applicable for a routed network. diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index da12f1b707a..a3013df146a 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,9 +7,9 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a8e..ec05bcd0cb6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -7,12 +7,13 @@ import ( "strings" "time" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" - nbContext "github.com/netbirdio/netbird/management/server/context" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -175,9 +176,9 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*t return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } -// GetUser looks up a user by provided nbContext.UserAuths. +// GetUser looks up a user by provided auth.UserAuths. // Expects account to have been created already. -func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { +func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return nil, err @@ -940,7 +941,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou var peerIDs []string for _, peer := range peers { // nolint:staticcheck - ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + ctx = context.WithValue(ctx, nbcontext.PeerIDKey, peer.Key) if peer.UserID == "" { // we do not want to expire peers that are added via setup key @@ -1171,7 +1172,7 @@ func validateUserInvite(invite *types.UserInfo) error { } // GetCurrentUserInfo retrieves the account's current user info and permissions -func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { +func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) { accountID, userID := userAuth.AccountId, userAuth.UserId user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 5920a2a3316..9cbfc16af6c 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -11,12 +11,12 @@ import ( "golang.org/x/exp/maps" nbcache "github.com/netbirdio/netbird/management/server/cache" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -966,7 +966,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { permissionsManager: permissionsManager, } - claims := nbcontext.UserAuth{ + claims := auth.UserAuth{ UserId: mockUserID, AccountId: mockAccountID, } @@ -1573,33 +1573,33 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { tt := []struct { name string - userAuth nbcontext.UserAuth + userAuth auth.UserAuth expectedErr error expectedResult *users.UserInfoWithPermissions }{ { name: "not found", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "not-found"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "not-found"}, expectedErr: status.NewUserNotFoundError("not-found"), }, { name: "not part of account", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account2Owner"}, expectedErr: status.NewUserNotPartOfAccountError(), }, { name: "blocked", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "blocked-user"}, expectedErr: status.NewUserBlockedError(), }, { name: "service user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "service-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "service-user"}, expectedErr: status.NewPermissionDeniedError(), }, { name: "owner user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "account1Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account1Owner", @@ -1619,7 +1619,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "regular user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "regular-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "regular-user", @@ -1638,7 +1638,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "admin user", - userAuth: nbcontext.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, + userAuth: auth.UserAuth{AccountId: account1.Id, UserId: "admin-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "admin-user", @@ -1657,7 +1657,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked regular user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1678,7 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { { name: "settings blocked regular user child account", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "settings-blocked-user", IsChild: true}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "settings-blocked-user", @@ -1698,7 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { }, { name: "settings blocked owner user", - userAuth: nbcontext.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, + userAuth: auth.UserAuth{AccountId: account2.Id, UserId: "account2Owner"}, expectedResult: &users.UserInfoWithPermissions{ UserInfo: &types.UserInfo{ ID: "account2Owner", diff --git a/relay/server/peer.go b/relay/server/peer.go index c47f2e96038..c5ff41857e3 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -9,10 +9,10 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/relay/healthcheck" - "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/relay/metrics" "github.com/netbirdio/netbird/relay/server/store" + "github.com/netbirdio/netbird/shared/relay/healthcheck" + "github.com/netbirdio/netbird/shared/relay/messages" ) const ( diff --git a/management/server/auth/jwt/extractor.go b/shared/auth/jwt/extractor.go similarity index 92% rename from management/server/auth/jwt/extractor.go rename to shared/auth/jwt/extractor.go index d270d0ff173..a41d5f07a60 100644 --- a/management/server/auth/jwt/extractor.go +++ b/shared/auth/jwt/extractor.go @@ -8,7 +8,7 @@ import ( "github.com/golang-jwt/jwt/v5" log "github.com/sirupsen/logrus" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/auth" ) const ( @@ -87,9 +87,10 @@ func (c ClaimsExtractor) audienceClaim(claimName string) string { return url } -func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, error) { +// ToUserAuth extracts user authentication information from a JWT token +func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (auth.UserAuth, error) { claims := token.Claims.(jwt.MapClaims) - userAuth := nbcontext.UserAuth{} + userAuth := auth.UserAuth{} userID, ok := claims[c.userIDClaim].(string) if !ok { @@ -122,6 +123,7 @@ func (c *ClaimsExtractor) ToUserAuth(token *jwt.Token) (nbcontext.UserAuth, erro return userAuth, nil } +// ToGroups extracts group information from a JWT token func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string { claims := token.Claims.(jwt.MapClaims) userJWTGroups := make([]string, 0) diff --git a/management/server/auth/jwt/validator.go b/shared/auth/jwt/validator.go similarity index 100% rename from management/server/auth/jwt/validator.go rename to shared/auth/jwt/validator.go diff --git a/shared/auth/user.go b/shared/auth/user.go new file mode 100644 index 00000000000..c1bae808e20 --- /dev/null +++ b/shared/auth/user.go @@ -0,0 +1,28 @@ +package auth + +import ( + "time" +) + +type UserAuth struct { + // The account id the user is accessing + AccountId string + // The account domain + Domain string + // The account domain category, TBC values + DomainCategory string + // Indicates whether this user was invited, TBC logic + Invited bool + // Indicates whether this is a child account + IsChild bool + + // The user id + UserId string + // Last login time for this user + LastLogin time.Time + // The Groups the user belongs to on this account + Groups []string + + // Indicates whether this user has authenticated with a Personal Access Token + IsPAT bool +} diff --git a/shared/context/keys.go b/shared/context/keys.go index 5345ee214c4..c5b5da044e7 100644 --- a/shared/context/keys.go +++ b/shared/context/keys.go @@ -5,4 +5,4 @@ const ( AccountIDKey = "accountID" UserIDKey = "userID" PeerIDKey = "peerID" -) \ No newline at end of file +) diff --git a/shared/management/operations/operation.go b/shared/management/operations/operation.go index b9b50036254..b1ba128154c 100644 --- a/shared/management/operations/operation.go +++ b/shared/management/operations/operation.go @@ -1,4 +1,4 @@ package operations // Operation represents a permission operation type -type Operation string \ No newline at end of file +type Operation string diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index 967e18d799a..c057ef08960 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -11,8 +11,8 @@ import ( "github.com/quic-go/quic-go" log "github.com/sirupsen/logrus" - quictls "github.com/netbirdio/netbird/shared/relay/tls" nbnet "github.com/netbirdio/netbird/client/net" + quictls "github.com/netbirdio/netbird/shared/relay/tls" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 66fff344773..37b189e05c5 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -14,9 +14,9 @@ import ( "github.com/coder/websocket" log "github.com/sirupsen/logrus" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/constants.go b/shared/relay/constants.go index 3c7c3cd296e..0f2a276103c 100644 --- a/shared/relay/constants.go +++ b/shared/relay/constants.go @@ -3,4 +3,4 @@ package relay const ( // WebSocketURLPath is the path for the websocket relay connection WebSocketURLPath = "/relay" -) \ No newline at end of file +) diff --git a/version/url_windows.go b/version/url_windows.go index 14fdb7ae638..a0fb6e5dd5f 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -6,7 +6,7 @@ import ( ) const ( - urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExe = "https://pkgs.netbird.io/windows/x64" urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) @@ -18,11 +18,11 @@ func DownloadUrl() string { if err != nil { return downloadURL } - + url := urlWinExe if runtime.GOARCH == "arm64" { url = urlWinExeArm } - + return url }