Skip to content

Commit 0a996db

Browse files
authored
Merge pull request #32 from Qmeiyi/main
update save_interval
2 parents 199d0e0 + 5c2b31c commit 0a996db

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

src/dataflex/configs/components.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ selectors:
88
gradient_type: adam
99
proj_dim: 4096
1010
seed: 123
11+
save_interval: 16
1112
reward_model_backend: local_vllm # choices: [local_vllm, api]
1213
reward_backend_params:
1314
local_vllm:
@@ -37,6 +38,7 @@ selectors:
3738
gradient_type: adam
3839
proj_dim: 4096
3940
seed: 123
41+
save_interval: 16
4042

4143
loss:
4244
name: loss

src/dataflex/train/selector/nice_selector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def __init__(self,
266266
reward_backend_params: Optional[Dict[str, Any]] = None,
267267
gradient_type: str = "adam",
268268
proj_dim: int = 8192,
269+
save_interval: int = 16,
269270
seed: int = 42,
270271
mc_samples: int = 4,
271272
max_new_tokens: int = 512,
@@ -280,6 +281,7 @@ def __init__(self,
280281
self.eval_dataset = eval_dataset
281282
self.gradient_type = gradient_type
282283
self.proj_dim = proj_dim
284+
self.save_interval = save_interval
283285
self.seed = seed
284286
self.mc_samples = mc_samples
285287
self.max_new_tokens = max_new_tokens
@@ -679,7 +681,7 @@ def indexed_collator_wrapper(features):
679681
dataloader = self.accelerator.prepare(dataloader)
680682

681683
# 4) 设置保存间隔
682-
save_interval = 64
684+
save_interval = self.save_interval
683685

684686
# 5) 断点续传
685687
max_index = self._get_max_saved_index(save_dir=save_dir)

0 commit comments

Comments
 (0)