Skip to content

Commit 791207e

Browse files
committed
Refactor the code for common classes
1 parent c73b897 commit 791207e

File tree

1 file changed

+60
-114
lines changed

1 file changed

+60
-114
lines changed

cyber_connectors/MsSentinel.py

Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,17 @@ def name(cls):
397397
return "ms-sentinel"
398398

399399

400-
class AzureMonitorBatchReader(DataSourceReader):
401-
"""Reader for Azure Monitor / Log Analytics workspaces."""
400+
class AzureMonitorReader:
401+
"""Base reader class for Azure Monitor / Log Analytics workspaces.
402+
403+
Shared read logic for batch and streaming reads.
404+
"""
402405

403406
def __init__(self, options, schema: StructType):
404407
"""Initialize the reader with options and schema.
405408
406409
Args:
407-
options: Dictionary of options containing workspace_id, query, time range, credentials
410+
options: Dictionary of options containing workspace_id, query, credentials
408411
schema: StructType schema (provided by DataSource.schema())
409412
410413
"""
@@ -415,53 +418,16 @@ def __init__(self, options, schema: StructType):
415418
self.client_id = options.get("client_id")
416419
self.client_secret = options.get("client_secret")
417420

418-
# Time range options (mutually exclusive)
419-
timespan = options.get("timespan")
420-
start_time = options.get("start_time")
421-
end_time = options.get("end_time")
422-
423-
# Optional options
424-
self.num_partitions = int(options.get("num_partitions", "1"))
425-
426421
# Validate required options
427422
assert self.workspace_id is not None, "workspace_id is required"
428423
assert self.query is not None, "query is required"
429424
assert self.tenant_id is not None, "tenant_id is required"
430425
assert self.client_id is not None, "client_id is required"
431426
assert self.client_secret is not None, "client_secret is required"
432427

433-
# Parse time range using module-level function
434-
self.start_time, self.end_time = _parse_time_range(timespan=timespan, start_time=start_time, end_time=end_time)
435-
436428
# Store schema (provided by DataSource.schema())
437429
self._schema = schema
438430

439-
def partitions(self):
440-
"""Generate list of non-overlapping time range partitions.
441-
442-
Returns:
443-
List of TimeRangePartition objects, each containing start_time and end_time
444-
445-
"""
446-
# Calculate total time range duration
447-
total_duration = self.end_time - self.start_time
448-
449-
# Split into N equal partitions
450-
partition_duration = total_duration / self.num_partitions
451-
452-
partitions = []
453-
for i in range(self.num_partitions):
454-
partition_start = self.start_time + (partition_duration * i)
455-
partition_end = self.start_time + (partition_duration * (i + 1))
456-
457-
# Ensure last partition ends exactly at end_time (avoid rounding errors)
458-
if i == self.num_partitions - 1:
459-
partition_end = self.end_time
460-
461-
partitions.append(TimeRangePartition(partition_start, partition_end))
462-
463-
return partitions
464-
465431
def read(self, partition: TimeRangePartition):
466432
"""Read data for the given partition time range.
467433
@@ -522,6 +488,57 @@ def read(self, partition: TimeRangePartition):
522488
yield Row(**row_dict)
523489

524490

491+
class AzureMonitorBatchReader(AzureMonitorReader, DataSourceReader):
492+
"""Batch reader for Azure Monitor / Log Analytics workspaces."""
493+
494+
def __init__(self, options, schema: StructType):
495+
"""Initialize the batch reader with options and schema.
496+
497+
Args:
498+
options: Dictionary of options containing workspace_id, query, time range, credentials
499+
schema: StructType schema (provided by DataSource.schema())
500+
501+
"""
502+
super().__init__(options, schema)
503+
504+
# Time range options (mutually exclusive)
505+
timespan = options.get("timespan")
506+
start_time = options.get("start_time")
507+
end_time = options.get("end_time")
508+
509+
# Optional options
510+
self.num_partitions = int(options.get("num_partitions", "1"))
511+
512+
# Parse time range using module-level function
513+
self.start_time, self.end_time = _parse_time_range(timespan=timespan, start_time=start_time, end_time=end_time)
514+
515+
def partitions(self):
516+
"""Generate list of non-overlapping time range partitions.
517+
518+
Returns:
519+
List of TimeRangePartition objects, each containing start_time and end_time
520+
521+
"""
522+
# Calculate total time range duration
523+
total_duration = self.end_time - self.start_time
524+
525+
# Split into N equal partitions
526+
partition_duration = total_duration / self.num_partitions
527+
528+
partitions = []
529+
for i in range(self.num_partitions):
530+
partition_start = self.start_time + (partition_duration * i)
531+
partition_end = self.start_time + (partition_duration * (i + 1))
532+
533+
# Ensure last partition ends exactly at end_time (avoid rounding errors)
534+
if i == self.num_partitions - 1:
535+
partition_end = self.end_time
536+
537+
partitions.append(TimeRangePartition(partition_start, partition_end))
538+
539+
return partitions
540+
541+
525542
class AzureMonitorOffset:
526543
"""Represents the offset for Azure Monitor streaming.
527544
@@ -565,7 +582,7 @@ def from_json(json_str: str):
565582
return AzureMonitorOffset(data["timestamp"])
566583

