Skip to content

Commit 873a7ab

Browse files
authored
Merge pull request #115 from CarterFendley/carter/programatic-join
Common utilities and programtic join.
2 parents b4775e4 + e94f95d commit 873a7ab

File tree

5 files changed

+485
-310
lines changed

5 files changed

+485
-310
lines changed

_launcher/solution.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ ns = solution.path(s)
155155
ext = file.ext(s)
156156
localcmd = if (s %in% c("clickhouse","h2o","juliadf", "juliads")) { # custom launcher bash script, for clickhouse h2o juliadf
157157
sprintf("exec.sh %s", t)
158-
} else sprintf("%s-%s.%s", t, ns, ext)
158+
} else if (s %in% c("dask")) {
159+
sprintf("%s_%s.%s", t, ns, ext)
160+
}else sprintf("%s-%s.%s", t, ns, ext)
159161
cmd = sprintf("./%s/%s", ns, localcmd)
160162
cmd
161163
ret = system(cmd, ignore.stdout=as.logical(args[["quiet"]]))

dask/common.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import sys
2+
import gc
3+
import os
4+
import logging
5+
import timeit
6+
from abc import ABC, abstractmethod
7+
from typing import Iterable, Any
8+
9+
import dask.dataframe as dd
10+
from dask import distributed
11+
12+
logging.basicConfig(
13+
level=logging.INFO,
14+
format='{ %(name)s:%(lineno)d @ %(asctime)s } - %(message)s'
15+
)
16+
logger = logging.getLogger(__name__)
17+
18+
THIS_DIR = os.path.abspath(
19+
os.path.dirname(__file__)
20+
)
21+
HELPERS_DIR = os.path.abspath(
22+
os.path.join(
23+
THIS_DIR, '../_helpers'
24+
)
25+
)
26+
sys.path.extend((THIS_DIR, HELPERS_DIR))
27+
from helpers import *
28+
29+
class Query(ABC):
30+
question: str = None
31+
32+
@staticmethod
33+
@abstractmethod
34+
def query(*args) -> dd.DataFrame:
35+
pass
36+
37+
@staticmethod
38+
@abstractmethod
39+
def check(ans: dd.DataFrame) -> Any:
40+
pass
41+
42+
@classmethod
43+
def name(cls) -> str:
44+
return f"{cls.__name__}: {cls.question}"
45+
46+
class QueryRunner:
47+
def __init__(
48+
self,
49+
task: str,
50+
solution: str,
51+
solution_version: str,
52+
solution_revision: str,
53+
fun: str,
54+
cache: str,
55+
on_disk: bool
56+
):
57+
self.task = task
58+
self.solution = solution
59+
self.solution_version = solution_version
60+
self.solution_revision = solution_revision
61+
self.fun = fun
62+
self.cache = cache
63+
self.on_disk = on_disk
64+
65+
def run_query(
66+
self,
67+
data_name: str,
68+
in_rows: int,
69+
args: Iterable[Any],
70+
query: Query,
71+
machine_type: str,
72+
runs: int = 2,
73+
raise_exception: bool = False,
74+
):
75+
logger.info("Running '%s'" % query.name())
76+
77+
try:
78+
for run in range(1, runs+1):
79+
gc.collect() # TODO: Able to do this in worker processes? Want to?
80+
81+
# Calculate ans
82+
t_start = timeit.default_timer()
83+
ans = query.query(*args)
84+
logger.debug("Answer shape: %s" % (ans.shape, ))
85+
t = timeit.default_timer() - t_start
86+
m = memory_usage()
87+
88+
logger.info("\tRun #%s: %0.3fs" % (run, t))
89+
90+
# Calculate chk
91+
t_start = timeit.default_timer()
92+
chk = query.check(ans)
93+
chkt = timeit.default_timer() - t_start
94+
95+
96+
write_log(
97+
task=self.task,
98+
data=data_name,
99+
in_rows=in_rows,
100+
question=query.question,
101+
out_rows=ans.shape[0],
102+
out_cols=ans.shape[1],
103+
solution=self.solution,
104+
version=self.solution_version,
105+
git=self.solution_revision,
106+
fun=self.fun,
107+
run=run,
108+
time_sec=t,
109+
mem_gb=m,
110+
cache=self.cache,
111+
chk=make_chk(chk),
112+
chk_time_sec=chkt,
113+
on_disk=self.on_disk,
114+
machine_type=machine_type
115+
)
116+
if run == runs:
117+
# Print head / tail on last run
118+
logger.debug("Answer head:\n%s" % ans.head(3))
119+
logger.debug("Answer tail:\n%s" % ans.tail(3))
120+
del ans
121+
except Exception as err:
122+
logger.error("Query '%s' failed!" % query.name())
123+
print(err)
124+
125+
# Re-raise if instructed
126+
if raise_exception:
127+
raise err
128+
129+
def dask_client() -> distributed.Client:
130+
# we use process-pool instead of thread-pool due to GIL cost
131+
return distributed.Client(processes=True, silence_logs=logging.ERROR)

0 commit comments

Comments
 (0)