diff --git a/dbscripts/postgres.sql b/dbscripts/postgres.sql index c588fbc..a4af444 100644 --- a/dbscripts/postgres.sql +++ b/dbscripts/postgres.sql @@ -60,8 +60,9 @@ CREATE TABLE profile_schema ( mutability VARCHAR(255) NOT NULL, multi_valued BOOLEAN DEFAULT FALSE, canonical_values JSONB DEFAULT '[]'::jsonb, - sub_attributes JSONB DEFAULT '[]'::jsonb + sub_attributes JSONB DEFAULT '[]'::jsonb, scim_dialect VARCHAR(255) , + mapped_local_claim VARCHAR(255), ); -- Application Data Table diff --git a/internal/profile/handler/profile_handler.go b/internal/profile/handler/profile_handler.go index 33728e4..fff2bcf 100644 --- a/internal/profile/handler/profile_handler.go +++ b/internal/profile/handler/profile_handler.go @@ -24,6 +24,7 @@ import ( "github.com/wso2/identity-customer-data-service/internal/profile/model" "github.com/wso2/identity-customer-data-service/internal/profile/provider" "github.com/wso2/identity-customer-data-service/internal/profile/service" + schemaService "github.com/wso2/identity-customer-data-service/internal/profile_schema/service" "github.com/wso2/identity-customer-data-service/internal/system/authn" "github.com/wso2/identity-customer-data-service/internal/system/constants" errors2 "github.com/wso2/identity-customer-data-service/internal/system/errors" @@ -325,7 +326,6 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. if profileSync.Event == "POST_ADD_USER" { if profileSync.ProfileId != "" && profileSync.UserId != "" { - log.GetLogger().Info("wewwdscfdsvgf????") // This sceario is when the user anonymously tried and then trying to signup or login. So profile with profile id exists existingProfile, err = profilesService.GetProfile(profileSync.ProfileId) @@ -336,7 +336,11 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. } for claimURI, value := range identityClaims { - attributeKeyPath := extractClaimKeyFromLocalURI(claimURI) + attributeKeyPath, err := extractAttributePathFromLocalURI(tenantId, claimURI) + if err != nil { + utils.HandleError(writer, fmt.Errorf("failed to extract attribute path from local URI: %w", err)) + return + } setNestedMapValue(existingProfile.IdentityAttributes, attributeKeyPath, value) } @@ -357,13 +361,16 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. } return } else if profileSync.ProfileId == "" { - log.GetLogger().Info("am i herere????") // this is when we create a profile for a new user created in IS existingProfile, err = profilesService.FindProfileByUserId(profileSync.UserId) if existingProfile == nil { identityAttributes := make(map[string]interface{}) for claimURI, value := range identityClaims { - attributeKeyPath := extractClaimKeyFromLocalURI(claimURI) + attributeKeyPath, err := extractAttributePathFromLocalURI(tenantId, claimURI) + if err != nil { + utils.HandleError(writer, fmt.Errorf("failed to extract attribute path from local URI: %w", err)) + return + } setNestedMapValue(identityAttributes, attributeKeyPath, value) } @@ -408,7 +415,11 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. identityAttributes := make(map[string]interface{}) for claimURI, value := range identityClaims { - attributeKeyPath := extractClaimKeyFromLocalURI(claimURI) + attributeKeyPath, err := extractAttributePathFromLocalURI(tenantId, claimURI) + if err != nil { + utils.HandleError(writer, fmt.Errorf("failed to extract attribute path from local URI: %w", err)) + return + } setNestedMapValue(identityAttributes, attributeKeyPath, value) } @@ -432,7 +443,11 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. } for claimURI, value := range identityClaims { - attributeKeyPath := extractClaimKeyFromLocalURI(claimURI) + attributeKeyPath, err := extractAttributePathFromLocalURI(tenantId, claimURI) + if err != nil { + utils.HandleError(writer, fmt.Errorf("failed to extract attribute path from local URI: %w", err)) + return + } setNestedMapValue(existingProfile.IdentityAttributes, attributeKeyPath, value) } @@ -458,7 +473,11 @@ func (ph *ProfileHandler) SyncProfile(writer http.ResponseWriter, request *http. } for claimURI, value := range identityClaims { - attributeKeyPath := extractClaimKeyFromLocalURI(claimURI) + attributeKeyPath, err := extractAttributePathFromLocalURI(tenantId, claimURI) + if err != nil { + utils.HandleError(writer, fmt.Errorf("failed to extract attribute path from local URI: %w", err)) + return + } setNestedMapValue(existingProfile.IdentityAttributes, attributeKeyPath, value) } @@ -501,7 +520,17 @@ func setNestedMapValue(m map[string]interface{}, path string, value interface{}) // todo: ensure the value type and also try how we merge the values here. } -func extractClaimKeyFromLocalURI(localURI string) string { - parts := strings.Split(localURI, "/") - return parts[len(parts)-1] +// extractAttributePathFromLocalURI extracts the claim key from a local URI. +func extractAttributePathFromLocalURI(tenantId, localURI string) (string, error) { + + profileSchemaService := schemaService.GetProfileSchemaService() + claim, err := profileSchemaService.GetProfileSchemaAttributeByMappedLocalClaim(tenantId, localURI) + if err != nil { + return "", err + } + if claim.AttributeId == "" { + return "", fmt.Errorf("claim not found for local URI: %s", localURI) + } + key := strings.TrimPrefix(claim.AttributeName, "identity_attributes.") + return key, nil } diff --git a/internal/profile/service/profile_service.go b/internal/profile/service/profile_service.go index cff3ca0..706197c 100644 --- a/internal/profile/service/profile_service.go +++ b/internal/profile/service/profile_service.go @@ -169,13 +169,13 @@ func ValidateProfileAgainstSchema(profile profileModel.ProfileRequest, existingP return clientError } if isUpdate && existingProfile.IdentityAttributes != nil { - if !(attr.AttributeName == "identity_attributes.modified" || attr.AttributeName == "identity_attributes.created" || attr.AttributeName == "identity_attributes.userid") { + if !(attr.AttributeName == "identity_attributes.meta.lastModified" || attr.AttributeName == "identity_attributes.meta.created" || attr.AttributeName == "identity_attributes.id") { if err := validateMutability(attr.Mutability, isUpdate, existingProfile.IdentityAttributes[key], val); err != nil { return err } } } else { - if !(attr.AttributeName == "identity_attributes.modified" || attr.AttributeName == "identity_attributes.created" || attr.AttributeName == "identity_attributes.userid") { + if !(attr.AttributeName == "identity_attributes.meta.lastModified" || attr.AttributeName == "identity_attributes.meta.created" || attr.AttributeName == "identity_attributes.id") { if err := validateMutability(attr.Mutability, isUpdate, nil, val); err != nil { return err } diff --git a/internal/profile_schema/model/profile_schema.go b/internal/profile_schema/model/profile_schema.go index 3692ba8..642fd81 100644 --- a/internal/profile_schema/model/profile_schema.go +++ b/internal/profile_schema/model/profile_schema.go @@ -26,10 +26,11 @@ type ProfileSchemaAttribute struct { MergeStrategy string `json:"merge_strategy" bson:"merge_strategy" binding:"required"` Mutability string `json:"mutability" bson:"mutability"` ApplicationIdentifier string `json:"application_identifier,omitempty" bson:"application_identifier,omitempty"` - MultiValued bool `json:"multi_valued,omitempty" bson:"multi_valued,omitempty"` // Means the data type is an array of chosen data type - CanonicalValues []CanonicalValue `json:"canonical_values,omitempty" bson:"canonical_values,omitempty"` // String of options for the attribute - SubAttributes []SubAttribute `json:"sub_attributes,omitempty" bson:"sub_attributes,omitempty"` // If the datatype is object - SCIMDialect string `json:"scim_dialect,omitempty" bson:"scim_dialect,omitempty"` // Need to skip this in the response + MultiValued bool `json:"multi_valued,omitempty" bson:"multi_valued,omitempty"` // Means the data type is an array of chosen data type + CanonicalValues []CanonicalValue `json:"canonical_values,omitempty" bson:"canonical_values,omitempty"` // String of options for the attribute + SubAttributes []SubAttribute `json:"sub_attributes,omitempty" bson:"sub_attributes,omitempty"` // If the datatype is object + SCIMDialect string `json:"scim_dialect,omitempty" bson:"scim_dialect,omitempty"` // Need to skip this in the response + MappedLocalClaim string `json:"mapped_local_claim,omitempty" bson:"mapped_local_claim,omitempty"` // Local claims mapped to this attribute } type SubAttribute struct { diff --git a/internal/profile_schema/service/profile_schema_service.go b/internal/profile_schema/service/profile_schema_service.go index db4e7d1..a8a5144 100644 --- a/internal/profile_schema/service/profile_schema_service.go +++ b/internal/profile_schema/service/profile_schema_service.go @@ -40,6 +40,7 @@ type ProfileSchemaServiceInterface interface { GetProfileSchemaAttributesByScopeAndFilter(id, scope string, filters []string) (interface{}, error) DeleteProfileSchemaAttributesByScope(orgId, scope string) error GetProfileSchemaAttributeById(orgId, attributeId string) (model.ProfileSchemaAttribute, error) + GetProfileSchemaAttributeByMappedLocalClaim(orgId, mappedLocalClaim string) (model.ProfileSchemaAttribute, error) PatchProfileSchemaAttributeById(orgId, attributeId string, updates map[string]interface{}) error DeleteProfileSchemaAttributeById(orgId, attributeId string) error SyncProfileSchema(orgId string) error @@ -578,3 +579,9 @@ func (s *ProfileSchemaService) GetProfileSchemaAttributesByScopeAndFilter(orgId, } return schemaAttributes, nil } + +func (s *ProfileSchemaService) GetProfileSchemaAttributeByMappedLocalClaim(orgId, mappedLocalClaim string) (model.ProfileSchemaAttribute, error) { + + return psstr.GetProfileSchemaAttributeByMappedLocalClaim(orgId, mappedLocalClaim) + +} diff --git a/internal/profile_schema/store/profile_schema_store.go b/internal/profile_schema/store/profile_schema_store.go index ea49143..6daade3 100644 --- a/internal/profile_schema/store/profile_schema_store.go +++ b/internal/profile_schema/store/profile_schema_store.go @@ -635,9 +635,9 @@ func UpsertIdentityAttributes(orgID string, attrs []model.ProfileSchemaAttribute attrKey := extractClaimKeyFromURI(attr.AttributeName) attr.AttributeName = attrKey - valueStrings = append(valueStrings, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d, $%d)", + valueStrings = append(valueStrings, fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d,$%d, $%d,$%d)", argIndex, argIndex+1, argIndex+2, argIndex+3, argIndex+4, argIndex+5, argIndex+6, - argIndex+7, argIndex+8, argIndex+9, argIndex+10, argIndex+11)) + argIndex+7, argIndex+8, argIndex+9, argIndex+10, argIndex+11, argIndex+12)) valueArgs = append(valueArgs, orgID, attr.AttributeId, @@ -650,9 +650,10 @@ func UpsertIdentityAttributes(orgID string, attrs []model.ProfileSchemaAttribute string(canonicalJSON), string(subAttrJSON), attr.SCIMDialect, + attr.MappedLocalClaim, constants.IdentityAttributes, ) - argIndex += 12 + argIndex += 13 } insertQuery += strings.Join(valueStrings, ",") @@ -745,3 +746,45 @@ func GetProfileSchemaAttributesByScopeAndFilter(orgId, scope string, filters []s return attributes, nil } + +func GetProfileSchemaAttributeByMappedLocalClaim(orgId string, claim string) (model.ProfileSchemaAttribute, error) { + + dbClient, err := provider.NewDBProvider().GetDBClient() + logger := log.GetLogger() + if err != nil { + errorMsg := fmt.Sprintf("Error occurred while fetching profile schema for org: %s and mapped claim: %s", + orgId, claim) + logger.Debug(errorMsg, log.Error(err)) + serverError := errors.NewServerError(errors.ErrorMessage{ + Code: errors.DB_CLIENT_INIT.Code, + Message: errors.DB_CLIENT_INIT.Message, + Description: errorMsg, + }, err) + return model.ProfileSchemaAttribute{}, serverError + } + defer dbClient.Close() + + query := scripts.GetProfileSchemaAttributeByMappedLocalClaim[provider.NewDBProvider().GetDBType()] + + results, err := dbClient.ExecuteQuery(query, orgId, claim) + if err != nil { + errorMsg := fmt.Sprintf("Error occurred while fetching profile schema for the org:%s", orgId) + logger.Debug(errorMsg, log.Error(err)) + serverError := errors.NewServerError(errors.ErrorMessage{ + Code: errors.GET_PROFILE_SCHEMA.Code, + Message: errors.GET_PROFILE_SCHEMA.Message, + Description: errorMsg, + }, err) + return model.ProfileSchemaAttribute{}, serverError + } + if len(results) == 0 { + clientError := errors.NewClientError(errors.ErrorMessage{ + Code: errors.ATTRIBUTE_NOT_FOUND.Code, + Message: errors.ATTRIBUTE_NOT_FOUND.Message, + Description: "Profile schema attribute not found for org: " + orgId + " and mapped claim : " + claim, + }, http.StatusNotFound) + return model.ProfileSchemaAttribute{}, clientError + } + row := results[0] + return mapRowToProfileAttribute(row), nil +} diff --git a/internal/system/client/identity_client.go b/internal/system/client/identity_client.go index e138455..bff1dc4 100644 --- a/internal/system/client/identity_client.go +++ b/internal/system/client/identity_client.go @@ -125,6 +125,7 @@ func (c *IdentityClient) fetchClientCredentialsToken() (map[string]interface{}, return result, nil } +// GetProfileSchema fetches the profile schema attributes from the identity server func (c *IdentityClient) GetProfileSchema(orgId string) ([]model.ProfileSchemaAttribute, error) { logger := log.GetLogger() @@ -171,7 +172,7 @@ func (c *IdentityClient) GetProfileSchema(orgId string) ([]model.ProfileSchemaAt continue } - attr, subAttr, parent := ConvertSCIMClaimWithLocal(scimClaim, localClaim, claims, orgId, dialectURI) + attr, subAttr, parent := ConvertSCIMClaimWithLocal(scimClaim, localClaim, claims, orgId) result = append(result, attr) existingAttrs[attr.AttributeName] = true @@ -194,14 +195,15 @@ func (c *IdentityClient) GetProfileSchema(orgId string) ([]model.ProfileSchemaAt dialect = "urn:synthetic" // fallback } result = append(result, model.ProfileSchemaAttribute{ - OrgId: orgId, - AttributeId: uuid.New().String(), - AttributeName: parent, - ValueType: constants.ComplexDataType, - MergeStrategy: constants.MergeStrategyOverwrite, - Mutability: constants.MutabilityReadWrite, - SubAttributes: subs, - SCIMDialect: dialect, // mark as generated + OrgId: orgId, + AttributeId: uuid.New().String(), + AttributeName: parent, + ValueType: constants.ComplexDataType, + MergeStrategy: constants.MergeStrategyOverwrite, + Mutability: constants.MutabilityReadWrite, + SubAttributes: subs, + SCIMDialect: dialect, // mark as generated + MappedLocalClaim: "", // synthetic, no local mapping }) } } @@ -278,12 +280,12 @@ func ConvertSCIMClaimWithLocal( scim map[string]interface{}, local map[string]interface{}, allClaims []map[string]interface{}, - orgId, dialectURI string, + orgId string, ) (model.ProfileSchemaAttribute, *model.SubAttribute, string) { - claimURI := fmt.Sprintf("%v", scim["claimURI"]) localURI := fmt.Sprintf("%v", scim["mappedLocalClaimURI"]) - attrKey := extractClaimKeyFromLocalURI(localURI) + claimURI := fmt.Sprintf("%v", scim["claimURI"]) + attrKey := extractClaimKeyFromSCIMURI(claimURI) readOnly := false multiValued := false @@ -324,23 +326,11 @@ func ConvertSCIMClaimWithLocal( var subAttrs []model.SubAttribute for _, otherClaim := range allClaims { otherURI := fmt.Sprintf("%v", otherClaim["claimURI"]) - if strings.HasPrefix(otherURI, claimURI+".") { - mappedLocalURI := fmt.Sprintf("%v", otherClaim["mappedLocalClaimURI"]) - - // Ensure mapped local URI is truly nested under the current local URI - if strings.HasPrefix(mappedLocalURI, localURI+".") { - subAttrKey := extractClaimKeyFromLocalURI(mappedLocalURI) - - if strings.HasPrefix(subAttrKey, attrKey+".") { - subAttrKey = strings.TrimPrefix(subAttrKey, attrKey+".") - } - - subAttrs = append(subAttrs, model.SubAttribute{ - AttributeId: fmt.Sprintf("%v", uuid.New().String()), - AttributeName: "identity_attributes." + attrKey + "." + subAttrKey, - }) - } + subAttrs = append(subAttrs, model.SubAttribute{ + AttributeId: uuid.New().String(), + AttributeName: "identity_attributes." + extractClaimKeyFromSCIMURI(otherURI), + }) } } @@ -349,7 +339,7 @@ func ConvertSCIMClaimWithLocal( valueType = "complex" } - fullAttrName := "identity_attributes." + attrKey + fullAttrName := "identity_attributes." + convertSCIMURIToAttributeName(claimURI) // Check if this is a sub-attribute (i.e., contains a dot after the prefix) if strings.Contains(attrKey, ".") { @@ -359,31 +349,33 @@ func ConvertSCIMClaimWithLocal( AttributeName: fullAttrName, } return model.ProfileSchemaAttribute{ - OrgId: orgId, - AttributeId: subAttr.AttributeId, - AttributeName: fullAttrName, - ValueType: valueType, - MergeStrategy: "overwrite", - Mutability: ifThenElse(readOnly, "readOnly", "readWrite"), - MultiValued: multiValued, - CanonicalValues: canonicalValues, - SubAttributes: nil, - SCIMDialect: dialectURI, + OrgId: orgId, + AttributeId: subAttr.AttributeId, + AttributeName: fullAttrName, + ValueType: valueType, + MergeStrategy: "overwrite", + Mutability: ifThenElse(readOnly, "readOnly", "readWrite"), + MultiValued: multiValued, + CanonicalValues: canonicalValues, + SubAttributes: nil, + SCIMDialect: claimURI, + MappedLocalClaim: localURI, }, &subAttr, parentAttrName } // It's a top-level or parent attribute return model.ProfileSchemaAttribute{ - OrgId: orgId, - AttributeId: fmt.Sprintf("%v", uuid.New().String()), - AttributeName: fullAttrName, - ValueType: valueType, - MergeStrategy: "overwrite", - Mutability: ifThenElse(readOnly, "readOnly", "readWrite"), - MultiValued: multiValued, - CanonicalValues: canonicalValues, - SubAttributes: subAttrs, - SCIMDialect: dialectURI, + OrgId: orgId, + AttributeId: fmt.Sprintf("%v", uuid.New().String()), + AttributeName: fullAttrName, + ValueType: valueType, + MergeStrategy: "overwrite", + Mutability: ifThenElse(readOnly, "readOnly", "readWrite"), + MultiValued: multiValued, + CanonicalValues: canonicalValues, + SubAttributes: subAttrs, + SCIMDialect: claimURI, + MappedLocalClaim: localURI, }, nil, "" } @@ -432,3 +424,16 @@ func flattenSCIMClaims(user map[string]interface{}) map[string]interface{} { return flat } + +func extractClaimKeyFromSCIMURI(scimURI string) string { + parts := strings.Split(scimURI, ":") + lastPart := parts[len(parts)-1] + return strings.TrimPrefix(lastPart, "schemas:") // removes redundant prefixes if present +} + +func convertSCIMURIToAttributeName(uri string) string { + // E.g., "urn:scim:schemas:core:2.0:name.givenName" → "name.givenName" + parts := strings.Split(uri, ":") + last := parts[len(parts)-1] + return last +} diff --git a/internal/system/database/scripts/queries.go b/internal/system/database/scripts/queries.go index 88adb1c..376ae5e 100644 --- a/internal/system/database/scripts/queries.go +++ b/internal/system/database/scripts/queries.go @@ -35,7 +35,7 @@ var DeleteIdentityClaimsOfProfileSchema = map[string]string{ var InsertIdentityClaimsForProfileSchema = map[string]string{ "postgres": `INSERT INTO profile_schema (tenant_id, attribute_id, attribute_name, value_type, merge_strategy, mutability, application_identifier, - multi_valued, canonical_values, sub_attributes, scim_dialect, scope) VALUES `, + multi_valued, canonical_values, sub_attributes, scim_dialect, mapped_local_claim, scope) VALUES `, } var GetProfileSchemaAttributeByName = map[string]string{ @@ -78,6 +78,12 @@ var GetProfileSchemaAttributeById = map[string]string{ FROM profile_schema WHERE tenant_id = $1 AND attribute_id = $2`, } +var GetProfileSchemaAttributeByMappedLocalClaim = map[string]string{ + "postgres": `SELECT attribute_id, attribute_name, value_type, merge_strategy, mutability , application_identifier, multi_valued, sub_attributes::text, + canonical_values::text + FROM profile_schema WHERE tenant_id = $1 AND mapped_local_claim = $2`, +} + var FilterProfileSchemaAttributes = map[string]string{ "postgres": `SELECT attribute_id, tenant_id, attribute_name, value_type, merge_strategy, mutability, application_identifier, multi_valued, sub_attributes::text, canonical_values::text FROM profile_schema WHERE tenant_id = $1`,