-
Notifications
You must be signed in to change notification settings - Fork 187
Refactor task construction for generation #888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
69c02f2
47308a0
0e3688d
849253f
934f334
7005d4e
c615778
62948b1
b1ef852
fd7e3c6
3629e09
f286644
db30066
876913f
55d7970
d8d6af1
b8275c9
78e1a9c
bc59c72
fcfe8b8
32cc014
ccd0a3a
4248725
c140260
2116744
5f25397
925ab76
5d8ddd2
6054321
bdfd79f
b40cf51
2563cf3
4ad06bf
cb37ddb
bac2d07
d3e8316
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,14 +14,16 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import importlib | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Callable, Dict, List, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import typer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import nemo_skills.pipeline.utils as pipeline_utils | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.dataset.utils import import_from_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.inference import GENERATION_MODULE_MAP, GenerationType | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.pipeline.app import app, typer_unpacker | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.pipeline.utils.commands import sandbox_command | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.pipeline.utils.declarative import Command, CommandGroup, HardwareConfig, Pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from nemo_skills.utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| compute_chunk_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_logger_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -35,6 +37,103 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TODO: add num_jobs here for consistency with eval? | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _create_commandgroup_from_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_cmd: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_config: Optional[Dict], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with_sandbox: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sandbox_port: Optional[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cluster_config: Dict, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installation_command: Optional[str], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_server_command_fn: Callable, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| partition: Optional[str], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qos: Optional[str], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| time_min: Optional[str], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| exclusive: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| keep_mounts_for_sandbox: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| log_dir: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> CommandGroup: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create a CommandGroup from server_config. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Component ordering: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 1. Server (if server_config provided) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 2. Client command | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 3. Sandbox (if with_sandbox=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| components = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 1. Add server if server_config is provided | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if server_config is not None and int(server_config["num_gpus"]) > 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_type = server_config["server_type"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_container = server_config.pop("container", cluster_config["containers"][server_type]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Call server command builder directly with cluster_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd, num_tasks = get_server_command_fn(**server_config, cluster_config=cluster_config) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create metadata dict | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "num_tasks": num_tasks, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "gpus": server_config["num_gpus"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "nodes": server_config["num_nodes"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "log_prefix": "server", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_cmd = Command( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| command=cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| container=server_container, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gpus=server_config["num_gpus"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nodes=server_config["num_nodes"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata=metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| components.append(server_cmd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 2. Add main generation command | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| client_cmd = Command( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| command=generation_cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| container=cluster_config["containers"]["nemo-skills"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installation_command=installation_command, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata={"log_prefix": "main"}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| components.append(client_cmd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # 3. Add sandbox if requested | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if with_sandbox: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Call sandbox command builder directly with cluster_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd, metadata = sandbox_command(cluster_config=cluster_config, port=sandbox_port) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata["log_prefix"] = "sandbox" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sandbox_cmd = Command( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| command=cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| container=cluster_config["containers"]["sandbox"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metadata=metadata, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| components.append(sandbox_cmd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+103
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Implement keep_mounts_for_sandbox to restore mount propagation. The Apply this diff: # 3. Add sandbox if requested
if with_sandbox:
# Call sandbox command builder directly with cluster_config
cmd, metadata = sandbox_command(cluster_config=cluster_config, port=sandbox_port)
metadata["log_prefix"] = "sandbox"
+
+ # Propagate mounts to sandbox if enabled
+ if keep_mounts_for_sandbox and "mounts" in cluster_config:
+ metadata["mounts"] = cluster_config["mounts"]
sandbox_cmd = Command(
command=cmd,
container=cluster_config["containers"]["sandbox"],
name=task_name,
metadata=metadata,
)
components.append(sandbox_cmd)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Find maximum GPUs/nodes needed by any component for the HardwareConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # The job-level resource request must be the maximum across all components | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_gpus = max((comp.gpus or 0) for comp in components) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_nodes = max((comp.nodes or 1) for comp in components) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return CommandGroup( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| commands=components, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hardware=HardwareConfig( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| partition=partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qos=qos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| time_min=time_min, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| exclusive=exclusive, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_gpus=max_gpus, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_nodes=max_nodes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| log_dir=log_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @typer_unpacker | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def generate( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -257,89 +356,125 @@ def generate( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_ids=chunk_ids, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rerun_done=rerun_done, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_tasks = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_tasks = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _task_dependencies is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _task_dependencies = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with pipeline_utils.get_exp(expname, cluster_config, _reuse_exp) as exp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for seed_idx, (seed, chunk_ids) in enumerate(remaining_jobs.items()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if wandb_parameters: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # no need for chunks as it will run after merging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_parameters["samples_file"] = pipeline_utils.get_chunked_rs_filename( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| random_seed=seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_id=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for chunk_id in chunk_ids: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| has_tasks = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_config, server_address, extra_arguments = pipeline_utils.configure_client( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_type=server_type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_address=original_server_address, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_gpus=server_gpus, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_nodes=server_nodes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_args=server_args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_entrypoint=server_entrypoint, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_container=server_container, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| extra_arguments=extra_arguments_original, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_random_port=get_random_port, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd = pipeline_utils.get_generation_cmd( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_file=input_file, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_dir=input_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| random_seed=seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dir=output_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| extra_arguments=extra_arguments, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eval_args=eval_args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_id=chunk_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_chunks=num_chunks, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preprocess_cmd=preprocess_cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| postprocess_cmd=postprocess_cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_parameters=wandb_parameters if seed_idx == 0 else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| script=generation_module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Build jobs list using declarative interface | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jobs = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_job_names = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for seed_idx, (seed, chunk_ids) in enumerate(remaining_jobs.items()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if wandb_parameters: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # no need for chunks as it will run after merging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_parameters["samples_file"] = pipeline_utils.get_chunked_rs_filename( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| random_seed=seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_id=None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for chunk_id in chunk_ids: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Configure client (same as before) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_config, server_address, extra_arguments = pipeline_utils.configure_client( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_type=server_type, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_address=original_server_address, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_gpus=server_gpus, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_nodes=server_nodes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_args=server_args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_entrypoint=server_entrypoint, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_container=server_container, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| extra_arguments=extra_arguments_original, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_random_port=get_random_port, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Build generation command (same as before) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd = pipeline_utils.get_generation_cmd( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_file=input_file, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_dir=input_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| random_seed=seed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dir=output_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| extra_arguments=extra_arguments, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eval_args=eval_args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunk_id=chunk_id, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_chunks=num_chunks, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| preprocess_cmd=preprocess_cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| postprocess_cmd=postprocess_cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_parameters=wandb_parameters if seed_idx == 0 else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| script=generation_module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd = pipeline_utils.wrap_python_path(cmd=cmd) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Base task name (shared across all dependent jobs in the chain) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name = f"{expname}-rs{seed}" if seed is not None else expname | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if chunk_id is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name += f"-chunk{chunk_id}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Handle dependent_jobs chain | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dependencies = _task_dependencies.copy() if _task_dependencies else [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prev_job = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for dep_idx in range(dependent_jobs + 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create CommandGroup for this task | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd_group = _create_commandgroup_from_config( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_cmd=cmd, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_config=server_config.copy() if server_config else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with_sandbox=with_sandbox, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sandbox_port=None if get_random_port else 6000, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cluster_config=cluster_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installation_command=installation_command, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_server_command_fn=generation_task.get_server_command_fn(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| partition=partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qos=qos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| time_min=time_min, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| exclusive=exclusive, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| keep_mounts_for_sandbox=keep_mounts_for_sandbox, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| log_dir=log_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prev_tasks = _task_dependencies | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for _ in range(dependent_jobs + 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name = f"{expname}-rs{seed}" if seed is not None else expname | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if chunk_id is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name += f"-chunk{chunk_id}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| new_task = pipeline_utils.add_task( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| exp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cmd=pipeline_utils.wrap_python_path(cmd=cmd), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_name=task_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| log_dir=log_dir, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| container=cluster_config["containers"]["nemo-skills"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cluster_config=cluster_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| partition=partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qos=qos, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| time_min=time_min, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| server_config=server_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with_sandbox=with_sandbox, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| keep_mounts_for_sandbox=keep_mounts_for_sandbox, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| sandbox_port=None if get_random_port else 6000, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_after=run_after, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reuse_code=reuse_code, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reuse_code_exp=reuse_code_exp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| task_dependencies=( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prev_tasks if cluster_config["executor"] == "slurm" else all_tasks + _task_dependencies | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| get_server_command=generation_task.get_server_command_fn(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| slurm_kwargs={"exclusive": exclusive} if exclusive else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| installation_command=installation_command, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| skip_hf_home_check=skip_hf_home_check, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prev_tasks = [new_task] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_tasks.append(new_task) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_tasks and not _reuse_exp: # if we are reusing an experiment, the tasks will run from there | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline_utils.run_exp(exp, cluster_config, dry_run=dry_run) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _reuse_exp: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return all_tasks | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if has_tasks: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return exp | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use unique internal job name for dependency tracking, but same task_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| internal_job_name = f"{task_name}-dep{dep_idx}" if dep_idx > 0 else task_name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Build dependencies: first job in chain gets external dependencies, rest chain to previous | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if dep_idx == 0: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # First job: add run_after if no task_dependencies | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| job_deps = dependencies.copy() if dependencies else [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not dependencies and run_after: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_after_list = run_after if isinstance(run_after, list) else [run_after] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| job_deps.extend(run_after_list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| job_deps = job_deps if job_deps else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Subsequent jobs in chain depend on previous job (use job object, not string) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| job_deps = [prev_job] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| job_spec = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "name": internal_job_name, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "group": cmd_group, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "dependencies": job_deps, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jobs.append(job_spec) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prev_job = job_spec # Track for next iteration | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| all_job_names.append(internal_job_name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If no jobs to run, return early | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not jobs: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Create and run pipeline | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pipeline = Pipeline( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| name=expname, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cluster_config=cluster_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| jobs=jobs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reuse_code=reuse_code, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reuse_code_exp=reuse_code_exp, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| skip_hf_home_check=skip_hf_home_check, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Pass _reuse_exp to pipeline.run() to add jobs to existing experiment | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = pipeline.run(dry_run=dry_run, _reuse_exp=_reuse_exp) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| typer.main.get_command_name = lambda name: name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we confident enough in our tests related to generate?
If our test coverage is low, I will strong recommend that we keep the old
generate.pyasgenerate_legacy.pyso that people getting the new version automatically use the new command (with warnings of course). So in case, it does cause any bugs, they can always have an escape hatch for this transition phase.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should run a few more tests, but I would probably just go ahead and switch to the new interface directly. As long as there are no silent issues (so it doesn't corrupt generations, but just throws an error), it should be easy to fix, and we will quickly resolve all problems. If we find that there are some really tricky things that we didn't handle, we can always just roll this back or add generate_legacy.py whenever we face this situation