Skip to content

Commit a4991d2

Browse files
authored
Migrate resource configuration to Fray (#2154)
This migrates our various accelerator & resource configurations to use Fray. Of course we had at least 3 ways to specify these in the past. This doesn't completely remove all duplication, but it gets us most of the way there. Execution is still handled via ray as usual, this just replaces our various V6_TPU_STRICT_PACK etc constants with Fray versions. Now all training & evaluation jobs use a Fray ResourceConfig to specify the accelerator type, number of slices etc. We also port the flops calculation over to Fray since it was duplicated in a few places. As before, for training jobs, we rely on the usual ray_tpu logic to pack TPU workers properly. For evaluation jobs, instead of having the various Ray-specific helpers, we instead inject the strict_pack using the "scheduling_strategy" helper when launching the evaluation jobs themselves.
1 parent 7f6e5ba commit a4991d2

File tree

132 files changed

+1198
-1550
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

132 files changed

+1198
-1550
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ You can check out the [full script](https://github.com/marin-community/marin/blo
5353
<!--marin-example-start-->
5454

5555
```python
56+
from fray.cluster import ResourceConfig
57+
5658
from experiments.defaults import default_tokenize, default_train
5759
from experiments.llama import llama3_tokenizer, llama_nano
5860
from experiments.simple_train_config import SimpleTrainConfig
5961
from marin.execution.executor import executor_main
60-
from marin.resources import CpuOnlyConfig
6162

6263
# 1. Choose a dataset
6364
tinystories_hf_id = "roneneldan/TinyStories"
@@ -72,7 +73,7 @@ tinystories_tokenized = default_tokenize(
7273
# 3. Define training configuration
7374
nano_train_config = SimpleTrainConfig(
7475
# Here we define the hardware resources we need.
75-
resources=CpuOnlyConfig(num_cpus=1),
76+
resources=ResourceConfig.with_cpu(),
7677
train_batch_size=4,
7778
num_train_steps=100,
7879
# set hyperparameters

docs/references/resource-config.md

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,41 @@
11
# Hardware Resource Configuration
22

33
Marin uses Ray for scheduling and resource management. Ray provides a flexible resource model that allows you to specify
4-
the resources that a task requires. In Marin, we specify a few wrapper types for common hardware configurations.
4+
the resources that a task requires. The `fray` library provides unified resource configuration types.
55

6+
## ResourceConfig
67

7-
## CPU-Only
8+
The main entry point for resource configuration. Use the static factory methods to create configurations:
89

9-
::: marin.resources.CpuOnlyConfig
10+
```python
11+
from fray.cluster import ResourceConfig
1012

11-
## GPU
13+
# TPU configuration
14+
tpu_config = ResourceConfig.with_tpu("v4-8")
15+
tpu_multislice = ResourceConfig.with_tpu("v4-8", slice_count=2)
1216

13-
::: marin.resources.GpuConfig
17+
# GPU configuration
18+
gpu_config = ResourceConfig.with_gpu("H100", count=8)
19+
gpu_auto = ResourceConfig.with_gpu() # auto-detect GPU type
1420

15-
## TPU
21+
# CPU-only configuration
22+
cpu_config = ResourceConfig.with_cpu()
23+
```
1624

17-
::: marin.resources.TpuPodConfig
25+
::: fray.cluster.base.ResourceConfig
26+
27+
## Device Configurations
28+
29+
These are the underlying device types wrapped by `ResourceConfig`:
30+
31+
### CPU
32+
33+
::: fray.cluster.base.CpuConfig
34+
35+
### GPU
36+
37+
::: fray.cluster.base.GpuConfig
38+
39+
### TPU
40+
41+
::: fray.cluster.base.TpuConfig

docs/tutorials/first-experiment.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,16 @@ For this tutorial, we will use the `SimpleTrainConfig` class from `experiments.s
7878
This class defines basic training configuration that is sufficient for most experiments.
7979

8080
!!! info "Training Configuration for Different Accelerators"
81-
You need to provide the appropriate resource configuration based on your hardware setup. Marin supports different accelerator types through various [resource configurations](../references/resource-config.md). The `CpuOnlyConfig` is one such resource configuration that requests a certain number of CPUs. Other resource configurations include `GpuConfig` for requesting GPUs and `TpuPodConfig` for requesting TPUs.
81+
You need to provide the appropriate resource configuration based on your hardware setup. Marin supports different accelerator types through [`ResourceConfig`](../references/resource-config.md) factory methods.
8282

8383
=== "CPU"
8484
```python
85-
from marin.resources import CpuOnlyConfig
85+
from fray.cluster import ResourceConfig
8686
from experiments.simple_train_config import SimpleTrainConfig
8787

8888
nano_train_config = SimpleTrainConfig(
8989
# Here we define the hardware resources we need.
90-
resources=CpuOnlyConfig(num_cpus=1),
90+
resources=ResourceConfig.with_cpu(),
9191
train_batch_size=4,
9292
num_train_steps=100,
9393
# set hyperparameters
@@ -100,12 +100,12 @@ This class defines basic training configuration that is sufficient for most expe
100100

101101
=== "GPU"
102102
```python
103-
from marin.resources import GpuConfig
103+
from fray.cluster import ResourceConfig
104104
from experiments.simple_train_config import SimpleTrainConfig
105105

106106
nano_train_config = SimpleTrainConfig(
107107
# Here we define the hardware resources we need.
108-
resources=GpuConfig(gpu_count=1),
108+
resources=ResourceConfig.with_gpu(count=1),
109109
train_batch_size=32,
110110
num_train_steps=100,
111111
# set hyperparameters
@@ -116,12 +116,12 @@ This class defines basic training configuration that is sufficient for most expe
116116

117117
=== "TPU"
118118
```python
119-
from marin.resources import TpuPodConfig
119+
from fray.cluster import ResourceConfig
120120
from experiments.simple_train_config import SimpleTrainConfig
121121

122122
nano_train_config = SimpleTrainConfig(
123123
# Here we define the hardware resources we need.
124-
resources=TpuPodConfig(tpu_type="v4-8"),
124+
resources=ResourceConfig.with_tpu("v4-8"),
125125
train_batch_size=4,
126126
num_train_steps=100,
127127
# set hyperparameters

docs/tutorials/run-alpaca-eval.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ This tutorial shows how to configure and launch the Alpaca evaluation pipeline i
1212
The default evaluation script for alpaca is `experiments/evals/run_alpaca_eval.py`), if for some reason you want to make your own script import:
1313

1414
```python
15+
from fray.cluster import ResourceConfig
1516
from experiments.evals.engine_configs import DEFAULT_VLLM_ENGINE_KWARGS
16-
from experiments.evals.resource_configs import SINGLE_TPU_V6E_8
1717
from experiments.evals.evals import evaluate_alpaca_eval
1818
from marin.execution.executor import ExecutorMainConfig, executor_main
1919
```
@@ -22,9 +22,9 @@ from marin.execution.executor import ExecutorMainConfig, executor_main
2222

2323
```python
2424
# nodryrun
25+
from fray.cluster import ResourceConfig
2526
from experiments.evals.engine_configs import DEFAULT_VLLM_ENGINE_KWARGS
2627
from experiments.evals.evals import evaluate_alpaca_eval
27-
from experiments.evals.resource_configs import SINGLE_TPU_V6E_8
2828
from marin.execution.executor import ExecutorMainConfig, executor_main
2929

3030
# Retry any failed steps by default
@@ -34,7 +34,7 @@ steps = [
3434
evaluate_alpaca_eval(
3535
model_name="my_alpaca_model_eval", # Name for logging / W&B
3636
model_path="path/to/your/model/checkpoint/hf/", # HF checkpoint directory
37-
resource_config=SINGLE_TPU_V6E_8, # E.g., TPU v6e-8; choose GPU/TPU config
37+
resource_config=ResourceConfig.with_tpu("v6e-8"), # E.g., TPU v6e-8; choose GPU/TPU config
3838
engine_kwargs=DEFAULT_VLLM_ENGINE_KWARGS, # vLLM backend parameters
3939

4040
# IMPORTANT: stop_token_ids must include the eos_token_id of your HF model.
@@ -65,7 +65,7 @@ if __name__ == "__main__":
6565
|----------------------|----------------------|-------------|
6666
| model_name | `str` | Name for experiment tracking through executor framework. |
6767
| model_path | `str` | Path on GCP or URL to HF-format model checkpoint. |
68-
| resource_config | `ResourceConfig` | Hardware spec (e.g. `SINGLE_TPU_V6E_8`). |
68+
| resource_config | `ResourceConfig` | Hardware spec (e.g. `ResourceConfig.with_tpu("v6e-8")`). |
6969
| engine_kwargs | `dict  None` | vLLM engine settings (e.g. batch size, sequence length). |
7070
| max_eval_instances | `int  None` | Limits the number of examples to evaluate; `None` = all. |
7171
| temperature | `float` | Sampling temperature. |

docs/tutorials/run-lm-evals.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ from experiments.evals.task_configs import (
2525
)
2626

2727
# Hardware / executor
28-
from experiments.evals.resource_configs import SINGLE_TPU_V4_8, SINGLE_TPU_V6E_8
28+
from fray.cluster import ResourceConfig
2929
from marin.execution.executor import executor_main
3030
from marin.execution.executor import ExecutorMainConfig # for retry logic
3131
```
@@ -36,8 +36,8 @@ Run the canonical CORE_TASKS (subset of DCLM tasks) via LM Evaluation Harness:
3636

3737
```python
3838
# run_mcqa_eval.py
39+
from fray.cluster import ResourceConfig
3940
from experiments.evals.evals import default_eval
40-
from experiments.evals.resource_configs import SINGLE_TPU_V4_8
4141
from marin.execution.executor import executor_main
4242

4343
# Example: evaluate a standalone checkpoint
@@ -46,7 +46,7 @@ model_path = "gs://marin-us-east5/gcsfuse_mount/perplexity-models/llama-200m"
4646
# This creates an ExecutorStep that runs CORE_TASKS
4747
core_evals_step = default_eval(
4848
step=model_path,
49-
resource_config=SINGLE_TPU_V4_8,
49+
resource_config=ResourceConfig.with_tpu("v4-8"),
5050
# Optional: override the task set:
5151
# evals=CORE_TASKS_PLUS_MMLU,
5252
# max_eval_instances=100,
@@ -65,8 +65,8 @@ Use `default_key_evals` to run a collection of generation tasks (`KEY_GENERATION
6565

6666
```python
6767
# run_key_evals.py (see 1:18:experiments/evals/run_key_evals.py)
68+
from fray.cluster import ResourceConfig
6869
from experiments.evals.evals import default_key_evals
69-
from experiments.evals.resource_configs import SINGLE_TPU_V6E_8
7070
from marin.execution.executor import executor_main
7171

7272
# Point to your checkpoint or a training ExecutorStep
@@ -78,7 +78,7 @@ model_path = "gs://marin-us-east5/gcsfuse_mount/perplexity-models/llama-200m"
7878
# 3) Alpaca eval
7979
key_steps = default_key_evals(
8080
step=model_path,
81-
resource_config=SINGLE_TPU_V6E_8,
81+
resource_config=ResourceConfig.with_tpu("v6e-8"),
8282
model_name="my_key_evals",
8383
# max_eval_instances=50,
8484
)
@@ -104,7 +104,7 @@ from experiments.evals.evals import evaluate_alpaca_eval
104104
alpaca_step = evaluate_alpaca_eval(
105105
model_name="my_model",
106106
model_path="...",
107-
resource_config=SINGLE_TPU_V6E_8,
107+
resource_config=ResourceConfig.with_tpu("v6e-8"),
108108
engine_kwargs=DEFAULT_VLLM_ENGINE_KWARGS,
109109
stop_token_ids=[<YOUR_EOS_TOKEN_ID>], # must match your HF model's eos_token_id
110110
)
@@ -117,8 +117,8 @@ alpaca_step = evaluate_alpaca_eval(
117117
If you want fine‐grained control over which tasks to run:
118118

119119
```python
120+
from fray.cluster import ResourceConfig
120121
from experiments.evals.evals import evaluate_lm_evaluation_harness
121-
from experiments.evals.resource_configs import SINGLE_TPU_V4_8
122122
from marin.evaluation.evaluation_config import EvalTaskConfig
123123
from marin.execution.executor import executor_main
124124

@@ -132,7 +132,7 @@ custom_step = evaluate_lm_evaluation_harness(
132132
model_name="custom_eval",
133133
model_path="...",
134134
evals=custom_tasks,
135-
resource_config=SINGLE_TPU_V4_8,
135+
resource_config=ResourceConfig.with_tpu("v4-8"),
136136
max_eval_instances=200,
137137
)
138138

experiments/anneal_config.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass
15+
from dataclasses import dataclass, field
1616

17+
from fray.cluster import ResourceConfig
1718
from marin.execution import InputName
1819
from marin.processing.tokenize.data_configs import LMMixtureDatasetConfig
19-
from marin.resources import ResourceConfig, TpuPodConfig
2020

2121

2222
@dataclass(frozen=True)
@@ -54,7 +54,7 @@ class AnnealConfig:
5454

5555
# Hardware related
5656
# The number of TPUs to use, type of TPU, and the number of pods to use.
57-
resources: ResourceConfig = TpuPodConfig(tpu_type="v4-128", slice_count=2) # noqa: RUF009
57+
resources: ResourceConfig = field(default_factory=lambda: ResourceConfig.with_tpu("v4-128", slice_count=2))
5858

5959
# Checkpoint related
6060
# The number of steps between saving checkpoints. Larger values will save checkpoints more frequently.
@@ -64,17 +64,3 @@ class AnnealConfig:
6464
# This argument is used in the default_train. If set to True, the validation set is Paloma.
6565
# If set to False, we will not calculate validation loss.
6666
use_default_validation: bool = True
67-
68-
@property
69-
def tpu_type(self) -> str | None:
70-
"""For backward compatibility."""
71-
if isinstance(self.resources, TpuPodConfig):
72-
return self.resources.tpu_type
73-
return None
74-
75-
@property
76-
def node_count(self) -> int:
77-
"""For backward compatibility."""
78-
if isinstance(self.resources, TpuPodConfig):
79-
return self.resources.slice_count
80-
return 1

experiments/cooldown_quality.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@
2323
and determine their relative contributions to model performance.
2424
"""
2525

26-
from dataclasses import dataclass
26+
from dataclasses import dataclass, field
2727

2828

2929
from experiments.anneal_config import AnnealConfig
30-
from experiments.pretraining_datasets.dclm import dclm_components_llama3
3130
from experiments.defaults import default_anneal
3231
from experiments.pretraining_datasets import tokenize_dolma
32+
from experiments.pretraining_datasets.dclm import dclm_components_llama3
33+
from fray.cluster import ResourceConfig
3334
from marin.execution.executor import ExecutorStep
34-
from marin.processing.tokenize.data_configs import TokenizerStep, lm_mixture_data_config, PermutationType
35-
from marin.resources import TpuPodConfig
35+
from marin.processing.tokenize.data_configs import PermutationType, TokenizerStep, lm_mixture_data_config
3636

3737

3838
@dataclass(frozen=True)
@@ -48,7 +48,7 @@ class QualityAblationConfig:
4848

4949
# Training parameters
5050
num_anneal_tokens: int = 50_000_000_000
51-
resources: TpuPodConfig = TpuPodConfig(tpu_type="v5litepod-128") # noqa: RUF009
51+
resources: ResourceConfig = field(default_factory=lambda: ResourceConfig.with_tpu("v5litepod-128"))
5252

5353
# Naming
5454
model_name_prefix: str = "8b-quality-eval"

experiments/datashop/datashop_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
default_train_quality_model,
4040
)
4141
from experiments.evals.evals import default_eval
42-
from experiments.evals.resource_configs import SINGLE_TPU_V6E_8, TPU_V6E_8_STRICT_PACK, ResourceConfig
42+
from fray.cluster import ResourceConfig
4343
from experiments.evals.task_configs import MMLU_5_SHOT
4444
from marin.datashop.pipeline import CorpusContent
4545
from marin.execution.executor import executor_main
@@ -73,7 +73,7 @@ class DatashopRunnerConfig:
7373
pretraining_data_path_name: str = "datashop-dclm-pretraining-subset"
7474

7575
# How to schedule the TPUs (what hardware to use and how to pack them) specifically for labeling
76-
labeler_resource_config: ResourceConfig = field(default_factory=lambda: TPU_V6E_8_STRICT_PACK)
76+
labeler_resource_config: ResourceConfig = field(default_factory=lambda: ResourceConfig.with_tpu("v6e-8"))
7777

7878
# What hardware to use for training the final model
7979
training_tpu_type: str = "v6e-128"
@@ -100,7 +100,7 @@ class DatashopRunnerConfig:
100100
consolidate_config_kwargs: dict | None = None
101101

102102
# What hardware to use for evaluating the model
103-
eval_resource_config: ResourceConfig = field(default_factory=lambda: SINGLE_TPU_V6E_8)
103+
eval_resource_config: ResourceConfig = field(default_factory=lambda: ResourceConfig.with_tpu("v6e-8"))
104104

105105

106106
class DatashopRunner:

experiments/datashop/default_configs.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@
6767

6868
default_inference_config_kwargs = {
6969
"model_type": "gte",
70-
"runtime": RuntimeConfig(
71-
memory_limit_gb=12,
72-
resources={"TPU": 1},
73-
),
70+
"runtime": RuntimeConfig(memory_limit_gb=12, resources={"TPU": 1}),
7471
"task": TaskConfig(max_in_flight=500),
7572
"filetype": "jsonl.zst",
7673
"classifier_kwargs": {"max_length": 512},

0 commit comments

Comments
 (0)