@@ -397,14 +397,17 @@ def name(cls):
397397 return "ms-sentinel"
398398
399399
400- class AzureMonitorBatchReader (DataSourceReader ):
401- """Reader for Azure Monitor / Log Analytics workspaces."""
400+ class AzureMonitorReader :
401+ """Base reader class for Azure Monitor / Log Analytics workspaces.
402+
403+ Shared read logic for batch and streaming reads.
404+ """
402405
403406 def __init__ (self , options , schema : StructType ):
404407 """Initialize the reader with options and schema.
405408
406409 Args:
407- options: Dictionary of options containing workspace_id, query, time range, credentials
410+ options: Dictionary of options containing workspace_id, query, credentials
408411 schema: StructType schema (provided by DataSource.schema())
409412
410413 """
@@ -415,53 +418,16 @@ def __init__(self, options, schema: StructType):
415418 self .client_id = options .get ("client_id" )
416419 self .client_secret = options .get ("client_secret" )
417420
418- # Time range options (mutually exclusive)
419- timespan = options .get ("timespan" )
420- start_time = options .get ("start_time" )
421- end_time = options .get ("end_time" )
422-
423- # Optional options
424- self .num_partitions = int (options .get ("num_partitions" , "1" ))
425-
426421 # Validate required options
427422 assert self .workspace_id is not None , "workspace_id is required"
428423 assert self .query is not None , "query is required"
429424 assert self .tenant_id is not None , "tenant_id is required"
430425 assert self .client_id is not None , "client_id is required"
431426 assert self .client_secret is not None , "client_secret is required"
432427
433- # Parse time range using module-level function
434- self .start_time , self .end_time = _parse_time_range (timespan = timespan , start_time = start_time , end_time = end_time )
435-
436428 # Store schema (provided by DataSource.schema())
437429 self ._schema = schema
438430
439- def partitions (self ):
440- """Generate list of non-overlapping time range partitions.
441-
442- Returns:
443- List of TimeRangePartition objects, each containing start_time and end_time
444-
445- """
446- # Calculate total time range duration
447- total_duration = self .end_time - self .start_time
448-
449- # Split into N equal partitions
450- partition_duration = total_duration / self .num_partitions
451-
452- partitions = []
453- for i in range (self .num_partitions ):
454- partition_start = self .start_time + (partition_duration * i )
455- partition_end = self .start_time + (partition_duration * (i + 1 ))
456-
457- # Ensure last partition ends exactly at end_time (avoid rounding errors)
458- if i == self .num_partitions - 1 :
459- partition_end = self .end_time
460-
461- partitions .append (TimeRangePartition (partition_start , partition_end ))
462-
463- return partitions
464-
465431 def read (self , partition : TimeRangePartition ):
466432 """Read data for the given partition time range.
467433
@@ -522,6 +488,57 @@ def read(self, partition: TimeRangePartition):
522488 yield Row (** row_dict )
523489
524490
491+ class AzureMonitorBatchReader (AzureMonitorReader , DataSourceReader ):
492+ """Batch reader for Azure Monitor / Log Analytics workspaces."""
493+
494+ def __init__ (self , options , schema : StructType ):
495+ """Initialize the batch reader with options and schema.
496+
497+ Args:
498+ options: Dictionary of options containing workspace_id, query, time range, credentials
499+ schema: StructType schema (provided by DataSource.schema())
500+
501+ """
502+ super ().__init__ (options , schema )
503+
504+ # Time range options (mutually exclusive)
505+ timespan = options .get ("timespan" )
506+ start_time = options .get ("start_time" )
507+ end_time = options .get ("end_time" )
508+
509+ # Optional options
510+ self .num_partitions = int (options .get ("num_partitions" , "1" ))
511+
512+ # Parse time range using module-level function
513+ self .start_time , self .end_time = _parse_time_range (timespan = timespan , start_time = start_time , end_time = end_time )
514+
515+ def partitions (self ):
516+ """Generate list of non-overlapping time range partitions.
517+
518+ Returns:
519+ List of TimeRangePartition objects, each containing start_time and end_time
520+
521+ """
522+ # Calculate total time range duration
523+ total_duration = self .end_time - self .start_time
524+
525+ # Split into N equal partitions
526+ partition_duration = total_duration / self .num_partitions
527+
528+ partitions = []
529+ for i in range (self .num_partitions ):
530+ partition_start = self .start_time + (partition_duration * i )
531+ partition_end = self .start_time + (partition_duration * (i + 1 ))
532+
533+ # Ensure last partition ends exactly at end_time (avoid rounding errors)
534+ if i == self .num_partitions - 1 :
535+ partition_end = self .end_time
536+
537+ partitions .append (TimeRangePartition (partition_start , partition_end ))
538+
539+ return partitions
540+
541+
525542class AzureMonitorOffset :
526543 """Represents the offset for Azure Monitor streaming.
527544
@@ -565,7 +582,7 @@ def from_json(json_str: str):
565582 return AzureMonitorOffset (data ["timestamp" ])
566583
567584
568- class AzureMonitorStreamReader (DataSourceStreamReader ):
585+ class AzureMonitorStreamReader (AzureMonitorReader , DataSourceStreamReader ):
569586 """Stream reader for Azure Monitor / Log Analytics workspaces.
570587
571588 Implements incremental streaming by tracking time-based offsets and splitting
@@ -580,13 +597,8 @@ def __init__(self, options, schema: StructType):
580597 schema: StructType schema (provided by DataSource.schema())
581598
582599 """
583- # Extract and validate required options
584- self .workspace_id = options .get ("workspace_id" )
585- self .query = options .get ("query" )
586- self .tenant_id = options .get ("tenant_id" )
587- self .client_id = options .get ("client_id" )
588- self .client_secret = options .get ("client_secret" )
589-
600+ super ().__init__ (options , schema )
601+
590602 # Stream-specific options
591603 start_time = options .get ("start_time" , "latest" )
592604 # Support 'latest' as alias for current timestamp
@@ -609,15 +621,6 @@ def __init__(self, options, schema: StructType):
609621 # Partition duration in seconds (default 1 hour)
610622 self .partition_duration = int (options .get ("partition_duration" , "3600" ))
611623
612- # Validate required options
613- assert self .workspace_id is not None , "workspace_id is required"
614- assert self .query is not None , "query is required"
615- assert self .tenant_id is not None , "tenant_id is required"
616- assert self .client_id is not None , "client_id is required"
617- assert self .client_secret is not None , "client_secret is required"
618-
619- self ._schema = schema
620-
621624 def initialOffset (self ):
622625 """Return the initial offset (start time).
623626
@@ -682,63 +685,6 @@ def partitions(self, start, end):
682685
683686 return partitions
684687
685- def read (self , partition : TimeRangePartition ):
686- """Read data for the given partition time range.
687-
688- Args:
689- partition: TimeRangePartition containing start_time and end_time
690-
691- Yields:
692- Row objects from the query results
693-
694- """
695- # Import inside method for partition-level execution
696- from pyspark .sql import Row
697-
698- # Use partition's time range
699- timespan_value = (partition .start_time , partition .end_time )
700-
701- # Execute query using module-level function
702- response = _execute_logs_query (
703- workspace_id = self .workspace_id ,
704- query = self .query ,
705- timespan = timespan_value ,
706- tenant_id = self .tenant_id ,
707- client_id = self .client_id ,
708- client_secret = self .client_secret ,
709- )
710-
711- # Create a mapping of column names to their expected types from schema
712- schema_field_map = {field .name : field .dataType for field in self ._schema .fields }
713-
714- # Process all tables in response (reuse same logic as batch reader)
715- for table in response .tables :
716- # Convert Azure Monitor rows to Spark Rows
717- for row_idx , row_data in enumerate (table .rows ):
718- row_dict = {}
719-
720- # First, process columns from the query results
721- for i , col in enumerate (table .columns ):
722- # Handle both string columns (real API) and objects with .name attribute (test mocks)
723- column_name = str (col ) if isinstance (col , str ) else str (col .name )
724- raw_value = row_data [i ]
725-
726- # If column is in schema, convert to expected type
727- if column_name in schema_field_map :
728- expected_type = schema_field_map [column_name ]
729- try :
730- converted_value = _convert_value_to_schema_type (raw_value , expected_type )
731- row_dict [column_name ] = converted_value
732- except ValueError as e :
733- raise ValueError (f"Row { row_idx } , column '{ column_name } ': { e } " )
734-
735- # Second, add NULL values for schema columns that are not in query results
736- for schema_column_name in schema_field_map .keys ():
737- if schema_column_name not in row_dict :
738- row_dict [schema_column_name ] = None
739-
740- yield Row (** row_dict )
741-
742688 def commit (self , end ):
743689 """Called when a batch is successfully processed.
744690
0 commit comments