Skip to content

Commit f88989b

Browse files
committed
Add a complete DDP program, including the process of collect and world_model/llm training.
1 parent 0cef8b8 commit f88989b

File tree

6 files changed

+374
-322
lines changed

6 files changed

+374
-322
lines changed

lzero/entry/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,11 @@ def calculate_update_per_collect(
528528
collected_transitions_tensor
529529
).item()
530530
updates = int(total_collected_transitions * cfg.policy.replay_ratio)
531+
print(f"total_collected_transitions={total_collected_transitions}\tupdates={updates}")
531532
else:
532533
# In a single-process setup.
533534
updates = int(collected_transitions_num * cfg.policy.replay_ratio)
535+
print(f"collected_transitions_num={collected_transitions_num}\tupdates={updates}")
534536

535537
return max(1, updates) # Ensure at least one update.
536538

zoo/jericho/priorzero/priorzero_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class PriorZeroLLMConfig:
135135
rft_kl_coef: float = 0.01
136136
kl_estimator: str = "k3"
137137

138-
train_llm_after_wm_warm_step: int = int(1e3)
138+
train_llm_after_wm_warm_step: int = int(1e2)
139139
value_norm_cfg: Optional[EasyDict] = field(default_factory=lambda: EasyDict({
140140
'enable_stability_optimizer': True,
141141
'value_norm_init_momentum': 0.9, # Fast adaptation in early training
@@ -153,6 +153,7 @@ def get_priorzero_config(
153153
exp_name: str = None,
154154
use_cot: bool = False,
155155
model_key: Optional[str] = None,
156+
multi_gpu: bool = False
156157
) -> Tuple[EasyDict, EasyDict]:
157158
"""
158159
Generate complete PriorZero configuration with automatic model configuration.
@@ -218,7 +219,7 @@ def get_priorzero_config(
218219
)
219220
policy_config = dict(
220221
type='priorzero',
221-
multi_gpu=False,
222+
multi_gpu=multi_gpu,
222223
use_wandb=False,
223224
learn=dict(
224225
learner=dict(

zoo/jericho/priorzero/priorzero_datafactory.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def build_llm_samples(self,
215215
)
216216
return samples
217217

218-
def make_llm_train_samples(self, priorzero_batch) -> List[Dict[str, Any]]:
218+
def make_llm_train_samples(self, priorzero_batch, ddp: bool = False) -> List[Dict[str, Any]]:
219219
"""
220220
Convert PriorZero batch to LLM training samples.
221221
@@ -235,14 +235,17 @@ def make_llm_train_samples(self, priorzero_batch) -> List[Dict[str, Any]]:
235235
samples = self.build_llm_samples(
236236
raw_obs_list, history_obs_list, action_logprob_list, target_value, cot_prefix_list
237237
)
238-
per_rank = len(samples) // self.world_size
239-
start = self.rank * per_rank
240-
end = (self.rank + 1) * per_rank if self.rank != self.world_size - 1 else len(samples)
241-
print(f"[Rank {self.rank}] process {start}: {end} samples, total {len(samples)} samples.")
242-
real_samples = samples[start:end]
238+
if ddp:
239+
print(f"[Rank {self.rank}] process {len(samples)} samples collected by Rank {self.rank}")
240+
real_samples = samples
241+
else:
242+
per_rank = len(samples) // self.world_size
243+
start = self.rank * per_rank
244+
end = (self.rank + 1) * per_rank if self.rank != self.world_size - 1 else len(samples)
245+
print(f"[Rank {self.rank}] process {start}: {end} samples. Total {len(samples)} samples collected by Rank 0.")
246+
real_samples = samples[start:end]
243247

244248
prompts_only = [s["prompt"] for s in real_samples]
245-
246249
if self.use_cot:
247250
targets_only = [s["prefix_cot"] + " " + s["target"] + self.tokenizer.eos_token for s in real_samples]
248251
if self.args.reward_func.format_reward:

0 commit comments

Comments
 (0)