Skip to content

Commit bbd5396

Browse files
angel-coreThe tunix Authors
authored andcommitted
Code update
PiperOrigin-RevId: 907767265
1 parent 7c73768 commit bbd5396

4 files changed

Lines changed: 22 additions & 4 deletions

File tree

examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ actor_model_config:
3030
shape: "(2,4)"
3131
axis_names: "('fsdp','tp')"
3232
rollout_model_config:
33-
mesh:
34-
shape: "(2,4)"
35-
axis_names: "('fsdp','tp')"
33+
mesh: null
34+
same_mesh_as: "actor"
35+
reference_model_config:
36+
mesh: null
37+
same_mesh_as: "actor"
3638
tokenizer_config:
3739
tokenizer_type: "sentencepiece"
3840
add_bos: False
41+
data_source: "tfds"
3942
dataset_name: "gsm8k"
4043
batch_size: 1
4144
num_batches: 3738

tests/cli/config_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,12 @@ def test_perf_metrics_validation(
632632
else:
633633
config.initialize(argv)
634634

635+
def test_dict_to_cli_args_with_none(self):
636+
d = {"a": 1, "b": None, "c": {"d": None, "e": 2}}
637+
expected = ["a=1", "b=null", "c.d=null", "c.e=2"]
638+
got = list(config._dict_to_cli_args(d))
639+
self.assertEqual(expected, got)
640+
635641

636642
if __name__ == "__main__":
637643
if "HF_TOKEN" not in os.environ:

tests/models/naming_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def test_get_model_family_and_version(
191191
(expected_family, expected_version),
192192
)
193193

194+
def test_get_model_family_and_version_format_agnostic(self):
195+
self.assertEqual(
196+
naming.get_model_family_and_version('gemma2-2b-it'),
197+
naming.get_model_family_and_version('gemma2_2b_it'),
198+
)
199+
194200
def test_get_model_family_and_version_invalid_fails(self):
195201
with self.assertRaisesRegex(
196202
ValueError, 'Could not determine model family for: foo-bar.'

tunix/cli/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ def _dict_to_cli_args(
143143
else:
144144
yield f"{new_key}={{}}"
145145
else:
146-
yield f"{new_key}={v}"
146+
if v is None:
147+
yield f"{new_key}=null"
148+
else:
149+
yield f"{new_key}={v}"
147150

148151

149152
class HyperParameters:

0 commit comments

Comments
 (0)