|
| 1 | +import logging |
| 2 | +import sys |
| 3 | + |
| 4 | +from tqdm.contrib.logging import logging_redirect_tqdm |
| 5 | + |
| 6 | +import cubed |
| 7 | +import cubed.array_api as xp |
| 8 | +import cubed.random |
| 9 | +from cubed.extensions.history import HistoryCallback |
| 10 | +from cubed.extensions.timeline import TimelineVisualizationCallback |
| 11 | +from cubed.extensions.tqdm import TqdmProgressBar |
| 12 | +from cubed.runtime.executors.lithops import LithopsDagExecutor |
| 13 | + |
| 14 | +logging.basicConfig(level=logging.INFO) |
| 15 | +# suppress harmless connection pool warnings |
| 16 | +logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) |
| 17 | + |
| 18 | +if __name__ == "__main__": |
| 19 | + tmp_path = sys.argv[1] |
| 20 | + runtime = sys.argv[2] |
| 21 | + spec = cubed.Spec(tmp_path, allowed_mem="2GB") |
| 22 | + executor = LithopsDagExecutor() |
| 23 | + |
| 24 | + # Note we use default float dtype, since np.matmul is not optimized for ints |
| 25 | + a = xp.ones((50000, 50000), chunks=(5000, 5000), spec=spec) |
| 26 | + b = xp.ones((50000, 50000), chunks=(5000, 5000), spec=spec) |
| 27 | + c = xp.matmul(a, b) |
| 28 | + d = xp.all(c == 50000) |
| 29 | + with logging_redirect_tqdm(): |
| 30 | + progress = TqdmProgressBar() |
| 31 | + hist = HistoryCallback() |
| 32 | + timeline_viz = TimelineVisualizationCallback() |
| 33 | + res = d.compute( |
| 34 | + executor=executor, |
| 35 | + callbacks=[progress, hist, timeline_viz], |
| 36 | + runtime=runtime, |
| 37 | + runtime_memory=2048, # Note that Lithops/Google Cloud Functions only accepts powers of 2 for this argument. |
| 38 | + ) |
| 39 | + assert res, "Validation failed" |
0 commit comments