@@ -24,6 +24,16 @@ public class EntraADHelper : IEntraADHelper
2424 private readonly HttpClient _httpClient ;
2525 private readonly AuthenticationHandler _authenticationHandler ;
2626
27+ // Whitelist of valid extension attribute property names
28+ private static readonly HashSet < string > ValidExtensionAttributes = new ( StringComparer . OrdinalIgnoreCase )
29+ {
30+ "ExtensionAttribute1" , "ExtensionAttribute2" , "ExtensionAttribute3" ,
31+ "ExtensionAttribute4" , "ExtensionAttribute5" , "ExtensionAttribute6" ,
32+ "ExtensionAttribute7" , "ExtensionAttribute8" , "ExtensionAttribute9" ,
33+ "ExtensionAttribute10" , "ExtensionAttribute11" , "ExtensionAttribute12" ,
34+ "ExtensionAttribute13" , "ExtensionAttribute14" , "ExtensionAttribute15"
35+ } ;
36+
2737 public EntraADHelper ( ILogger < IEntraADHelper > logger , IHttpClientFactory httpClientFactory , IOptions < EntraADHelperSettings > settings , GraphServiceClient graphServiceClient , AuthenticationHandler authenticationHandler )
2838 {
2939 _graphServiceClient = graphServiceClient ;
@@ -33,6 +43,29 @@ public EntraADHelper(ILogger<IEntraADHelper> logger, IHttpClientFactory httpClie
3343 _authenticationHandler = authenticationHandler ;
3444 }
3545
46+ /// <summary>
47+ /// Normalizes extension attribute name to proper casing (e.g. "extensionattribute1" → "extensionAttribute1")
48+ /// and validates it against the whitelist.
49+ /// </summary>
50+ private string NormalizeExtensionAttributeName ( string extensionAttribute )
51+ {
52+ ArgumentException . ThrowIfNullOrWhiteSpace ( extensionAttribute ) ;
53+
54+ // Find matching attribute from whitelist (case-insensitive)
55+ var match = ValidExtensionAttributes . FirstOrDefault ( a =>
56+ string . Equals ( a , extensionAttribute , StringComparison . OrdinalIgnoreCase ) ) ;
57+
58+ if ( match == null )
59+ {
60+ throw new ArgumentException (
61+ $ "Invalid extension attribute name: '{ extensionAttribute } '. Must be extensionAttribute1 through extensionAttribute15.",
62+ nameof ( extensionAttribute ) ) ;
63+ }
64+
65+ // Return with correct casing for Graph API: "extensionAttribute1" (lowercase 'e', uppercase 'A')
66+ return char . ToLowerInvariant ( match [ 0 ] ) + match [ 1 ..] ;
67+ }
68+
3669
3770 public async Task < IEnumerable < User > > GetUsers ( )
3871 {
@@ -110,7 +143,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
110143 }
111144 catch ( ServiceException ex )
112145 {
113- Console . WriteLine ( $ "An error occurred: { ex . Message } " ) ;
146+ _logger . LogError ( ex , "An error occurred retrieving user {UserId}" , userId ) ;
114147 return null ;
115148 }
116149 }
@@ -124,7 +157,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
124157 }
125158 catch ( ServiceException ex )
126159 {
127- Console . WriteLine ( $ "An error occurred retrieving Device directory object: { ex . Message } " ) ;
160+ _logger . LogError ( ex , "An error occurred retrieving Device directory object for {DeviceId}" , deviceId ) ;
128161 return null ;
129162 }
130163 }
@@ -139,7 +172,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
139172
140173 _logger . LogTrace ( "Building the request to the Microsoft Graph API (beta)" ) ;
141174 var request = new HttpRequestMessage ( HttpMethod . Get ,
142- $ "https://graph.microsoft.com/beta/devices?$filter=displayName eq '{ deviceName } '&$select={ string . Join ( "," , _settings . AttributesToLoad ?? new [ ] { "id" , "deviceId" , "accountEnabled" , "approximateLastSignInDateTime" , "displayName" , "trustType" } ) } ") ;
175+ $ "https://graph.microsoft.com/beta/devices?$filter=displayName eq '{ Uri . EscapeDataString ( deviceName ) } '&$select={ string . Join ( "," , _settings . AttributesToLoad ?? new [ ] { "id" , "deviceId" , "accountEnabled" , "approximateLastSignInDateTime" , "displayName" , "trustType" } ) } ") ;
143176 _logger . LogTrace ( "Request built" ) ;
144177
145178 _logger . LogTrace ( "Adding the access token to the request headers" ) ;
@@ -215,9 +248,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
215248 {
216249 _logger . LogTrace ( "GetExtensionAttribute Called" ) ;
217250
218- string lowerCasedExtensionAttribute = extensionAttribute . ToLowerInvariant ( ) ;
219- string correctCasingAttribute = char . ToUpperInvariant ( lowerCasedExtensionAttribute [ 0 ] ) + lowerCasedExtensionAttribute . Substring ( 1 ) ;
220- correctCasingAttribute = correctCasingAttribute . Substring ( 0 , 9 ) + char . ToUpperInvariant ( extensionAttribute [ 9 ] ) + extensionAttribute . Substring ( 10 ) ;
251+ string correctCasingAttribute = NormalizeExtensionAttributeName ( extensionAttribute ) ;
221252 _logger . LogTrace ( "Extension attribute name {extensionAttribute} converted to {correctCasingAttribute}" , extensionAttribute , correctCasingAttribute ) ;
222253
223254 try
@@ -255,7 +286,10 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
255286 _logger . LogError ( "Extension attributes are null for device {DeviceId}" , deviceId ) ;
256287 return null ;
257288 }
258- var extensionAttributeValue = device . ExtensionAttributes . GetType ( ) . GetProperty ( correctCasingAttribute ) ? . GetValue ( device . ExtensionAttributes , null ) ;
289+
290+ // Use PascalCase for reflection on the model property
291+ var pascalCaseAttr = char . ToUpperInvariant ( correctCasingAttribute [ 0 ] ) + correctCasingAttribute [ 1 ..] ;
292+ var extensionAttributeValue = device . ExtensionAttributes . GetType ( ) . GetProperty ( pascalCaseAttr ) ? . GetValue ( device . ExtensionAttributes , null ) ;
259293 if ( extensionAttributeValue == null )
260294 {
261295 _logger . LogWarning ( "Extension attribute {ExtensionAttribute} is null for device {DeviceId}" , extensionAttribute , deviceId ) ;
@@ -295,9 +329,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
295329
296330 foreach ( string extensionAttribute in extensionAttributes )
297331 {
298- string lowerCasedExtensionAttribute = extensionAttribute . ToLowerInvariant ( ) ;
299- string correctCasingAttribute = char . ToUpperInvariant ( lowerCasedExtensionAttribute [ 0 ] ) + lowerCasedExtensionAttribute . Substring ( 1 ) ;
300- correctCasingAttribute = correctCasingAttribute . Substring ( 0 , 9 ) + char . ToUpperInvariant ( extensionAttribute [ 9 ] ) + extensionAttribute . Substring ( 10 ) ;
332+ string correctCasingAttribute = NormalizeExtensionAttributeName ( extensionAttribute ) ;
301333 _logger . LogTrace ( "Extension attribute name {extensionAttribute} converted to {correctCasingAttribute}" , extensionAttribute , correctCasingAttribute ) ;
302334 correctCasingAttributes . Add ( correctCasingAttribute ) ;
303335 }
@@ -339,7 +371,8 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
339371 }
340372
341373 var extensionAttributeValuesList = device . ExtensionAttributes . GetType ( ) . GetProperties ( )
342- . Where ( p => correctCasingAttributes . Contains ( p . Name ) )
374+ . Where ( p => correctCasingAttributes . Any ( ca =>
375+ string . Equals ( char . ToUpperInvariant ( ca [ 0 ] ) + ca [ 1 ..] , p . Name , StringComparison . Ordinal ) ) )
343376 . Select ( p => p . GetValue ( device . ExtensionAttributes , null ) ? . ToString ( ) )
344377 . ToList ( ) ;
345378
@@ -374,7 +407,7 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
374407 _logger . LogTrace ( "SetExtensionAttributeValue Called" ) ;
375408
376409 var lowerCasedExtensionAttribute = extensionAttributeName . ToLowerInvariant ( ) ;
377- var correctCasingAttribute = lowerCasedExtensionAttribute . Substring ( 0 , 9 ) + char . ToUpperInvariant ( extensionAttributeName [ 9 ] ) + extensionAttributeName . Substring ( 10 ) ;
410+ var correctCasingAttribute = NormalizeExtensionAttributeName ( extensionAttributeName ) ;
378411 _logger . LogTrace ( "Correct casing for extension attribute {extensionAttribute} is {correctCasingAttribute}" , extensionAttributeName , correctCasingAttribute ) ;
379412
380413 try
@@ -393,7 +426,14 @@ public async Task<IEnumerable<User>> GetUsers(int pagingNum)
393426 _logger . LogTrace ( "Access token added to the request headers" ) ;
394427
395428 _logger . LogTrace ( "Adding the request body" ) ;
396- request . Content = new StringContent ( $ "{{\" extensionAttributes\" :{{\" { extensionAttributeName } \" :\" { extensionAttributeValue } \" }}}}", System . Text . Encoding . UTF8 , "application/json" ) ;
429+ var jsonBody = new System . Text . Json . Nodes . JsonObject
430+ {
431+ [ "extensionAttributes" ] = new System . Text . Json . Nodes . JsonObject
432+ {
433+ [ extensionAttributeName ] = extensionAttributeValue
434+ }
435+ } ;
436+ request . Content = new StringContent ( jsonBody . ToJsonString ( ) , System . Text . Encoding . UTF8 , "application/json" ) ;
397437 _logger . LogTrace ( "Request body added" ) ;
398438
399439 _logger . LogTrace ( "Sending the request" ) ;
@@ -815,7 +855,7 @@ public async Task<IEnumerable<string>> GetDeviceHWIdByComputerName(string comput
815855
816856 _logger . LogTrace ( "Building the request to the Microsoft Graph API (beta)" ) ;
817857 var request = new HttpRequestMessage ( HttpMethod . Get ,
818- $ "https://graph.microsoft.com/beta/devices?$filter=displayName eq '{ computerName } '") ;
858+ $ "https://graph.microsoft.com/beta/devices?$filter=displayName eq '{ Uri . EscapeDataString ( computerName ) } '") ;
819859 _logger . LogTrace ( "Request built" ) ;
820860
821861 _logger . LogTrace ( "Adding the access token to the request headers" ) ;
0 commit comments