Skip to content

[NOMERGE] Production concatenate calls squashed #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 51 commits into
base: production
Choose a base branch
from

Conversation

MTCam
Copy link
Member

@MTCam MTCam commented Mar 17, 2025

Summarize concat/outlining changes only.

majosm added 30 commits March 10, 2025 21:59
…cer#585)

* refactor deduplicate_data_wrappers to avoid dependence on erroneous super().rec usage in CachedMapAndCopyMapper

Here is a sketch of what happens with super().rec vs Mapper.rec for the previous implementation of
deduplicate_data_wrappers. Suppose we have 2 data wrappers a and b with the same data pointer.

With super().rec:

1) map_fn maps a to itself, then mapper copies a to a'; mapper caches a -> a' (twice, once in
   super().rec and then again in rec),
2) map_fn maps b to a, then mapper maps (via cache in super().rec call) a to a'; mapper caches
   b -> a'.

=> Only a' in output DAG.

With Mapper.rec:

1) map_fn maps a to itself, then mapper copies a to a'; caches a -> a',
2) map_fn maps b to a, then mapper copies a to a''; caches b -> a''.

=> Both a' and a'' in output DAG.

* call Mapper.rec instead of super().rec to avoid double caching

* call Mapper.rec from CachedMapper too just to avoid copy/paste errors

* add assertion to check for double caching

* add comment explaining use of Mapper.rec
…nition_cache_key for extra args case

ambiguous due to the fact that any arg can be specified with/without keyword
…n_definition_cache_key are not defined for general extra args/kwargs
apparently TypeVar(..., <type>) doesn't include subclasses of <type>
@MTCam
Copy link
Member Author

MTCam commented Mar 17, 2025

FYI: I haven't been able tor run prediction driver past 128 ranks. I keep getting errors like this one:

2025-03-17 15:08:50,853 - INFO - pytato.distributed.verify - find_distributed_partition: Split 928 nodes into 3 parts, with [77, 482, 604] nodes in ea\
ch partition.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/miniforge3/envs/x.concat/lib/python3.12/site-packages/mpi4py/__main__.py", line 7, in <mod\
ule>
2025-03-17 15:08:50,853 - INFO - pytato.distributed.verify - find_distributed_partition: Split 816 nodes into 3 parts, with [66, 424, 532] nodes in ea\
ch partition.
    main()
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/miniforge3/envs/x.concat/lib/python3.12/site-packages/mpi4py/run.py", line 214, in main
    run_command_line(args)
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/miniforge3/envs/x.concat/lib/python3.12/site-packages/mpi4py/run.py", line 46, in run_comm\
and_line
    run_path(sys.argv[0], run_name='__main__')
  File "<frozen runpy>", line 287, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "driver.py", line 80, in <module>
    main(actx_class, restart_filename=restart_filename,
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/mirgecom/mirgecom/mpi.py", line 152, in wrapped_func
    func(*args, **kwargs)
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/drivers_y3-prediction/y3prediction/prediction.py", line 4174, in main
2025-03-17 15:08:50,857 - INFO - grudge.array_context - pt.find_distributed_partition: completed (63.69s wall 1.00x CPU)
    compute_smoothed_char_length_compiled(smoothed_char_length_fluid, i)
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/arraycontext/arraycontext/impl/pytato/compile.py", line 350, in __call__
    compiled_func = self._dag_to_compiled_func(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/grudge/grudge/array_context.py", line 374, in _dag_to_compiled_func
2025-03-17 15:08:50,859 - INFO - grudge.array_context - pt.find_distributed_partition: completed (61.00s wall 1.00x CPU)
    distributed_partition = pt.find_distributed_partition(
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/lustre5/mtcampbe/CEESD/Experimental/concat-03.13/pytato/pytato/distributed/partition.py", line 998, in find_distributed_partition
    name_to_output_per_part[pid][name] = ary
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^
IndexError: list index out of range

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants