|
| 1 | +import sys |
| 2 | +import gc |
| 3 | +import os |
| 4 | +import logging |
| 5 | +import timeit |
| 6 | +from abc import ABC, abstractmethod |
| 7 | +from typing import Iterable, Any |
| 8 | + |
| 9 | +import dask.dataframe as dd |
| 10 | +from dask import distributed |
| 11 | + |
| 12 | +logging.basicConfig( |
| 13 | + level=logging.INFO, |
| 14 | + format='{ %(name)s:%(lineno)d @ %(asctime)s } - %(message)s' |
| 15 | +) |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | +THIS_DIR = os.path.abspath( |
| 19 | + os.path.dirname(__file__) |
| 20 | +) |
| 21 | +HELPERS_DIR = os.path.abspath( |
| 22 | + os.path.join( |
| 23 | + THIS_DIR, '../_helpers' |
| 24 | + ) |
| 25 | +) |
| 26 | +sys.path.extend((THIS_DIR, HELPERS_DIR)) |
| 27 | +from helpers import * |
| 28 | + |
| 29 | +class Query(ABC): |
| 30 | + question: str = None |
| 31 | + |
| 32 | + @staticmethod |
| 33 | + @abstractmethod |
| 34 | + def query(*args) -> dd.DataFrame: |
| 35 | + pass |
| 36 | + |
| 37 | + @staticmethod |
| 38 | + @abstractmethod |
| 39 | + def check(ans: dd.DataFrame) -> Any: |
| 40 | + pass |
| 41 | + |
| 42 | + @classmethod |
| 43 | + def name(cls) -> str: |
| 44 | + return f"{cls.__name__}: {cls.question}" |
| 45 | + |
| 46 | +class QueryRunner: |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + task: str, |
| 50 | + solution: str, |
| 51 | + solution_version: str, |
| 52 | + solution_revision: str, |
| 53 | + fun: str, |
| 54 | + cache: str, |
| 55 | + on_disk: bool |
| 56 | + ): |
| 57 | + self.task = task |
| 58 | + self.solution = solution |
| 59 | + self.solution_version = solution_version |
| 60 | + self.solution_revision = solution_revision |
| 61 | + self.fun = fun |
| 62 | + self.cache = cache |
| 63 | + self.on_disk = on_disk |
| 64 | + |
| 65 | + def run_query( |
| 66 | + self, |
| 67 | + data_name: str, |
| 68 | + in_rows: int, |
| 69 | + args: Iterable[Any], |
| 70 | + query: Query, |
| 71 | + machine_type: str, |
| 72 | + runs: int = 2, |
| 73 | + raise_exception: bool = False, |
| 74 | + ): |
| 75 | + logger.info("Running '%s'" % query.name()) |
| 76 | + |
| 77 | + try: |
| 78 | + for run in range(1, runs+1): |
| 79 | + gc.collect() # TODO: Able to do this in worker processes? Want to? |
| 80 | + |
| 81 | + # Calculate ans |
| 82 | + t_start = timeit.default_timer() |
| 83 | + ans = query.query(*args) |
| 84 | + logger.debug("Answer shape: %s" % (ans.shape, )) |
| 85 | + t = timeit.default_timer() - t_start |
| 86 | + m = memory_usage() |
| 87 | + |
| 88 | + logger.info("\tRun #%s: %0.3fs" % (run, t)) |
| 89 | + |
| 90 | + # Calculate chk |
| 91 | + t_start = timeit.default_timer() |
| 92 | + chk = query.check(ans) |
| 93 | + chkt = timeit.default_timer() - t_start |
| 94 | + |
| 95 | + |
| 96 | + write_log( |
| 97 | + task=self.task, |
| 98 | + data=data_name, |
| 99 | + in_rows=in_rows, |
| 100 | + question=query.question, |
| 101 | + out_rows=ans.shape[0], |
| 102 | + out_cols=ans.shape[1], |
| 103 | + solution=self.solution, |
| 104 | + version=self.solution_version, |
| 105 | + git=self.solution_revision, |
| 106 | + fun=self.fun, |
| 107 | + run=run, |
| 108 | + time_sec=t, |
| 109 | + mem_gb=m, |
| 110 | + cache=self.cache, |
| 111 | + chk=make_chk(chk), |
| 112 | + chk_time_sec=chkt, |
| 113 | + on_disk=self.on_disk, |
| 114 | + machine_type=machine_type |
| 115 | + ) |
| 116 | + if run == runs: |
| 117 | + # Print head / tail on last run |
| 118 | + logger.debug("Answer head:\n%s" % ans.head(3)) |
| 119 | + logger.debug("Answer tail:\n%s" % ans.tail(3)) |
| 120 | + del ans |
| 121 | + except Exception as err: |
| 122 | + logger.error("Query '%s' failed!" % query.name()) |
| 123 | + print(err) |
| 124 | + |
| 125 | + # Re-raise if instructed |
| 126 | + if raise_exception: |
| 127 | + raise err |
| 128 | + |
| 129 | +def dask_client() -> distributed.Client: |
| 130 | + # we use process-pool instead of thread-pool due to GIL cost |
| 131 | + return distributed.Client(processes=True, silence_logs=logging.ERROR) |
0 commit comments