-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_parallelism.py
106 lines (87 loc) · 2.98 KB
/
example_parallelism.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
We can get parllelism in certain edge cases.
How it works:
- @arrow_jit is a wrapper around Polars' ``Expr.map_batches()``.
- When ``map_batches()`` has ``is_elementwise=True, returns_scalar=False``,
the passed in function receives batches rather than the whole series when
in streaming mode.
- We can therefore combine a two pass operation to get batched operations,
which then get run in the thread pool in parallel, especially in
streaming mode.
"""
from time import time, process_time
import numpy as np
import polars as pl
from polars_numba import arrow_jit
def timeit(prefix, f, count=50):
start, cpu_start = time(), process_time()
for _ in range(count):
f()
print(
f"{prefix}:",
(time() - start) / count,
"(secs)",
(process_time() - cpu_start) / count,
"(CPU secs)",
)
@arrow_jit(returns_scalar=True, return_dtype=pl.Float64())
def not_parallel_sum(arr):
result = 0.0
for value in arr:
if value is not None:
# Try a complex expression so we're not bottlenecked on memory
# bandwidth:
result += np.log(np.cos(value) + np.sin(value) + 7)
return result
# is_elementwise means we won't always get the full Series, we might get chunks
# in some cases.
@arrow_jit(returns_scalar=False, is_elementwise=True, return_dtype=pl.Float64())
def sum_chunk(arr, array_builder):
result = 0
for value in arr:
if value is not None:
result += np.log(np.cos(value) + np.sin(value) + 7)
array_builder.real(result)
def parallel_sum(column: pl.Expr) -> pl.Expr:
# First, do sum of chunks, which will result in a Series of patial sums:
partial_sums = sum_chunk(column)
# Then do sum of those:
return partial_sums.sum()
df = pl.DataFrame({"values": range(1_000_000)})
print(df.select(parallel_sum(pl.col("values"))))
print(df.select(not_parallel_sum(pl.col("values"))))
# Wierdly doing just this, and not the above, results in Numba issues?!
print(df.lazy().select(parallel_sum(pl.col("values"))).collect(engine="streaming"))
# Check correctness
assert (
abs(
df.select(not_parallel_sum(pl.col("values"))).item()
- df.lazy()
.select(parallel_sum(pl.col("values")))
.collect(engine="streaming")
.item()
)
< 0.00001
)
timeit("Eager, not_parallel:", lambda: df.select(not_parallel_sum(pl.col("values"))))
timeit("Eager, parallel:", lambda: df.select(parallel_sum(pl.col("values"))))
timeit(
"Lazy, not_parallel:",
lambda: df.lazy().select(not_parallel_sum(pl.col("values"))).collect(),
)
timeit(
"Lazy, parallel:",
lambda: df.lazy().select(parallel_sum(pl.col("values"))).collect(),
)
timeit(
"Lazy streaming, not_parallel:",
lambda: df.lazy()
.select(not_parallel_sum(pl.col("values")))
.collect(engine="streaming"),
)
timeit(
"Lazy streaming, parallel:",
lambda: df.lazy()
.select(parallel_sum(pl.col("values")))
.collect(engine="streaming"),
)