567584

568-
class AzureMonitorStreamReader(DataSourceStreamReader):
585+
class AzureMonitorStreamReader(AzureMonitorReader, DataSourceStreamReader):
569586
"""Stream reader for Azure Monitor / Log Analytics workspaces.
570587
571588
Implements incremental streaming by tracking time-based offsets and splitting
@@ -580,13 +597,8 @@ def __init__(self, options, schema: StructType):
580597
schema: StructType schema (provided by DataSource.schema())
581598
582599
"""
583-
# Extract and validate required options
584-
self.workspace_id = options.get("workspace_id")
585-
self.query = options.get("query")
586-
self.tenant_id = options.get("tenant_id")
587-
self.client_id = options.get("client_id")
588-
self.client_secret = options.get("client_secret")
589-
600+
super().__init__(options, schema)
601+
590602
# Stream-specific options
591603
start_time = options.get("start_time", "latest")
592604
# Support 'latest' as alias for current timestamp
@@ -609,15 +621,6 @@ def __init__(self, options, schema: StructType):
609621
# Partition duration in seconds (default 1 hour)
610622
self.partition_duration = int(options.get("partition_duration", "3600"))
611623

612-
# Validate required options
613-
assert self.workspace_id is not None, "workspace_id is required"
614-
assert self.query is not None, "query is required"
615-
assert self.tenant_id is not None, "tenant_id is required"
616-
assert self.client_id is not None, "client_id is required"
617-
assert self.client_secret is not None, "client_secret is required"
618-
619-
self._schema = schema
620-
621624
def initialOffset(self):
622625
"""Return the initial offset (start time).
623626
@@ -682,63 +685,6 @@ def partitions(self, start, end):
682685

683686
return partitions
684687

685-
def read(self, partition: TimeRangePartition):
686-
"""Read data for the given partition time range.
687-
688-
Args:
689-
partition: TimeRangePartition containing start_time and end_time
690-
691-
Yields:
692-
Row objects from the query results
693-
694-
"""
695-
# Import inside method for partition-level execution
696-
from pyspark.sql import Row
697-
698-
# Use partition's time range
699-
timespan_value = (partition.start_time, partition.end_time)
700-
701-
# Execute query using module-level function
702-
response = _execute_logs_query(
703-
workspace_id=self.workspace_id,
704-
query=self.query,
705-
timespan=timespan_value,
706-
tenant_id=self.tenant_id,
707-
client_id=self.client_id,
708-
client_secret=self.client_secret,
709-
)
710-
711-
# Create a mapping of column names to their expected types from schema
712-
schema_field_map = {field.name: field.dataType for field in self._schema.fields}
713-
714-
# Process all tables in response (reuse same logic as batch reader)
715-
for table in response.tables:
716-
# Convert Azure Monitor rows to Spark Rows
717-
for row_idx, row_data in enumerate(table.rows):
718-
row_dict = {}
719-
720-
# First, process columns from the query results
721-
for i, col in enumerate(table.columns):
722-
# Handle both string columns (real API) and objects with .name attribute (test mocks)
723-
column_name = str(col) if isinstance(col, str) else str(col.name)
724-
raw_value = row_data[i]
725-
726-
# If column is in schema, convert to expected type
727-
if column_name in schema_field_map:
728-
expected_type = schema_field_map[column_name]
729-
try:
730-
converted_value = _convert_value_to_schema_type(raw_value, expected_type)
731-
row_dict[column_name] = converted_value
732-
except ValueError as e:
733-
raise ValueError(f"Row {row_idx}, column '{column_name}': {e}")
734-
735-
# Second, add NULL values for schema columns that are not in query results
736-
for schema_column_name in schema_field_map.keys():
737-
if schema_column_name not in row_dict:
738-
row_dict[schema_column_name] = None
739-
740-
yield Row(**row_dict)
741-
742688
def commit(self, end):
743689
"""Called when a batch is successfully processed.
744690

0 commit comments

Comments
 (0)