Conversation
tests/test_attributor.py
Outdated
| cfg = IndexConfig(run_path=str(tmp_path)) | ||
| cfg.skip_preconditioners = True | ||
|
|
||
| kwargs = { |
There was a problem hiding this comment.
if you can't see the keys for the input args to functions in your IDE something is broken, they should show up in grey
|
|
||
| assert result.returncode == 0 | ||
| # Print the output to see what's failing | ||
| if result.returncode != 0: |
There was a problem hiding this comment.
I fixed this in the other query PR, I think it's a bit different than this fix. Basically returncode is 0 as long as the subprocess itself is happy so you parse the output text unconditionally
tests/test_build.py
Outdated
| cfg = IndexConfig(run_path=str(tmp_path)) | ||
| # This build hangs in pytest with preconditioners enabled. | ||
| # It works when run directly so it may be a pytest issue. | ||
| cfg.skip_preconditioners = True |
tests/test_build.py
Outdated
| # This build hangs in pytest with preconditioners enabled. | ||
| # It works when run directly so it may be a pytest issue. | ||
| cfg.skip_preconditioners = True | ||
| kwargs = { |
There was a problem hiding this comment.
kwargs are undesirable when not totally necessary because they add a layer of indirection. if the SWE we're working with wants this I'm open to other lines of reasoning
bergson/build.py
Outdated
| ds = assert_type(Dataset, Dataset.from_json(data_str)) | ||
| else: | ||
| try: | ||
| ds = load_dataset(data_str, split=cfg.data.split, streaming=cfg.streaming) |
There was a problem hiding this comment.
Need to add back the lost subset arg
There was a problem hiding this comment.
By the way, LM eval harness uses a nice pattern where the CLI takes --model_kwargs "device='cuda',streaming=True,subset='hello'" --dataset_kwargs ". . . " such that they don't need to update their library whenever HF adds a new model or dataset kwarg. We could consider doing this too.
bergson/build.py
Outdated
| @@ -91,7 +139,6 @@ def worker( | |||
| device_map=device_map, | |||
| quantization_config=quantization_config, | |||
| dtype=dtype, | |||
There was a problem hiding this comment.
Need to add back the lost revision arg
bergson/build.py
Outdated
| # Add a barrier to ensure all processes reach this point | ||
| dist.barrier() | ||
| except Exception: | ||
| pass # Ignore barrier failures during cleanup |
There was a problem hiding this comment.
I'm removing this unless we have a good reason, I don't think we should suppress errors (?)
There was a problem hiding this comment.
I did this for the .part rename call and it was a mistake, I'm gonna remove mine
bergson/build.py
Outdated
| # Write index config to json | ||
| def distributed_computing( | ||
| cfg: IndexConfig, | ||
| worker_fn: Callable, |
There was a problem hiding this comment.
Currently distributed_computing is not generic enough to be used in more than one place (it has build_index specific logic, like not having a QueryConfig parameter), and worker_fn is only ever set to collect_gradients, so it doesn't currently make sense to add a layer of indirection through an abstracted name there either.
worker_wrapper also has a generic name but does something specific (device-specific artifact setup & orchestration). I think the use of datasets is omnipresent in our workflows and we should embrace that as something to keep concrete. So we can name worker_wrapper to something like run_dataset_on_worker, and then we're not as surprised to find it's doing dataset processing.
Because both of our functions get the dataset in the main process, we can call maybe remove this conditional for now and add it back later if we need:
# Do all the data loading and preprocessing on the main process
if setup_data:
ds = setup_data_pipeline(cfg)
else:
# Create empty dataset for compatibility
ds = assert_type(Dataset, Dataset.from_list([]))
Then we may also choose to reduce the conceptual nesting of our functions by having a clear function called build that does
ds = setup_data_pipeline(cfg)
distributed_computing(worker_fn=collect_gradients, constant_worker_args=[cfg, ds], process_name="build"
and a query function that does
ds = setup_data_pipeline(index_cfg)
distributed_computing(worker_fn=collect_gradients, constant_worker_args=[index_cfg, query_cfg, ds], process_name="query"
Then distributed_computing will be truly generic, simply running
args={
i: (i, world_size, *constant_worker_args)
for i in range(world_size)
}
If we ever need to use something other than collect_gradients it can become one of the constant worker args.
I'm going to draft this up
45708b2 to
cf481d8
Compare
|
I'm going to yolo merge this, happy to raise another PR if there are issues or we want to add the kwargs pattern back to the tests or something like that Thanks for working on this!! |
This PR rewrites distributed setup to be more flexible.