diff --git a/src/SimpleReplay/extract/extractor/extractor.py b/src/SimpleReplay/extract/extractor/extractor.py index ed112342..58846b1c 100755 --- a/src/SimpleReplay/extract/extractor/extractor.py +++ b/src/SimpleReplay/extract/extractor/extractor.py @@ -11,16 +11,22 @@ import pathlib import re from collections import OrderedDict - import dateutil.parser +import datetime + + import redshift_connector from boto3 import client from tqdm import tqdm +import asyncio + + from audit_logs_parsing import ( ConnectionLog, ) from helper import aws_service as aws_service_helper +import util from log_validation import remove_line_comments from .cloudwatch_extractor import CloudwatchExtractor from .s3_extractor import S3Extractor @@ -75,6 +81,63 @@ def get_extract(self, log_location, start_time, end_time): else: return self.local_extractor.get_extract_locally(log_location, start_time, end_time) + async def get_stored_procedures(self, start_time, end_time, username, stored_procedures_map, sql_json): + ''' + Handled bind varibales for stored procedures by querying the sys_query_history table + ''' + parse_start_time = "'" + start_time + "'" + parse_end_time = "'" + end_time + "'" + query_results = [] + async def fetch_procedure_data(transaction_ids): + transaction_ids_str = ','.join(str(id) for id in transaction_ids) + sys_query_history = f'select transaction_id, query_text \ + from sys_query_history \ + WHERE user_id > 1 \ + AND transaction_id IN ({transaction_ids_str}) \ + AND start_time >= {parse_start_time} \ + AND end_time <= {parse_end_time}\ + ORDER BY 1;' + cluster_object = util.cluster_dict(endpoint=self.config["source_cluster_endpoint"]) + result = await aws_service_helper.redshift_execute_query_async( + redshift_cluster_id=cluster_object['id'], + redshift_database_name=cluster_object['database'], + redshift_user=username, + region=self.config['region'], + query=sys_query_history, + ) + query_results.append(result) + return query_results + transaction_ids = list(stored_procedures_map.keys()) + tasks = [] + batch_size = 100 + for i in range(0, len(transaction_ids),batch_size): + # fetching a portion(100) of the transaction ids + batch = transaction_ids[i:i + batch_size] + tasks.append(fetch_procedure_data(batch)) + + # Gather all the results when the tasks are completed + batch_results = await asyncio.gather(*tasks) + + #Combine results from all batches + results = [result for batch_result in batch_results for result in batch_results] + + #Process the results and update the sql_json + for result in results: + if 'ColumnMetadata' in result and 'Records' in result: + for row in result['Records']: + for field in row: + if 'stringValue' in field and field['stringValue'].startswith('call'): + modified_string_value = field['stringValue'].rsplit('--')[0] + transaction_id = row[0]['longValue'] + stored_procedures_map[transaction_id] = modified_string_value + + # Update sql_json with the modified query text + for xid, query_text in stored_procedures_map.items(): + if xid in sql_json['transactions'] and sql_json['transactions'][xid]['queries']: + sql_json['transactions'][xid]['queries'][0]['text'] = query_text + return sql_json + + def save_logs(self, logs, last_connections, output_directory, connections, start_time, end_time): """ saving the extracted logs in S3 location in the following format: @@ -111,11 +174,16 @@ def save_logs(self, logs, last_connections, output_directory, connections, start ) pathlib.Path(output_directory).mkdir(parents=True, exist_ok=True) - sql_json, missing_audit_log_connections, replacements = self.get_sql_connections_replacements(last_connections, - log_items) - - with gzip.open(archive_filename, "wb") as f: - f.write(json.dumps(sql_json, indent=2).encode("utf-8")) + sql_json, missing_audit_log_connections, replacements, stored_procedures_map = self.get_sql_connections_replacements(last_connections, + log_items) + if self.config.get('replay_stored_procedures'): + logger.info(f'The total length of stored procedures found are : {len(stored_procedures_map)}') + sql_json_with_stored_procedure = asyncio.run(Extractor.get_stored_procedures(self,start_time=self.config['start_time'],end_time=self.config['end_time'],username=self.config['master_username'],stored_procedures_map=stored_procedures_map,sql_json=sql_json)) + with gzip.open(archive_filename, "wb") as f: + f.write(json.dumps(sql_json_with_stored_procedure, indent=2).encode("utf-8")) + else: + with gzip.open(archive_filename, "wb") as f: + f.write(json.dumps(sql_json, indent=2).encode("utf-8")) if is_s3: dest = output_prefix + "/SQLs.json.gz" @@ -169,6 +237,7 @@ def get_sql_connections_replacements(self, last_connections, log_items): sql_json = {"transactions": OrderedDict()} missing_audit_log_connections = set() replacements = set() + stored_procedures_map = {} for filename, queries in tqdm( log_items, disable=self.disable_progress_bar, @@ -220,7 +289,10 @@ def get_sql_connections_replacements(self, last_connections, log_items): query.text, flags=re.IGNORECASE, ) - + if self.config.get("replay_stored_procedures", None): + if query.text.lower().startswith("call"): + stored_procedures_map[query.xid] = query.text + query.text = f"{query.text.strip()}" if not len(query.text) == 0: if not query.text.endswith(";"): @@ -231,7 +303,7 @@ def get_sql_connections_replacements(self, last_connections, log_items): if not hash((query.database_name, query.username, query.pid)) in last_connections: missing_audit_log_connections.add((query.database_name, query.username, query.pid)) - return sql_json, missing_audit_log_connections, replacements + return sql_json, missing_audit_log_connections, replacements,stored_procedures_map def unload_system_table( self, diff --git a/src/SimpleReplay/extract/extractor/s3_extractor.py b/src/SimpleReplay/extract/extractor/s3_extractor.py index 25b4de7c..68a7c84b 100644 --- a/src/SimpleReplay/extract/extractor/s3_extractor.py +++ b/src/SimpleReplay/extract/extractor/s3_extractor.py @@ -32,7 +32,7 @@ def get_extract_from_s3(self, log_bucket, log_prefix, start_time, end_time): last_connections = {} databases = set() - bucket_objects = aws_service_helper.s3_get_bucket_contents(log_bucket, log_prefix) + bucket_objects = aws_service_helper.sync_s3_get_bucket_contents(log_bucket, log_prefix) s3_connection_logs = [] s3_user_activity_logs = [] diff --git a/src/SimpleReplay/helper/aws_service.py b/src/SimpleReplay/helper/aws_service.py index bfc47c5e..992c1917 100644 --- a/src/SimpleReplay/helper/aws_service.py +++ b/src/SimpleReplay/helper/aws_service.py @@ -1,9 +1,24 @@ +import base64 import datetime +import json import logging - import boto3 +from botocore.exceptions import ClientError +import asyncio +import functools + +logger = logging.getLogger("WorkloadReplicatorLogger") + + +def redshift_get_serverless_workgroup(workgroup_name, region): + rs_client = boto3.client("redshift-serverless", region_name=region) + return rs_client.get_workgroup(workgroupName=workgroup_name) -logger = logging.getLogger("SimpleReplayLogger") + +def redshift_describe_clusters(cluster_id, region): + rs_client = boto3.client("redshift", region_name=region) + response = rs_client.describe_clusters(ClusterIdentifier=cluster_id) + return response def redshift_describe_logging_status(source_cluster_endpoint): @@ -17,12 +32,44 @@ def redshift_describe_logging_status(source_cluster_endpoint): ) return result -def redshift_execute_query(redshift_cluster_id, redshift_user, redshift_database_name, region, query): + +def redshift_get_cluster_credentials( + region, + user, + database_name, + cluster_id, + duration=900, + auto_create=False, + additional_client_args={}, +): + rs_client = boto3.client("redshift", region, **additional_client_args) + try: + response = rs_client.get_cluster_credentials( + DbUser=user, + DbName=database_name, + ClusterIdentifier=cluster_id, + DurationSeconds=duration, + AutoCreate=auto_create, + ) + except Exception as e: + if e == rs_client.exceptions.ClusterNotFoundFault: + logger.error( + f"Cluster {cluster_id} not found. Please confirm cluster endpoint, account, and region." + ) + else: + logger.error(f"Error while getting cluster credentials: {e}", exc_info=True) + exit(-1) + return response + + +def redshift_execute_query( + redshift_cluster_id, redshift_user, redshift_database_name, region, query +): """ Executes redshift query and gets response for query when finished """ # get query id - redshift_data_api_client = boto3.client("redshift-data", region) + redshift_data_api_client = boto3.client("redshift-data", region) response_execute_statement = redshift_data_api_client.execute_statement( Database=redshift_database_name, DbUser=redshift_user, @@ -31,16 +78,10 @@ def redshift_execute_query(redshift_cluster_id, redshift_user, redshift_database ) query_id = response_execute_statement["Id"] - # get query status - response_describe_statement = redshift_data_api_client.describe_statement( - Id=query_id - ) query_done = False while not query_done: - response_describe_statement = ( - redshift_data_api_client.describe_statement(Id=query_id) - ) + response_describe_statement = redshift_data_api_client.describe_statement(Id=query_id) query_status = response_describe_statement["Status"] if query_status == "FAILED": @@ -51,34 +92,66 @@ def redshift_execute_query(redshift_cluster_id, redshift_user, redshift_database query_done = True # log result if there is a result (typically from Select statement) if response_describe_statement["HasResultSet"]: - response_get_statement_result = ( - redshift_data_api_client.get_statement_result(Id=query_id) + response_get_statement_result = redshift_data_api_client.get_statement_result( + Id=query_id ) return response_get_statement_result +def execute_query_sync(redshift_data_api_client, redshift_database_name, redshift_user, query, cluster_id): + response_execute_statement = redshift_data_api_client.execute_statement( + Database=redshift_database_name, + DbUser=redshift_user, + Sql=query, + ClusterIdentifier=cluster_id, + ) + query_id = response_execute_statement["Id"] + query_done = False + while not query_done: + response_describe_statement = redshift_data_api_client.describe_statement(Id=query_id) + query_status = response_describe_statement["Status"] + if query_status == "FAILED": + logger.debug(f"SQL execution failed. Query ID = {query_id}") + raise Exception + elif query_status == "FINISHED": + query_done = True + # log result if there is a result (typically from Select statement) + if response_describe_statement["HasResultSet"]: + response_get_statement_result = redshift_data_api_client.get_statement_result( + Id=query_id + ) + return response_get_statement_result + return None # Handle the case where there's no result + +async def redshift_execute_query_async( + redshift_cluster_id, redshift_user, redshift_database_name, region, query +): + """ + Executes redshift query asynchronusly and gets response for query when finished + """ + loop = asyncio.get_event_loop() + redshift_data_api_client = boto3.client("redshift-data", region_name=region) + return await loop.run_in_executor(None, execute_query_sync, redshift_data_api_client, redshift_database_name, redshift_user, query, redshift_cluster_id) + def cw_describe_log_groups(log_group_name=None, region=None): cloudwatch_client = boto3.client("logs", region) if log_group_name: - return cloudwatch_client.describe_log_groups( - logGroupNamePrefix=log_group_name - ) + return cloudwatch_client.describe_log_groups(logGroupNamePrefix=log_group_name) else: - logs = cloudwatch_client.describe_log_groups() + response_pg_1 = cloudwatch_client.describe_log_groups() + logs = response_pg_1 - token = logs.get('nextToken','') - while token != '': + token = response_pg_1.get("nextToken", "") + while token != "": response_itr = cloudwatch_client.describe_log_groups(nextToken=token) - logs['logGroups'].extend(response_itr['logGroups']) - token = response_itr['nextToken'] if 'nextToken' in response_itr.keys() else '' + logs["logGroups"].extend(response_itr["logGroups"]) + token = response_itr["nextToken"] if "nextToken" in response_itr.keys() else "" return logs def cw_describe_log_streams(log_group_name, region): cloudwatch_client = boto3.client("logs", region) - return cloudwatch_client.describe_log_streams( - logGroupName=log_group_name - ) + return cloudwatch_client.describe_log_streams(logGroupName=log_group_name) def cw_get_paginated_logs(log_group_name, log_stream_name, start_time, end_time, region): @@ -86,8 +159,12 @@ def cw_get_paginated_logs(log_group_name, log_stream_name, start_time, end_time, cloudwatch_client = boto3.client("logs", region) paginator = cloudwatch_client.get_paginator("filter_log_events") pagination_config = {"MaxItems": 10000} - convert_to_millis_since_epoch = lambda time: int( - (time.replace(tzinfo=None) - datetime.datetime.utcfromtimestamp(0)).total_seconds()) * 1000 + convert_to_millis_since_epoch = ( + lambda time: int( + (time.replace(tzinfo=None) - datetime.datetime.utcfromtimestamp(0)).total_seconds() + ) + * 1000 + ) start_time_millis_since_epoch = convert_to_millis_since_epoch(start_time) end_time_millis_since_epoch = convert_to_millis_since_epoch(end_time) response_iterator = paginator.paginate( @@ -100,7 +177,7 @@ def cw_get_paginated_logs(log_group_name, log_stream_name, start_time, end_time, next_token = None while next_token != "": for response in response_iterator: - next_token = response.get('nextToken', '') + next_token = response.get("nextToken", "") for event in response["events"]: log_list.append(event["message"]) pagination_config.update({"StartingToken": next_token}) @@ -123,19 +200,46 @@ def s3_upload(local_file_name, bucket, key=None): def s3_put_object(file_content, bucket, key): s3 = boto3.client("s3") - s3.put_object( - Body=file_content, - Bucket=bucket, - Key=key - ) + s3.put_object(Body=file_content, Bucket=bucket, Key=key) + + +def s3_resource_put_object(bucket, prefix, body): + s3_resource = boto3.resource("s3") + s3_resource.Object(bucket, prefix).put(Body=body) + + +async def s3_get_bucket_contents(bucket, prefix): + s3_client = boto3.client("s3") + loop = asyncio.get_event_loop() + bucket_objects = [] + continuation_token = "" + while True: + if continuation_token != "": + f_list_bounded = functools.partial( + s3_client.list_objects_v2, + Bucket=bucket, + Prefix=prefix, + ContinuationToken=continuation_token, + ) + else: + f_list_bounded = functools.partial( + s3_client.list_objects_v2, Bucket=bucket, Prefix=prefix + ) + response = await loop.run_in_executor(executor=None, func=f_list_bounded) + bucket_objects.extend(response.get("Contents", [])) + if response["IsTruncated"]: + continuation_token = response["NextContinuationToken"] + else: + break + return bucket_objects -def s3_get_bucket_contents(bucket, prefix): +def sync_s3_get_bucket_contents(bucket, prefix): conn = boto3.client("s3") # get first set of response = conn.list_objects_v2(Bucket=bucket, Prefix=prefix) - bucket_objects = response.get('Contents', []) + bucket_objects = response.get("Contents", []) if "NextContinuationToken" in response: prev_key = response["NextContinuationToken"] @@ -149,11 +253,22 @@ def s3_get_bucket_contents(bucket, prefix): prev_key = response["NextContinuationToken"] return bucket_objects + +def s3_generate_presigned_url(client_method, bucket_name, object_name): + s3_client = boto3.client("s3") + response = s3_client.generate_presigned_url( + client_method, + Params={"Bucket": bucket_name, "Key": object_name}, + ExpiresIn=604800, + ) + return response + + def s3_copy_object(src_bucket, src_prefix, dest_bucket, dest_prefix): boto3.client("s3").copy_object( Bucket=dest_bucket, Key=dest_prefix, - CopySource={"Bucket": src_bucket, "Key": src_prefix} + CopySource={"Bucket": src_bucket, "Key": src_prefix}, ) @@ -161,12 +276,20 @@ def s3_get_object(bucket, filename): s3 = boto3.resource("s3") return s3.Object(bucket, filename) + +def s3_client_get_object(bucket, key): + s3 = boto3.client("s3") + return s3.get_object(Bucket=bucket, Key=key) + + def glue_get_table(database, table, region): table_get_response = boto3.client("glue", region).get_table( - DatabaseName=database,Name=table, + DatabaseName=database, + Name=table, ) return table_get_response + def glue_get_partition_indexes(database, table, region): index_response = boto3.client("glue", region).get_partition_indexes( DatabaseName=database, @@ -174,15 +297,14 @@ def glue_get_partition_indexes(database, table, region): ) return index_response + def glue_create_table(new_database, table_input, region): - boto3.client("glue", region).create_table( - DatabaseName=new_database, TableInput=table_input - ) + boto3.client("glue", region).create_table(DatabaseName=new_database, TableInput=table_input) + def glue_get_database(name, region): - boto3.client("glue", region).get_database( - Name=name - ) + boto3.client("glue", region).get_database(Name=name) + def glue_create_database(name, description, region): boto3.client("glue", region).create_database( @@ -192,3 +314,20 @@ def glue_create_database(name, description, region): } ) + +def get_secret(secret_name, region_name): + # Create a Secrets Manager client + client = boto3.client(service_name="secretsmanager", region_name=region_name) + try: + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + # Decrypts secret using the associated KMS key. + # Depending on whether the secret is a string or binary, one of these fields will be populated. + if "SecretString" in get_secret_value_response: + return json.loads(get_secret_value_response["SecretString"]) + else: + return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"])) + except ClientError as e: + logger.error( + f"Exception occurred while getting secret from Secrets manager {e}", exc_info=True + ) + raise e \ No newline at end of file