diff --git a/cyhy_commander/commander.py b/cyhy_commander/commander.py index fb64fcc..4bc40b3 100755 --- a/cyhy_commander/commander.py +++ b/cyhy_commander/commander.py @@ -24,10 +24,12 @@ from collections import defaultdict import logging import os +import Queue import random import shutil import signal import sys +import threading import time import traceback from ConfigParser import SafeConfigParser @@ -89,6 +91,7 @@ DEFAULT = "DEFAULT" DEFAULT_SCHEDULER = "default-scheduler" DEFAULT_SECTION = "default-section" +JOB_PROCESSING_THREADS = "job-processing-threads" JOBS_PER_NESSUS_HOST = "jobs-per-nessus-host" JOBS_PER_NMAP_HOST = "jobs-per-nmap-host" KEEP_FAILURES = "keep-failures" @@ -138,18 +141,24 @@ def __init__(self, config_section=None, debug_logging=False, console_logging=Fal self.__all_hosts_idle = False self.__config_section = config_section self.__db = None + self.__failed_job_queue = None self.__failure_sinks = [] self.__host_exceptions = defaultdict(lambda: 0) self.__hosts_on_cooldown = [] + self.__is_processing_jobs = True self.__is_running = True + self.__job_processing_sleep_duration = 1 self.__keep_failures = False self.__keep_successes = False + self.__log_output_sleep_duration = 10 self.__nessus_sources = [] self.__next_scan_limit = 2000 self.__nmap_sources = [] + self.__queue_monitor_output_lock = threading.Lock() self.__setup_directories() self.__shutdown_when_idle = False self.__success_sinks = [] + self.__successful_job_queue = None self.__test_mode = False def __setup_logging(self, debug_logging, console_logging): @@ -305,9 +314,9 @@ def __done_jobs(self): local_job_path = os.path.join(destDir, job) if destDir == SUCCESS_DIR: - self.__process_successful_job(local_job_path) + self.__successful_job_queue.put(local_job_path) else: - self.__process_failed_job(local_job_path) + self.__failed_job_queue.put(local_job_path) else: self.__logger.warning( @@ -428,29 +437,108 @@ def __fill_hosts(self, counts, sources, workgroup_name, jobs_per_host): execute(self.__push_job, self, job_path, hosts=[lowest_host]) counts[lowest_host] += 1 + def __monitor_job_queues(self): + # Output the number of jobs that are not done for each queue every + # self.__log_output_sleep_duration seconds while work is on the queues. + while self.__is_processing_jobs: + with self.__queue_monitor_output_lock: + self.__logger.debug( + "%d unfinished jobs in the successful job queue" + % self.__successful_job_queue.unfinished_tasks + ) + self.__logger.debug( + "%d unfinished jobs in the failed job queue" + % self.__failed_job_queue.unfinished_tasks + ) + time.sleep(self.__log_output_sleep_duration) + + def __process_queued_jobs(self): + + # define an inner function to process jobs + def process_job_from_queue(target_job_queue, job_processing_function): + """Helper function to process jobs from a queue. + + Args: + target_job_queue (Queue.Queue): The queue to get a job to process. + job_processing_function (callable): The function used to process a job. + + Returns: + The job path that was processed or None if the queue was empty. + """ + job_path = None + + # check the successful jobs queue + try: + job_path = target_job_queue.get(timeout=1) + except Queue.Empty: + return job_path + + try: + job_processing_function(job_path) + except Exception, e: + self.__logger.critical(e) + self.__logger.critical(traceback.format_exc()) + + # report task completion no matter what so the queue can be joined + target_job_queue.task_done() + + # return path of the job that was processed + return job_path + + # run as long as the commander is processing jobs + while self.__is_processing_jobs: + # process successful job + job_processing_results = process_job_from_queue( + self.__successful_job_queue, self.__process_successful_job + ) + + # process failed job if a successful job was not processed + if job_processing_results is None: + job_processing_results = process_job_from_queue( + self.__failed_job_queue, self.__process_failed_job + ) + + # sleep if both queues are empty + if job_processing_results is None: + time.sleep(self.__job_processing_sleep_duration) + def __process_successful_job(self, job_path): + # Get the name of the current thread + thread_name = threading.current_thread().name + for sink in self.__success_sinks: if sink.can_handle(job_path): - self.__logger.info("Processing %s with %s" % (job_path, sink)) + self.__logger.info( + "[%s] Processing %s with %s" % (thread_name, job_path, sink) + ) sink.handle(job_path) - self.__logger.info("Processing completed") + self.__logger.info("[%s] Processing completed" % thread_name) if not self.__test_mode and not self.__keep_successes: shutil.rmtree(job_path) - self.__logger.info("%s deleted" % job_path) + self.__logger.info("[%s] %s deleted" % (thread_name, job_path)) return - self.__logger.warning("No handler was able to process %s" % job_path) + self.__logger.warning( + "[%s] No handler was able to process %s" % (thread_name, job_path) + ) def __process_failed_job(self, job_path): + # Get the name of the current thread + thread_name = threading.current_thread().name + for sink in self.__failure_sinks: if sink.can_handle(job_path): - self.__logger.warning("Processing %s with %s" % (job_path, sink)) + self.__logger.warning( + "[%s] Processing %s with %s" % (thread_name, job_path, sink) + ) sink.handle(job_path) - self.__logger.info("Processing completed") + self.__logger.info("[%s] Processing completed" % thread_name) if not self.__test_mode and not self.__keep_failures: shutil.rmtree(job_path) - self.__logger.info("%s deleted" % job_path) + self.__logger.info("[%s] %s deleted" % (thread_name, job_path)) return - self.__logger.warning("No handler was able to process %s" % job_path) + self.__logger.warning( + "[%s] No handler was able to process %s" % (thread_name, job_path) + ) def handle_term(self, signal, frame): self.__logger.warning( @@ -470,7 +558,7 @@ def __check_stop_file(self): def __check_database_pause(self): while self.__ch_db.should_commander_pause() and self.__is_running: self.__logger.info("Commander is paused due to database request.") - time.sleep(10) + time.sleep(self.__log_output_sleep_duration) self.__check_stop_file() def __write_config(self): @@ -478,6 +566,7 @@ def __write_config(self): config.set(None, DATABASE_URI, "mongodb://localhost:27017/") config.set(None, JOBS_PER_NMAP_HOST, "8") config.set(None, JOBS_PER_NESSUS_HOST, "8") + config.set(None, JOB_PROCESSING_THREADS, "4") config.set(None, POLL_INTERVAL, "30") config.set(None, NEXT_SCAN_LIMIT, "2000") config.set(None, DEFAULT_SECTION, TESTING_SECTION) @@ -603,6 +692,12 @@ def do_work(self): self.__test_mode = config.getboolean(config_section, TEST_MODE) self.__logger.info("Test mode: %s", self.__test_mode) self.__keep_failures = config.getboolean(config_section, KEEP_FAILURES) + job_processing_thread_count = config.getint( + config_section, JOB_PROCESSING_THREADS + ) + self.__logger.info( + "Number of job processing threads: %d", job_processing_thread_count + ) self.__logger.info("Keep failed jobs: %s", self.__keep_failures) self.__keep_successes = config.getboolean(config_section, KEEP_SUCCESSES) self.__logger.info("Keep successful jobs: %s", self.__keep_successes) @@ -617,6 +712,43 @@ def do_work(self): self.__setup_sources() self.__setup_sinks() + self.__successful_job_queue = Queue.Queue() + self.__failed_job_queue = Queue.Queue() + + # spin up the thread pool to process retrieved work + job_processing_threads = [] + for t in range(job_processing_thread_count): + job_processing_thread = threading.Thread( + name="JobProcessor-%d" % t, target=self.__process_queued_jobs + ) + job_processing_threads.append(job_processing_thread) + try: + job_processing_thread.start() + except Exception as e: + self.__logger.error("Unable to start job processing thread #%s", t) + self.__logger.error(e) + # bail out + self.__logger.critical( + "Shutting down due to inability to start job processing threads." + ) + self.__is_running = False + + # spin up a thread to output queue load information + self.__queue_monitor_output_lock.acquire() + job_queue_monitor_thread = threading.Thread( + name="QueueMonitor", target=self.__monitor_job_queues + ) + try: + job_queue_monitor_thread.start() + except Exception as e: + self.__logger.error("Unable to start job queue monitoring thread") + self.__logger.error(e) + # bail out + self.__logger.critical( + "Shutting down due to inability to start queue monitoring thread." + ) + self.__is_running = False + # pairs of hosts and job sources work_groups = ( (NMAP_WORKGROUP, nmap_hosts, self.__nmap_sources, jobs_per_nmap_host), @@ -706,11 +838,18 @@ def do_work(self): self.__logger.debug( "Checking remotes for completed jobs to download and process" ) + self.__queue_monitor_output_lock.release() for (workgroup_name, hosts, sources, jobs_per_host) in work_groups: if hosts == None: continue execute(self.__done_jobs, self, hosts=hosts) + # wait for work to process + self.__logger.debug("Waiting for completed jobs to be processed.") + self.__successful_job_queue.join() + self.__failed_job_queue.join() + self.__queue_monitor_output_lock.acquire() + # check for scheduled hosts self.__logger.debug( "Checking for scheduled DONE hosts to mark WAITING." @@ -752,12 +891,25 @@ def do_work(self): except Exception, e: self.__logger.critical(e) self.__logger.critical(traceback.format_exc()) + + # signal job processing threads to exit once they have finished all + # queued work + self.__is_processing_jobs = False + self.__queue_monitor_output_lock.release() + + # wait for the job processing threads to exit + for job_processing_thread in job_processing_threads: + job_processing_thread.join() + + # wait for the job queue monitoring thread to exit + job_queue_monitor_thread.join() + self.__logger.info("Shutting down.") disconnect_all() def main(): - args = docopt(__doc__, version="v1.0.2") + args = docopt(__doc__, version="v1.1.0") workingDir = os.path.join(os.getcwd(), args[""]) if not os.path.exists(workingDir): print >>sys.stderr, 'Working directory "%s" does not exist. Attempting to create...' % workingDir diff --git a/cyhy_commander/nessus/nessus_importer.py b/cyhy_commander/nessus/nessus_importer.py index 8cd703e..622105c 100755 --- a/cyhy_commander/nessus/nessus_importer.py +++ b/cyhy_commander/nessus/nessus_importer.py @@ -6,6 +6,7 @@ import netaddr import gzip import logging +import threading from cyhy.core import UNKNOWN_OWNER from cyhy.db import CHDatabase, VulnTicketManager from cyhy.util import util @@ -50,7 +51,10 @@ def __init__(self, db, manual_scan=False): self.manual_scan = manual_scan def process(self, filename, gzipped=False): - self.__logger.debug("Starting processing of %s" % filename) + # Get the name of the current thread + thread_name = threading.current_thread().name + + self.__logger.debug("[%s] Starting processing of %s" % (thread_name, filename)) if self.manual_scan: # if we are doing a manual scan import we have to assume a current time self.current_ip_time = util.utcnow() @@ -62,20 +66,28 @@ def process(self, filename, gzipped=False): f.close() def __try_to_clear_latest_flags(self): + # Get the name of the current thread + thread_name = threading.current_thread().name + # Once the ticket manager has all its information, # it can clear the previous latest flags if self.ticket_manager.ready_to_clear_vuln_latest_flags(): self.__logger.debug( - 'Ticket manager IS READY to clear VulnScan "latest" flags' + '[%s] Ticket manager IS READY to clear VulnScan "latest" flags' + % thread_name ) self.ticket_manager.clear_vuln_latest_flags() self.attempted_to_clear_latest_flags = True else: self.__logger.debug( - 'Ticket manager IS NOT READY to clear VulnScan "latest" flags' + '[%s] Ticket manager IS NOT READY to clear VulnScan "latest" flags' + % thread_name ) def targets_callback(self, targets_string): + # Get the name of the current thread + thread_name = threading.current_thread().name + """list of targets read from the policy section clear latest flags, and change host state this is done here since not all targets necessarily @@ -99,35 +111,53 @@ def targets_callback(self, targets_string): if len(parts) == 2 and parts[1].endswith("]"): t = parts[1][:-1] else: - self.__logger.warning("Skipping malformed target: '%s'" % t.strip()) + self.__logger.warning( + "[%s] Skipping malformed target: '%s'" + % (thread_name, t.strip()) + ) continue self.targets.add(netaddr.IPAddress(t)) - self.__logger.debug("Found %d targets in Nessus file" % len(self.targets)) + self.__logger.debug( + "[%s] Found %d targets in Nessus file" % (thread_name, len(self.targets)) + ) self.ticket_manager.ips = self.targets self.__try_to_clear_latest_flags() def plugin_set_callback(self, plugin_set_string): + # Get the name of the current thread + thread_name = threading.current_thread().name + string_list = plugin_set_string.split(";") if ( string_list[-1] == "" ): # this list ends with a ; creating a non-int empty string string_list.pop() plugin_set = set(int(s) for s in string_list) - self.__logger.debug("Found %d plugin_ids in Nessus file" % len(plugin_set)) + self.__logger.debug( + "[%s] Found %d plugin_ids in Nessus file" % (thread_name, len(plugin_set)) + ) self.ticket_manager.source_ids = plugin_set self.__try_to_clear_latest_flags() def port_range_callback(self, port_range_string): + # Get the name of the current thread + thread_name = threading.current_thread().name + # The base policy port range was used if port_range_string == "default": # Match the base policy value found in /extras/policy.xml port_range_string = "1-65535" ports = set(util.range_string_to_list(port_range_string)) - self.__logger.debug("Found %d ports in Nessus file" % len(ports)) + self.__logger.debug( + "[%s] Found %d ports in Nessus file" % (thread_name, len(ports)) + ) self.ticket_manager.ports = ports self.__try_to_clear_latest_flags() def host_callback(self, parsedHost): + # Get the name of the current thread + thread_name = threading.current_thread().name + # some fragile hosts don't list their host_ip # fallback to name if parsedHost.has_key("host_ip"): @@ -140,8 +170,8 @@ def host_callback(self, parsedHost): # When parsedHost['name'] is not a valid IP (see CYHY-113 in Jira) self.current_ip = None self.__logger.warning( - "Skipping vulnerability reports; invalid host IP detected: %s" - % parsedHost["name"] + "[%s] Skipping vulnerability reports; invalid host IP detected: %s" + % (thread_name, parsedHost["name"]) ) return parsedHost["ip"] = self.current_ip @@ -188,8 +218,9 @@ def host_callback(self, parsedHost): self.current_host_owner = UNKNOWN_OWNER if self.current_hostname: self.__logger.warning( - "Could not find owner for %s - %s (%d)" + "[%s] Could not find owner for %s - %s (%d)" % ( + thread_name, self.current_hostname, self.current_ip, int(self.current_ip), @@ -197,20 +228,23 @@ def host_callback(self, parsedHost): ) else: self.__logger.warning( - "Could not find owner for %s (%d)" - % (self.current_ip, int(self.current_ip)) + "[%s] Could not find owner for %s (%d)" + % (thread_name, self.current_ip, int(self.current_ip)) ) # Nessus host docs are not stored as we already have better data from nmap def report_callback(self, parsedReport): + # Get the name of the current thread + thread_name = threading.current_thread().name + # not storing severity 0 reports or reports with invalid IPs if parsedReport["severity"] == 0: return if self.current_ip is None: self.__logger.warning( - "No current IP; skipping vulnerability report: %s" - % parsedReport["plugin_name"] + "[%s] No current IP; skipping vulnerability report: %s" + % (thread_name, parsedReport["plugin_name"]) ) return report = self.__db.VulnScanDoc() @@ -231,6 +265,9 @@ def report_callback(self, parsedReport): self.ticket_manager.open_ticket(report, "vulnerability detected") def end_callback(self): + # Get the name of the current thread + thread_name = threading.current_thread().name + for ip in self.targets: if self.manual_scan: # update host priority and reschedule host @@ -241,11 +278,13 @@ def end_callback(self): self.ticket_manager.close_tickets() if not self.attempted_to_clear_latest_flags: self.__logger.warning( - 'Reached end of Nessus import but did not clear "latest" flags' + '[%s] Reached end of Nessus import but did not clear "latest" flags' + % thread_name ) self.__logger.warning( - "Ticket manager state counts: %d ips, %d ports, %d source_ids" + "[%s] Ticket manager state counts: %d ips, %d ports, %d source_ids" % ( + thread_name, len(self.ticket_manager.ips), len(self.ticket_manager.ports), len(self.ticket_manager.source_ids), @@ -253,5 +292,6 @@ def end_callback(self): ) else: self.__logger.debug( - "Reached end of Nessus import, VulnScan latest flags were cleared." + "[%s] Reached end of Nessus import, VulnScan latest flags were cleared." + % thread_name ) diff --git a/cyhy_commander/nmap/nmap_importer.py b/cyhy_commander/nmap/nmap_importer.py index d88f005..e8f8b7f 100755 --- a/cyhy_commander/nmap/nmap_importer.py +++ b/cyhy_commander/nmap/nmap_importer.py @@ -2,6 +2,7 @@ # built-in python libraries import logging +import threading from xml.sax import parse # third-party libraries (install with pip) @@ -93,13 +94,18 @@ def process(self, nmap_filename, target_filename): f.close() def __store_port_details(self, parsed_host): + # Get the name of the current thread + thread_name = threading.current_thread().name + has_at_least_one_open_port = False ip = parsed_host["addr"] host_doc = self.__db.HostDoc.get_by_ip(ip) if host_doc: ip_owner = host_doc.get("owner", UNKNOWN_OWNER) else: - self.__logger.warning("No HostDoc found for IP %s" % str(ip)) + self.__logger.warning( + "[%s] No HostDoc found for IP %s" % (thread_name, str(ip)) + ) ip_owner = UNKNOWN_OWNER for (port, details) in parsed_host["ports"].items(): if details["state"] != "open": # only storing open ports diff --git a/setup.py b/setup.py index 039654c..c7a31a0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="cyhy-commander", - version="1.0.2", + version="1.1.0", author="Mark Feldhousen Jr.", author_email="mark.feldhousen@cisa.dhs.gov", packages=find_packages(),