Skip to content

Commit 29bb579

Browse files
authored
Add matmul example that validates result (#244)
1 parent e217865 commit 29bb579

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import logging
2+
import sys
3+
4+
from tqdm.contrib.logging import logging_redirect_tqdm
5+
6+
import cubed
7+
import cubed.array_api as xp
8+
import cubed.random
9+
from cubed.extensions.history import HistoryCallback
10+
from cubed.extensions.timeline import TimelineVisualizationCallback
11+
from cubed.extensions.tqdm import TqdmProgressBar
12+
from cubed.runtime.executors.lithops import LithopsDagExecutor
13+
14+
logging.basicConfig(level=logging.INFO)
15+
# suppress harmless connection pool warnings
16+
logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR)
17+
18+
if __name__ == "__main__":
19+
tmp_path = sys.argv[1]
20+
runtime = sys.argv[2]
21+
spec = cubed.Spec(tmp_path, allowed_mem="2GB")
22+
executor = LithopsDagExecutor()
23+
24+
# Note we use default float dtype, since np.matmul is not optimized for ints
25+
a = xp.ones((50000, 50000), chunks=(5000, 5000), spec=spec)
26+
b = xp.ones((50000, 50000), chunks=(5000, 5000), spec=spec)
27+
c = xp.matmul(a, b)
28+
d = xp.all(c == 50000)
29+
with logging_redirect_tqdm():
30+
progress = TqdmProgressBar()
31+
hist = HistoryCallback()
32+
timeline_viz = TimelineVisualizationCallback()
33+
res = d.compute(
34+
executor=executor,
35+
callbacks=[progress, hist, timeline_viz],
36+
runtime=runtime,
37+
runtime_memory=2048, # Note that Lithops/Google Cloud Functions only accepts powers of 2 for this argument.
38+
)
39+
assert res, "Validation failed"

0 commit comments

Comments
 (0)